|
@@ -0,0 +1,128 @@
|
|
|
|
|
+"""二次 OCR 分行聚合与择优逻辑单元测试。"""
|
|
|
|
|
+import pytest
|
|
|
|
|
+
|
|
|
|
|
+from ocr_tools.universal_doc_parser.models.adapters.wired_table.text_filling import TextFiller
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TestAggregateLineOcr:
|
|
|
|
|
+ def test_drop_low_score_and_length_weighted(self):
|
|
|
|
|
+ # 模拟 cell 207:助款 0.55 应丢弃
|
|
|
|
|
+ blocks = [
|
|
|
|
|
+ ("存折息", 0.855),
|
|
|
|
|
+ ("助设备", 0.953),
|
|
|
|
|
+ ("助款", 0.547),
|
|
|
|
|
+ ]
|
|
|
|
|
+ text, score = TextFiller.aggregate_line_ocr(
|
|
|
|
|
+ blocks, line_min_score=0.6, drop_low_score_blocks=True
|
|
|
|
|
+ )
|
|
|
|
|
+ assert text == "存折息助设备"
|
|
|
|
|
+ expected = (3 * 0.855 + 3 * 0.953) / 6
|
|
|
|
|
+ assert abs(score - expected) < 1e-6
|
|
|
|
|
+
|
|
|
|
|
+ def test_all_dropped_returns_empty(self):
|
|
|
|
|
+ blocks = [("新", 0.54), ("x", 0.5)]
|
|
|
|
|
+ text, score = TextFiller.aggregate_line_ocr(
|
|
|
|
|
+ blocks, line_min_score=0.6, drop_low_score_blocks=True
|
|
|
|
|
+ )
|
|
|
|
|
+ assert text == ""
|
|
|
|
|
+ assert score == 0.0
|
|
|
|
|
+
|
|
|
|
|
+ def test_no_drop_keeps_all(self):
|
|
|
|
|
+ blocks = [("ab", 0.8), ("c", 0.7)]
|
|
|
|
|
+ text, score = TextFiller.aggregate_line_ocr(
|
|
|
|
|
+ blocks, line_min_score=0.6, drop_low_score_blocks=False
|
|
|
|
|
+ )
|
|
|
|
|
+ assert text == "abc"
|
|
|
|
|
+ assert abs(score - (2 * 0.8 + 1 * 0.7) / 3) < 1e-6
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TestPickLineVsWhole:
|
|
|
|
|
+ def _filler(self) -> TextFiller:
|
|
|
|
|
+ return TextFiller(ocr_engine=None, config={"second_pass_ocr": {}})
|
|
|
|
|
+
|
|
|
|
|
+ def test_prefer_higher_score_whole(self):
|
|
|
|
|
+ f = self._filler()
|
|
|
|
|
+ t, s, strat = f._pick_line_vs_whole("存折息助设备", 0.85, "存折自助设备取款", 0.92)
|
|
|
|
|
+ assert t == "存折自助设备取款"
|
|
|
|
|
+ assert strat == "whole"
|
|
|
|
|
+
|
|
|
|
|
+ def test_prefer_higher_score_lines(self):
|
|
|
|
|
+ f = self._filler()
|
|
|
|
|
+ t, s, strat = f._pick_line_vs_whole("正确文本", 0.95, "错", 0.5)
|
|
|
|
|
+ assert t == "正确文本"
|
|
|
|
|
+ assert strat == "lines"
|
|
|
|
|
+
|
|
|
|
|
+ def test_tie_prefers_whole(self):
|
|
|
|
|
+ f = self._filler()
|
|
|
|
|
+ t, s, strat = f._pick_line_vs_whole("a", 0.8, "ab", 0.8)
|
|
|
|
|
+ assert t == "ab"
|
|
|
|
|
+ assert strat == "tie_whole"
|
|
|
|
|
+
|
|
|
|
|
+ def test_empty_line_uses_whole(self):
|
|
|
|
|
+ f = self._filler()
|
|
|
|
|
+ t, s, strat = f._pick_line_vs_whole("", 0.0, "整格", 0.7)
|
|
|
|
|
+ assert t == "整格"
|
|
|
|
|
+ assert strat == "whole"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TestSanitizeDebugFilename:
|
|
|
|
|
+ def test_illegal_chars(self):
|
|
|
|
|
+ assert TextFiller.sanitize_debug_filename("a/b:c") == "a_b_c"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TestSortDetBoxesReadingOrder:
|
|
|
|
|
+ def _box(self, x1, y1, x2, y2):
|
|
|
|
|
+ return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
|
|
|
|
+
|
|
|
|
|
+ def test_horizontal_same_row_left_to_right(self):
|
|
|
|
|
+ # 「交易类」在左,「型」在右且 y 略偏(中心点排序易错)
|
|
|
|
|
+ boxes = [
|
|
|
|
|
+ self._box(60, 12, 95, 32), # 型
|
|
|
|
|
+ self._box(5, 10, 55, 30), # 交易类
|
|
|
|
|
+ ]
|
|
|
|
|
+ ordered = TextFiller.sort_det_boxes_reading_order(boxes, 50, 100)
|
|
|
|
|
+ assert ordered[0] is boxes[1]
|
|
|
|
|
+ assert ordered[1] is boxes[0]
|
|
|
|
|
+
|
|
|
|
|
+ def test_vertical_top_to_bottom(self):
|
|
|
|
|
+ boxes = [
|
|
|
|
|
+ self._box(10, 50, 40, 70), # 型
|
|
|
|
|
+ self._box(10, 10, 40, 30), # 交易类
|
|
|
|
|
+ self._box(10, 30, 40, 48), # 中间行
|
|
|
|
|
+ ]
|
|
|
|
|
+ ordered = TextFiller.sort_det_boxes_reading_order(boxes, 80, 50)
|
|
|
|
|
+ assert [boxes.index(b) for b in ordered] == [1, 2, 0]
|
|
|
|
|
+
|
|
|
|
|
+ def test_two_row_table_header(self):
|
|
|
|
|
+ boxes = [
|
|
|
|
|
+ self._box(5, 5, 80, 22),
|
|
|
|
|
+ self._box(5, 28, 30, 45),
|
|
|
|
|
+ ]
|
|
|
|
|
+ ordered = TextFiller.sort_det_boxes_reading_order(boxes, 50, 90)
|
|
|
|
|
+ assert len(ordered) == 2
|
|
|
|
|
+ assert ordered[0] is boxes[0]
|
|
|
|
|
+ assert ordered[1] is boxes[1]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TestStripFallbackHeuristic:
|
|
|
|
|
+ def _filler(self) -> TextFiller:
|
|
|
|
|
+ return TextFiller(ocr_engine=None, config={"second_pass_ocr": {}})
|
|
|
|
|
+
|
|
|
|
|
+ def test_needs_strip_when_tall_and_one_block(self):
|
|
|
|
|
+ import numpy as np
|
|
|
|
|
+
|
|
|
|
|
+ f = self._filler()
|
|
|
|
|
+ img = np.zeros((90, 30, 3), dtype=np.uint8)
|
|
|
|
|
+ assert f._needs_strip_line_fallback(img, [("取款", 0.99)]) is True
|
|
|
|
|
+ assert f._needs_strip_line_fallback(img, [("a", 0.9), ("b", 0.9)]) is False
|
|
|
|
|
+
|
|
|
|
|
+ def test_pick_whole_when_much_longer(self):
|
|
|
|
|
+ f = self._filler()
|
|
|
|
|
+ t, s, strat = f._pick_line_vs_whole("取款", 0.99, "存折自助设备取款", 0.85)
|
|
|
|
|
+ assert t == "存折自助设备取款"
|
|
|
|
|
+ assert strat == "whole_longer"
|
|
|
|
|
+
|
|
|
|
|
+ def test_empty_text_zero_score(self):
|
|
|
|
|
+ text, score = TextFiller._parse_single_rec_item(("", 1.0))
|
|
|
|
|
+ assert text == ""
|
|
|
|
|
+ assert score == 0.0
|