فهرست منبع

feat: Enhance text filling strategy in TextFiller class by introducing overlap ratio calculation for improved OCR box matching, optimizing performance with binary search, and refining text extraction logic.

zhch158_admin 2 روز پیش
والد
کامیت
c628acd7b7
1فایلهای تغییر یافته به همراه120 افزوده شده و 68 حذف شده
  1. 120 68
      ocr_tools/universal_doc_parser/models/adapters/wired_table/text_filling.py

+ 120 - 68
ocr_tools/universal_doc_parser/models/adapters/wired_table/text_filling.py

@@ -4,6 +4,7 @@
 提供表格单元格文本填充功能,包括OCR文本匹配和二次OCR填充。
 """
 from typing import List, Dict, Any, Tuple, Optional
+import bisect
 import cv2
 import numpy as np
 from loguru import logger
@@ -26,17 +27,58 @@ class TextFiller:
         self.cell_crop_margin: int = config.get("cell_crop_margin", 2)
         self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.5)
     
+    @staticmethod
+    def calculate_overlap_ratio(ocr_bbox: List[float], cell_bbox: List[float]) -> float:
+        """
+        计算 OCR box 与单元格的重叠比例(重叠面积 / OCR box 面积)
+        
+        这个比例表示 OCR box 有多少部分在单元格内,用于判断 OCR box 是否主要属于该单元格。
+        
+        Args:
+            ocr_bbox: OCR box 坐标 [x1, y1, x2, y2]
+            cell_bbox: 单元格坐标 [x1, y1, x2, y2]
+            
+        Returns:
+            重叠比例 (0.0 ~ 1.0),表示 OCR box 有多少部分在单元格内
+        """
+        if not ocr_bbox or not cell_bbox or len(ocr_bbox) < 4 or len(cell_bbox) < 4:
+            return 0.0
+        
+        # 计算交集
+        inter_x1 = max(ocr_bbox[0], cell_bbox[0])
+        inter_y1 = max(ocr_bbox[1], cell_bbox[1])
+        inter_x2 = min(ocr_bbox[2], cell_bbox[2])
+        inter_y2 = min(ocr_bbox[3], cell_bbox[3])
+        
+        if inter_x2 <= inter_x1 or inter_y2 <= inter_y1:
+            return 0.0
+        
+        inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
+        ocr_area = (ocr_bbox[2] - ocr_bbox[0]) * (ocr_bbox[3] - ocr_bbox[1])
+        
+        if ocr_area <= 0:
+            return 0.0
+        
+        return inter_area / ocr_area
+    
     def fill_text_by_center_point(
         self,
         bboxes: List[List[float]],
         ocr_boxes: List[Dict[str, Any]],
     ) -> Tuple[List[str], List[float], List[List[Dict[str, Any]]], List[int]]:
         """
-        使用中心点落格策略填充文本。
+        使用混合匹配策略填充文本:中心点 + 重叠比例。
+        
+        策略说明:
+        1. 首先用中心点快速筛选:OCR box 的中心点在单元格内
+        2. 然后检查重叠比例:OCR box 与单元格的重叠面积 / OCR box 面积 >= 0.5
+           (这确保 OCR box 主要属于该单元格,避免跨单元格匹配)
+        3. 如果多个单元格都满足条件,选择重叠比例最高的
         
-        参考 fill_html_with_ocr_by_bbox:
-        - OCR文本中心点落入单元格bbox内则匹配
-        - 多行文本按y坐标排序拼接
+        优点:
+        - 比纯 IOU 更宽松,能匹配到更多 OCR box
+        - 比纯中心点更准确,能过滤跨单元格的 OCR box
+        - 适合表格场景,OCR box 通常比单元格小或部分重叠
         
         Args:
             bboxes: 单元格坐标 [[x1,y1,x2,y2], ...]
@@ -45,7 +87,7 @@ class TextFiller:
         Returns:
             每个单元格的文本列表
             每个单元格的置信度列表
-            每个单元格匹配到的 OCR boxes 列表
+            每个单元格匹配到的 OCR boxes 列表(已过滤跨单元格的 OCR box)
             需要二次 OCR 的单元格索引列表(OCR box 跨多个单元格或过大)
         """
         texts: List[str] = ["" for _ in bboxes]
@@ -56,61 +98,96 @@ class TextFiller:
         if not ocr_boxes:
             return texts, scores, matched_boxes_list, need_reocr_indices
         
-        # 预处理OCR结果:计算中心点
+        # 预处理OCR结果:转换为 bbox 格式,并计算中心点
         ocr_items: List[Dict[str, Any]] = []
         for item in ocr_boxes:
             # 使用 CoordinateUtils.poly_to_bbox() 替换 _normalize_bbox()
             box = CoordinateUtils.poly_to_bbox(item.get("bbox", []))
-            if not box:
+            if not box or len(box) < 4:
                 continue
             cx = (box[0] + box[2]) / 2
             cy = (box[1] + box[3]) / 2
             ocr_items.append({
+                "bbox": box,
                 "center_x": cx,
                 "center_y": cy,
-                "y1": box[1],
-                "bbox": box,  # 保存 bbox 用于跨单元格检测
                 "text": item.get("text", ""),
                 "confidence": float(item.get("confidence", item.get("score", 1.0))),
                 "original_box": item,  # 保存完整的 OCR box 对象
             })
         
-        # 为每个单元格匹配OCR文本
-        for idx, bbox in enumerate(bboxes):
-            x1, y1, x2, y2 = bbox
-            matched: List[Tuple[str, float, float, Dict[str, Any]]] = [] # (text, y1, score, original_box)
+        # 按 (y1, x1) 排序,便于后续二分查找和提前退出
+        # 排序只需要一次,对整体性能影响很小(O(n log n))
+        ocr_items.sort(key=lambda item: (item["bbox"][1], item["bbox"][0]))
+        
+        # 重叠比例阈值:OCR box 与单元格的重叠面积必须 >= OCR box 面积的 50%
+        # 这确保 OCR box 主要属于该单元格
+        overlap_ratio_threshold = 0.5
+        
+        # 为每个单元格匹配OCR文本(使用中心点 + 重叠比例)
+        # 优化:使用二分查找和提前退出机制,减少遍历次数
+        # 创建一个 y1 值的列表用于二分查找(兼容 Python < 3.10)
+        ocr_y1_list = [item["bbox"][1] for item in ocr_items]
+        
+        for idx, cell_bbox in enumerate(bboxes):
+            cell_x1, cell_y1, cell_x2, cell_y2 = cell_bbox
+            matched: List[Tuple[str, float, float, float, float, Dict[str, Any]]] = [] # (text, y1, x1, overlap_ratio, score, original_box)
             
-            for ocr in ocr_items:
-                if x1 <= ocr["center_x"] <= x2 and y1 <= ocr["center_y"] <= y2:
-                    matched.append((ocr["text"], ocr["y1"], ocr["confidence"], ocr["original_box"]))
+            # 使用二分查找找到第一个 y1 >= cell_y1 的 OCR item
+            # 由于 ocr_items 已按 (y1, x1) 排序,可以使用 bisect_left
+            start_idx = bisect.bisect_left(ocr_y1_list, cell_y1)
+            
+            # 关键优化:OCR box 的 y1 可能 < cell_y1,但 y2 >= cell_y1(跨越单元格上边界)
+            # 为了不遗漏这种情况,我们需要向前查找一些 items
+            # 向前查找的最大数量:假设 OCR box 最大高度不超过 100 像素(可根据实际情况调整)
+            max_lookback = 20  # 向前查找最多 20 个 items
+            actual_start_idx = max(0, start_idx - max_lookback)
+            
+            # 从 actual_start_idx 开始遍历,当 y1 > cell_y2 时提前退出
+            for i in range(actual_start_idx, len(ocr_items)):
+                ocr_item = ocr_items[i]
+                ocr_bbox = ocr_item["bbox"]
+                
+                # 提前退出:如果 y1 > cell_y2,后续的 items 都不可能在单元格内
+                if ocr_bbox[1] > cell_y2:
+                    break
+                
+                # 快速过滤:如果 OCR box 的 y2 < cell_y1,说明它完全在单元格上方,跳过
+                if ocr_bbox[3] < cell_y1:
+                    continue
+                
+                cx = ocr_item["center_x"]
+                cy = ocr_item["center_y"]
+                
+                # 第一步:中心点必须在单元格内
+                if not (cell_x1 <= cx <= cell_x2 and cell_y1 <= cy <= cell_y2):
+                    continue
+                
+                # 第二步:检查重叠比例(OCR box 有多少部分在单元格内)
+                overlap_ratio = self.calculate_overlap_ratio(ocr_bbox, cell_bbox)
+                if overlap_ratio >= overlap_ratio_threshold:
+                    matched.append((
+                        ocr_item["text"], 
+                        ocr_bbox[1],  # y1 坐标
+                        ocr_bbox[0],  # 添加 x1 坐标
+                        overlap_ratio,
+                        ocr_item["confidence"], 
+                        ocr_item["original_box"]
+                    ))
             
             if matched:
-                # 按y坐标排序,确保多行文本顺序正确
-                matched.sort(key=lambda x: x[1])
-                texts[idx] = "".join([t for t, _, _, _ in matched])
+                # 直接按 y1 和 x1 排序,确保文本顺序正确
+                # y_tolerance 用于将相近的 y1 归为同一行(容差范围内视为同一行)
+                # 同一行内按 x1 从左到右排序
+                y_tolerance = 5
+                matched.sort(key=lambda x: (round(x[1] / y_tolerance), x[2]))  # 先按 y_group,再按 x1
+                
+                texts[idx] = "".join([t for t, _, _, _, _, _ in matched])
                 # 计算平均置信度
-                avg_score = sum([s for _, _, s, _ in matched]) / len(matched)
+                avg_score = sum([s for _, _, _, _, s, _ in matched]) / len(matched)
                 scores[idx] = avg_score
                 # 保存匹配到的 OCR boxes
-                matched_boxes_list[idx] = [box for _, _, _, box in matched]
-                
-                # 检测 OCR box 是否跨多个单元格或过大
-                for ocr_item in ocr_items:
-                    ocr_bbox = ocr_item["bbox"]
-                    # 检测是否跨多个单元格
-                    overlapping_cells = self.detect_ocr_box_spanning_cells(ocr_bbox, bboxes, overlap_threshold=0.3)
-                    if len(overlapping_cells) >= 2:
-                        # OCR box 跨多个单元格,标记所有相关单元格需要二次 OCR
-                        for cell_idx in overlapping_cells:
-                            if cell_idx not in need_reocr_indices:
-                                need_reocr_indices.append(cell_idx)
-                        logger.debug(f"检测到 OCR box 跨 {len(overlapping_cells)} 个单元格: {ocr_item['text'][:20]}...")
-                    
-                    # 检测 OCR box 是否相对于当前单元格过大
-                    if self.is_ocr_box_too_large(ocr_bbox, bbox, size_ratio_threshold=1.5):
-                        if idx not in need_reocr_indices:
-                            need_reocr_indices.append(idx)
-                        logger.debug(f"检测到 OCR box 相对于单元格过大 (单元格 {idx}): {ocr_item['text'][:20]}...")
+                matched_boxes_list[idx] = [box for _, _, _, _, _, box in matched]
             else:
                 scores[idx] = 0.0 # 无匹配文本,置信度为0
         
@@ -189,35 +266,6 @@ class TextFiller:
         
         return overlapping_cells
     
-    @staticmethod
-    def is_ocr_box_too_large(
-        ocr_bbox: List[float],
-        cell_bbox: List[float],
-        size_ratio_threshold: float = 1.5
-    ) -> bool:
-        """
-        检测 OCR box 是否相对于单元格过大
-        
-        Args:
-            ocr_bbox: OCR box 坐标 [x1, y1, x2, y2]
-            cell_bbox: 单元格坐标 [x1, y1, x2, y2]
-            size_ratio_threshold: 面积比阈值,如果 OCR box 面积 > 单元格面积 * 阈值,则认为过大
-            
-        Returns:
-            是否过大
-        """
-        if not ocr_bbox or len(ocr_bbox) < 4 or not cell_bbox or len(cell_bbox) < 4:
-            return False
-        
-        ocr_area = (ocr_bbox[2] - ocr_bbox[0]) * (ocr_bbox[3] - ocr_bbox[1])
-        cell_area = (cell_bbox[2] - cell_bbox[0]) * (cell_bbox[3] - cell_bbox[1])
-        
-        if cell_area <= 0:
-            return False
-        
-        size_ratio = ocr_area / cell_area
-        return size_ratio > size_ratio_threshold
-    
     def second_pass_ocr_fill(
         self,
         table_image: np.ndarray,
@@ -377,6 +425,10 @@ class TextFiller:
                     
                     if x2 > x1 and y2 > y1:
                         cropped = cell_img[y1:y2, x1:x2]
+                        ch, cw = cropped.shape[:2]
+                        # 小图放大
+                        if ch < 64 or cw < 64:
+                            cropped = cv2.resize(cropped, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
                         if cropped.size > 0:
                             rec_img_list.append(cropped)
                             rec_indices.append((cell_idx, box_idx))