فهرست منبع

feat(cell_fusion): 增强单元格融合逻辑,支持UNet过度合并拆分,添加新配置参数

zhch158_admin 1 هفته پیش
والد
کامیت
ce673e8fc6
1فایلهای تغییر یافته به همراه141 افزوده شده و 36 حذف شده
  1. 141 36
      ocr_tools/universal_doc_parser/models/adapters/wired_table/cell_fusion.py

+ 141 - 36
ocr_tools/universal_doc_parser/models/adapters/wired_table/cell_fusion.py

@@ -57,11 +57,22 @@ class CellFusionEngine:
         self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
         self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
         self.enable_boundary_noise_filter = self.config.get('enable_boundary_noise_filter', True)
+        self.unet_split_min_count = self.config.get('unet_split_min_count', 2)
+        self.rtdetr_split_cover_threshold = self.config.get('rtdetr_split_cover_threshold', 0.5)
+        self.unet_split_cover_threshold = self.config.get('unet_split_cover_threshold', 0.5)
+        self.unet_split_rtdetr_score_threshold = self.config.get(
+            'unet_split_rtdetr_score_threshold',
+            self.rtdetr_conf_threshold
+        )
         
         logger.info(f"🔧 CellFusionEngine initialized: "
-                   f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
-                   f"iou_merge={self.iou_merge_threshold}, ocr_comp={self.enable_ocr_compensation}, "
-                   f"boundary_filter={self.enable_boundary_noise_filter}")
+               f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
+               f"iou_merge={self.iou_merge_threshold}, ocr_comp={self.enable_ocr_compensation}, "
+               f"boundary_filter={self.enable_boundary_noise_filter}, "
+               f"unet_split_min={self.unet_split_min_count}, "
+               f"unet_split_cover={self.unet_split_cover_threshold}, "
+               f"unet_split_score={self.unet_split_rtdetr_score_threshold}, "
+               f"rtdetr_split_cover={self.rtdetr_split_cover_threshold}")
     
     def should_use_rtdetr(
         self,
@@ -99,6 +110,7 @@ class CellFusionEngine:
         table_image: np.ndarray,
         unet_cells: List[List[float]],
         ocr_boxes: List[Dict[str, Any]],
+        ocr_text_pixel_tolerance: float = 10.0,
         pdf_type: str = 'ocr',
         debug_dir: Optional[str] = None,
         debug_prefix: str = "fusion"
@@ -110,6 +122,7 @@ class CellFusionEngine:
             table_image: 表格图像(原图坐标系)
             unet_cells: UNet检测的单元格列表 [[x1,y1,x2,y2], ...](原图坐标系)
             ocr_boxes: OCR结果列表
+            ocr_text_pixel_tolerance: OCR文本容差(原图坐标系,默认10.0)
             pdf_type: PDF类型 ('txt' 或 'ocr')
             debug_dir: 调试输出目录(可选)
             debug_prefix: 调试文件前缀
@@ -126,7 +139,7 @@ class CellFusionEngine:
             max(unet_cells, key=lambda box: box[2])[2], \
             max(unet_cells, key=lambda box: box[3])[3]
         ] if unet_cells else [0,0,0,0]
-        
+
         # 决策:是否使用 RT-DETR
         use_rtdetr = self.should_use_rtdetr(pdf_type, len(unet_cells), (w, h))
         
@@ -165,8 +178,8 @@ class CellFusionEngine:
                 table_image,
                 conf_threshold=self.rtdetr_conf_threshold
             )
-            # rtdetr_result从上到下,从左到右排序
-            rtdetr_results.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
+            # rtdetr_result从上到下,从左到右排序, 排序按取整后,容差为10
+            rtdetr_results = sorted(rtdetr_results, key=lambda x: (round(x['bbox'][1] / 10), round(x['bbox'][0])))
             rtdetr_cells = [res['bbox'] for res in rtdetr_results]
             rtdetr_scores = [res['score'] for res in rtdetr_results]
             fusion_stats['rtdetr_count'] = len(rtdetr_cells)
@@ -179,7 +192,7 @@ class CellFusionEngine:
         
         # Phase 2: 智能融合
         # 使用稳健边界估计(避免单个超大单元格撑开边界)
-        table_bbox = self._estimate_robust_table_bbox(rtdetr_cells)
+        table_bbox = self._estimate_robust_table_bbox(rtdetr_cells, ocr_text_pixel_tolerance)
         
         # 将所有单元格的边界限制在表格边界内
         # rtdetr_cells = self._clip_cells_to_bbox(rtdetr_cells, table_bbox)
@@ -190,6 +203,7 @@ class CellFusionEngine:
         fusion_stats['merged_count'] = merge_stats['merged']
         fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
         fusion_stats['added_count'] = merge_stats['added']
+        fusion_stats['split_count'] = merge_stats.get('split', 0)
         
         # Phase 3: NMS 去重
         fused_cells, suppressed = self._nms_filter(fused_cells, self.iou_nms_threshold)
@@ -199,7 +213,8 @@ class CellFusionEngine:
         # Phase 4: 边界噪声过滤(过滤掉边界的 unet_only 噪声单元格)
         if self.enable_boundary_noise_filter:
             fused_cells, cell_labels, noise_filtered = self._filter_boundary_noise(
-                fused_cells, cell_labels, ocr_boxes, table_bbox
+                fused_cells, cell_labels, ocr_boxes, table_bbox,
+                boundary_tolerance=ocr_text_pixel_tolerance
             )
             fusion_stats['noise_filtered_count'] = noise_filtered
         else:
@@ -220,7 +235,7 @@ class CellFusionEngine:
         logger.info(
             f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
             f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
-            f"Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, "
+            f"Split={merge_stats.get('split', 0)}, Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, "
             f"OCRCompensated={fusion_stats.get('ocr_compensated_count', 0)}, Final={len(fused_cells)}"
         )
         
@@ -243,13 +258,14 @@ class CellFusionEngine:
         """
         融合 UNet 和 RT-DETR 检测结果(增强版:支持合并单元格检测)
         
-        融合规则:
-        1. 检测RT-DETR的合并单元格(一对多匹配,基于包含关系)
-           - 判断RT-DETR单元格包含多少个UNet单元格
-           - 使用中心点+包含率判断(而非IoU)
-        2. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并(一对一)
-        3. RT-DETR 独有 + 高置信度 (>0.7) → 补充
-        4. UNet 独有 → 保留
+          融合规则:
+          1. 检测RT-DETR的合并单元格(一对多匹配,基于包含关系)
+              - 判断RT-DETR单元格包含多少个UNet单元格
+              - 使用中心点+包含率判断(而非IoU)
+          2. 检测UNet过度合并(一个UNet包含多个RT-DETR)并拆分
+          3. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并(一对一)
+          4. RT-DETR 独有 + 高置信度 (>0.7) → 补充
+          5. UNet 独有 → 保留
         
         包含关系判断逻辑:
         - UNet单元格的中心点在RT-DETR内
@@ -267,14 +283,14 @@ class CellFusionEngine:
             (fused_cells, stats, cell_labels)
             - fused_cells: 融合后的单元格
             - stats: {'merged': int, 'added': int, 'merged_cells': int}
-            - cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'new']
+            - cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'split_rtdetr', 'new']
         """
         
         fused_cells = []
         cell_labels = []  # 记录每个单元格的来源标签
         unet_matched = [False] * len(unet_cells)
         rtdetr_matched = [False] * len(rtdetr_cells)
-        stats = {'merged': 0, 'added': 0, 'merged_cells': 0}
+        stats = {'merged': 0, 'added': 0, 'merged_cells': 0, 'split': 0}
         
         # Step 1: 检测RT-DETR的合并单元格(一对多匹配)
         # 遍历RT-DETR单元格,查找被包含的多个UNet单元格
@@ -332,7 +348,7 @@ class CellFusionEngine:
                     coverage = min(total_unet_area / rtdetr_area, 1.0) if rtdetr_area > 0 else 0
                     
                     # 如果覆盖率>50%,说明这是一个真实的合并单元格
-                    if coverage > 0.5:
+                    if coverage > self.rtdetr_split_cover_threshold:
                         # 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过table_bbox范围
                         fused_cell = [
                             min(bounding_x1, rtdetr_cell[0]),
@@ -342,9 +358,9 @@ class CellFusionEngine:
                         ]
                         # x限制在table_bbox范围内
                         fused_cell[0] = max(fused_cell[0], table_bbox[0])
-                        # fused_cell[1] = max(fused_cell[1], table_bbox[1])
+                        fused_cell[1] = max(fused_cell[1], table_bbox[1])
                         fused_cell[2] = min(fused_cell[2], table_bbox[2])
-                        # fused_cell[3] = min(fused_cell[3], table_bbox[3])
+                        fused_cell[3] = min(fused_cell[3], table_bbox[3])
                         fused_cells.append(fused_cell)
                         cell_labels.append('merged_span')  # 标记为合并单元格
                         rtdetr_matched[rt_idx] = True
@@ -357,6 +373,80 @@ class CellFusionEngine:
                             f"(coverage={coverage:.2f}, score={rtdetr_scores[rt_idx]:.2f})"
                         )
         
+        # Step 1.5: 检测UNet过度合并(一个UNet包含多个RT-DETR)并拆分
+        for u_idx, unet_cell in enumerate(unet_cells):
+            if unet_matched[u_idx]:
+                continue
+
+            unet_area = self._calc_bbox_area(unet_cell)
+            if unet_area <= 0:
+                continue
+
+            contained_rtdetr = []
+            contained_intersects = []
+
+            for rt_idx, rtdetr_cell in enumerate(rtdetr_cells):
+                if rtdetr_matched[rt_idx]:
+                    continue
+                if rtdetr_scores[rt_idx] < self.unet_split_rtdetr_score_threshold:
+                    continue
+
+                rt_cx = (rtdetr_cell[0] + rtdetr_cell[2]) / 2
+                rt_cy = (rtdetr_cell[1] + rtdetr_cell[3]) / 2
+                if not (unet_cell[0] <= rt_cx <= unet_cell[2] and
+                        unet_cell[1] <= rt_cy <= unet_cell[3]):
+                    continue
+
+                intersect_x1 = max(unet_cell[0], rtdetr_cell[0])
+                intersect_y1 = max(unet_cell[1], rtdetr_cell[1])
+                intersect_x2 = min(unet_cell[2], rtdetr_cell[2])
+                intersect_y2 = min(unet_cell[3], rtdetr_cell[3])
+                if intersect_x2 <= intersect_x1 or intersect_y2 <= intersect_y1:
+                    continue
+
+                intersect_area = (intersect_x2 - intersect_x1) * (intersect_y2 - intersect_y1)
+                rtdetr_area = self._calc_bbox_area(rtdetr_cell)
+                contain_ratio = intersect_area / rtdetr_area if rtdetr_area > 0 else 0
+                if contain_ratio > 0.5:
+                    contained_rtdetr.append(rt_idx)
+                    contained_intersects.append(intersect_area)
+
+            if len(contained_rtdetr) >= self.unet_split_min_count:
+                # 计算总包含率:使用所有被包含RT-DETR单元格的外接矩形面积 vs UNet面积
+                # 与RT-DETR合并逻辑保持一致,避免相邻框重复/间隙导致覆盖率失真
+                rt_indices = contained_rtdetr
+                bounding_x1 = min(rtdetr_cells[i][0] for i in rt_indices)
+                bounding_y1 = min(rtdetr_cells[i][1] for i in rt_indices)
+                bounding_x2 = max(rtdetr_cells[i][2] for i in rt_indices)
+                bounding_y2 = max(rtdetr_cells[i][3] for i in rt_indices)
+                total_rtdetr_area = (bounding_x2 - bounding_x1) * (bounding_y2 - bounding_y1)
+                coverage = min(total_rtdetr_area / unet_area, 1.0)
+                if coverage >= self.unet_split_cover_threshold:
+                    # 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过table_bbox范围
+                    split_cell = [
+                        min(bounding_x1, unet_cell[0]),
+                        min(bounding_y1, unet_cell[1]),
+                        max(bounding_x2, unet_cell[2]),
+                        max(bounding_y2, unet_cell[3])
+                    ]
+                    split_cell = [
+                        max(split_cell[0], table_bbox[0]),
+                        max(split_cell[1], table_bbox[1]),
+                        min(split_cell[2], table_bbox[2]),
+                        min(split_cell[3], table_bbox[3])
+                    ]
+                    fused_cells.append(split_cell)
+                    cell_labels.append('split_rtdetr')
+                    for rt_idx in contained_rtdetr:
+                        rtdetr_matched[rt_idx] = True
+
+                    unet_matched[u_idx] = True
+                    stats['split'] += len(contained_rtdetr)
+                    logger.debug(
+                        f"🧩 UNet过度合并拆分: UNet[{u_idx}] -> {len(contained_rtdetr)} RT-DETR "
+                        f"(coverage={coverage:.2f})"
+                    )
+
         # Step 2: 一对一匹配(处理剩余的单元格)
         for u_idx, unet_cell in enumerate(unet_cells):
             if unet_matched[u_idx]:
@@ -401,9 +491,9 @@ class CellFusionEngine:
             if not rtdetr_matched[idx] and score > 0.7:
                 # rtdetr_cell不能超出table_bbox范围, x方向分别限制
                 rtdetr_cell[0] = max(rtdetr_cell[0], table_bbox[0])
-                # rtdetr_cell[1] = max(rtdetr_cell[1], table_bbox[1])
+                rtdetr_cell[1] = max(rtdetr_cell[1], table_bbox[1])
                 rtdetr_cell[2] = min(rtdetr_cell[2], table_bbox[2])
-                # rtdetr_cell[3] = min(rtdetr_cell[3], table_bbox[3])
+                rtdetr_cell[3] = min(rtdetr_cell[3], table_bbox[3])
                 fused_cells.append(rtdetr_cell)
                 cell_labels.append('rtdetr_only')  # 标记为RT-DETR独有
                 stats['added'] += 1
@@ -418,18 +508,16 @@ class CellFusionEngine:
         """
         稳健的表格边界估计
         
-        使用聚类方法找到"主流"的左右边界,避免单个超大单元格撑开边界。
+        使用聚类方法找到"主流"的边界,避免单个超大单元格撑开边界。
         
         算法:
-        1. 收集所有单元格的左边界x1和右边界x2
-        2. 对x1聚类,选择支持度最高的聚类中心作为表格左边界
-        3. 对x2聚类,选择支持度最高的聚类中心作为表格右边界
-        4. y方向使用简单的min/max(行高变化大,不适合聚类)
+        1. 收集所有单元格的边界
+        2. 聚类,选择支持度最高的聚类中心作为表格边界
+        3. 通过容差向内调整边界,过滤掉过于宽松的边界(可能包含噪声单元格)
         
         Args:
             rtdetr_cells: RT-DETR单元格列表
             cluster_tolerance: 聚类容差(像素)
-            
         Returns:
             table_bbox: [x1, y1, x2, y2]
         """
@@ -448,11 +536,13 @@ class CellFusionEngine:
         # 对x2聚类,找主流右边界
         robust_x2 = self._find_dominant_boundary(x2_coords, cluster_tolerance, mode='max')
         # y方向直接取极值
-        robust_y1 = min(y1_coords)
-        robust_y2 = max(y2_coords)
+        robust_y1 = self._find_dominant_boundary(y1_coords, cluster_tolerance, mode='min')
+        robust_y2 = self._find_dominant_boundary(y2_coords, cluster_tolerance, mode='max')
         
         logger.debug(f"📐 稳健边界估计: x=[{robust_x1:.1f}, {robust_x2:.1f}], "
-                    f"原始x范围=[{min(x1_coords):.1f}, {max(x2_coords):.1f}]")
+                    f"原始x范围=[{min(x1_coords):.1f}, {max(x2_coords):.1f}]"
+                    f" | y=[{robust_y1:.1f}, {robust_y2:.1f}], "
+                    f"原始y范围=[{min(y1_coords):.1f}, {max(y2_coords):.1f}]")
         
         return [robust_x1, robust_y1, robust_x2, robust_y2]
     
@@ -624,7 +714,8 @@ class CellFusionEngine:
         cells: List[List[float]],
         cell_labels: List[str],
         ocr_boxes: List[Dict[str, Any]],
-        rtdetr_bbox: List[float]
+        rtdetr_bbox: List[float],
+        boundary_tolerance: float = 0.0
     ) -> Tuple[List[List[float]], List[str], int]:
         """
         过滤边界噪声单元格
@@ -639,6 +730,7 @@ class CellFusionEngine:
             cell_labels: 单元格标签列表
             ocr_boxes: OCR结果列表
             rtdetr_bbox: RT-DETR单元格的边界框 [x1, y1, x2, y2]
+            boundary_tolerance: 边界判定容忍范围(像素,原图坐标系)
         Returns:
             (filtered_cells, filtered_labels, filtered_count)
         """
@@ -646,6 +738,8 @@ class CellFusionEngine:
         filtered_labels = []
         filtered_count = 0
         
+        tol = max(0.0, boundary_tolerance)
+
         for cell, label in zip(cells, cell_labels):
             # # 只过滤 unet_only 标记的单元格
             # if label != 'unet_only':
@@ -655,9 +749,9 @@ class CellFusionEngine:
             
             x1, y1, x2, y2 = cell
             
-            # 检查是否在边界
-            is_left_boundary = x1 <= rtdetr_bbox[0]
-            is_right_boundary = x2 >= rtdetr_bbox[2]
+            # 检查是否在边界(加入容忍范围,避免贴边被误判)
+            is_left_boundary = x1 <= (rtdetr_bbox[0] - tol)
+            is_right_boundary = x2 >= (rtdetr_bbox[2] + tol)
             is_on_boundary = is_left_boundary or is_right_boundary
             
             if not is_on_boundary:
@@ -906,6 +1000,7 @@ class CellFusionEngine:
             merged_cells_1to1 = []  # 1:1融合单元格(黄色)
             merged_cells_span = []  # 合并单元格(品红色,RT-DETR检测的跨格单元格)
             new_cells = []  # 新增单元格(紫色)
+            split_cells = []  # UNet拆分得到的RT-DETR单元格(青色)
             ocr_compensated = []  # OCR补偿单元格(橙色)
             
             for fused_cell, label in zip(fused_cells, cell_labels):
@@ -919,6 +1014,8 @@ class CellFusionEngine:
                     merged_cells_span.append(fused_cell)
                 elif label == 'new':
                     new_cells.append(fused_cell)
+                elif label == 'split_rtdetr':
+                    split_cells.append(fused_cell)
                 elif label == 'ocr_compensated':
                     ocr_compensated.append(fused_cell)
             
@@ -942,6 +1039,10 @@ class CellFusionEngine:
             for cell in new_cells:
                 x1, y1, x2, y2 = [int(v) for v in cell]
                 cv2.rectangle(img3, (x1, y1), (x2, y2), (128, 0, 128), 2)  # 紫色 - 新增
+
+            for cell in split_cells:
+                x1, y1, x2, y2 = [int(v) for v in cell]
+                cv2.rectangle(img3, (x1, y1), (x2, y2), (255, 255, 0), 3)  # 青色 - UNet拆分
             
             for cell in ocr_compensated:
                 x1, y1, x2, y2 = [int(v) for v in cell]
@@ -967,6 +1068,10 @@ class CellFusionEngine:
                 legend_y += 30
                 cv2.putText(img3, f"Purple: New ({len(new_cells)})", (10, legend_y),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (128, 0, 128), 2)
+            if split_cells:
+                legend_y += 30
+                cv2.putText(img3, f"Cyan: Split ({len(split_cells)})", (10, legend_y),
+                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2)
             if ocr_compensated:
                 legend_y += 30
                 cv2.putText(img3, f"Orange: OCR Compensated ({len(ocr_compensated)})", (10, legend_y),