Jelajahi Sumber

feat(添加二次OCR聚合与择优逻辑单元测试): 新增针对二次OCR的聚合、择优逻辑及调试功能的单元测试,提升OCR处理的准确性和可维护性。

zhch158_admin 4 hari lalu
induk
melakukan
8e61a877b0

+ 128 - 0
ocr_tools/universal_doc_parser/tests/test_second_pass_ocr_aggregate.py

@@ -0,0 +1,128 @@
+"""二次 OCR 分行聚合与择优逻辑单元测试。"""
+import pytest
+
+from ocr_tools.universal_doc_parser.models.adapters.wired_table.text_filling import TextFiller
+
+
+class TestAggregateLineOcr:
+    def test_drop_low_score_and_length_weighted(self):
+        # 模拟 cell 207:助款 0.55 应丢弃
+        blocks = [
+            ("存折息", 0.855),
+            ("助设备", 0.953),
+            ("助款", 0.547),
+        ]
+        text, score = TextFiller.aggregate_line_ocr(
+            blocks, line_min_score=0.6, drop_low_score_blocks=True
+        )
+        assert text == "存折息助设备"
+        expected = (3 * 0.855 + 3 * 0.953) / 6
+        assert abs(score - expected) < 1e-6
+
+    def test_all_dropped_returns_empty(self):
+        blocks = [("新", 0.54), ("x", 0.5)]
+        text, score = TextFiller.aggregate_line_ocr(
+            blocks, line_min_score=0.6, drop_low_score_blocks=True
+        )
+        assert text == ""
+        assert score == 0.0
+
+    def test_no_drop_keeps_all(self):
+        blocks = [("ab", 0.8), ("c", 0.7)]
+        text, score = TextFiller.aggregate_line_ocr(
+            blocks, line_min_score=0.6, drop_low_score_blocks=False
+        )
+        assert text == "abc"
+        assert abs(score - (2 * 0.8 + 1 * 0.7) / 3) < 1e-6
+
+
+class TestPickLineVsWhole:
+    def _filler(self) -> TextFiller:
+        return TextFiller(ocr_engine=None, config={"second_pass_ocr": {}})
+
+    def test_prefer_higher_score_whole(self):
+        f = self._filler()
+        t, s, strat = f._pick_line_vs_whole("存折息助设备", 0.85, "存折自助设备取款", 0.92)
+        assert t == "存折自助设备取款"
+        assert strat == "whole"
+
+    def test_prefer_higher_score_lines(self):
+        f = self._filler()
+        t, s, strat = f._pick_line_vs_whole("正确文本", 0.95, "错", 0.5)
+        assert t == "正确文本"
+        assert strat == "lines"
+
+    def test_tie_prefers_whole(self):
+        f = self._filler()
+        t, s, strat = f._pick_line_vs_whole("a", 0.8, "ab", 0.8)
+        assert t == "ab"
+        assert strat == "tie_whole"
+
+    def test_empty_line_uses_whole(self):
+        f = self._filler()
+        t, s, strat = f._pick_line_vs_whole("", 0.0, "整格", 0.7)
+        assert t == "整格"
+        assert strat == "whole"
+
+
+class TestSanitizeDebugFilename:
+    def test_illegal_chars(self):
+        assert TextFiller.sanitize_debug_filename("a/b:c") == "a_b_c"
+
+
+class TestSortDetBoxesReadingOrder:
+    def _box(self, x1, y1, x2, y2):
+        return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
+
+    def test_horizontal_same_row_left_to_right(self):
+        # 「交易类」在左,「型」在右且 y 略偏(中心点排序易错)
+        boxes = [
+            self._box(60, 12, 95, 32),   # 型
+            self._box(5, 10, 55, 30),    # 交易类
+        ]
+        ordered = TextFiller.sort_det_boxes_reading_order(boxes, 50, 100)
+        assert ordered[0] is boxes[1]
+        assert ordered[1] is boxes[0]
+
+    def test_vertical_top_to_bottom(self):
+        boxes = [
+            self._box(10, 50, 40, 70),   # 型
+            self._box(10, 10, 40, 30),   # 交易类
+            self._box(10, 30, 40, 48),   # 中间行
+        ]
+        ordered = TextFiller.sort_det_boxes_reading_order(boxes, 80, 50)
+        assert [boxes.index(b) for b in ordered] == [1, 2, 0]
+
+    def test_two_row_table_header(self):
+        boxes = [
+            self._box(5, 5, 80, 22),
+            self._box(5, 28, 30, 45),
+        ]
+        ordered = TextFiller.sort_det_boxes_reading_order(boxes, 50, 90)
+        assert len(ordered) == 2
+        assert ordered[0] is boxes[0]
+        assert ordered[1] is boxes[1]
+
+
+class TestStripFallbackHeuristic:
+    def _filler(self) -> TextFiller:
+        return TextFiller(ocr_engine=None, config={"second_pass_ocr": {}})
+
+    def test_needs_strip_when_tall_and_one_block(self):
+        import numpy as np
+
+        f = self._filler()
+        img = np.zeros((90, 30, 3), dtype=np.uint8)
+        assert f._needs_strip_line_fallback(img, [("取款", 0.99)]) is True
+        assert f._needs_strip_line_fallback(img, [("a", 0.9), ("b", 0.9)]) is False
+
+    def test_pick_whole_when_much_longer(self):
+        f = self._filler()
+        t, s, strat = f._pick_line_vs_whole("取款", 0.99, "存折自助设备取款", 0.85)
+        assert t == "存折自助设备取款"
+        assert strat == "whole_longer"
+
+    def test_empty_text_zero_score(self):
+        text, score = TextFiller._parse_single_rec_item(("", 1.0))
+        assert text == ""
+        assert score == 0.0