Explorar el Código

feat(cell_fusion): 增强单元格融合逻辑,支持合并单元格检测并添加来源标签

zhch158_admin hace 2 semanas
padre
commit
6e0bcc305d

+ 214 - 48
ocr_tools/universal_doc_parser/models/adapters/wired_table/cell_fusion.py

@@ -140,12 +140,13 @@ class CellFusionEngine:
         # 如果不使用RT-DETR,直接返回UNet结果
         if not use_rtdetr:
             fused_cells = unet_cells.copy()
+            cell_labels = ['unet_only'] * len(fused_cells)  # 所有都是UNet独有
             fusion_stats['fused_count'] = len(fused_cells)
             
             # 可选:OCR补偿
             if self.enable_ocr_compensation and ocr_boxes:
-                fused_cells, ocr_comp_count = self._compensate_with_ocr(
-                    fused_cells, ocr_boxes, (w, h)
+                fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
+                    fused_cells, cell_labels, ocr_boxes, (w, h)
                 )
                 fusion_stats['ocr_compensated_count'] = ocr_comp_count
                 fusion_stats['fused_count'] = len(fused_cells)
@@ -159,6 +160,8 @@ class CellFusionEngine:
                 table_image,
                 conf_threshold=self.rtdetr_conf_threshold
             )
+            # rtdetr_result从上到下,从左到右排序
+            rtdetr_results.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
             rtdetr_cells = [res['bbox'] for res in rtdetr_results]
             rtdetr_scores = [res['score'] for res in rtdetr_results]
             fusion_stats['rtdetr_count'] = len(rtdetr_cells)
@@ -171,10 +174,11 @@ class CellFusionEngine:
             return fused_cells, fusion_stats
         
         # Phase 2: 智能融合
-        fused_cells, merge_stats = self._fuse_cells(
+        fused_cells, merge_stats, cell_labels = self._fuse_cells(
             unet_cells, rtdetr_cells, rtdetr_scores
         )
         fusion_stats['merged_count'] = merge_stats['merged']
+        fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
         fusion_stats['added_count'] = merge_stats['added']
         
         # Phase 3: NMS 去重
@@ -182,8 +186,8 @@ class CellFusionEngine:
         
         # Phase 4: OCR 补偿(可选)
         if self.enable_ocr_compensation and ocr_boxes:
-            fused_cells, ocr_comp_count = self._compensate_with_ocr(
-                fused_cells, ocr_boxes, (w, h)
+            fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
+                fused_cells, cell_labels, ocr_boxes, (w, h)
             )
             fusion_stats['ocr_compensated_count'] = ocr_comp_count
         
@@ -191,14 +195,14 @@ class CellFusionEngine:
         
         logger.info(
             f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
-            f"Merged={merge_stats['merged']}, Added={merge_stats['added']}, "
-            f"Final={len(fused_cells)}"
+            f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
+            f"Added={merge_stats['added']}, Final={len(fused_cells)}"
         )
         
         # 可视化(调试)
         if debug_dir:
             self._visualize_fusion(
-                table_image, unet_cells, rtdetr_cells, fused_cells,
+                table_image, unet_cells, rtdetr_cells, fused_cells, cell_labels,
                 debug_dir, debug_prefix
             )
         
@@ -209,14 +213,23 @@ class CellFusionEngine:
         unet_cells: List[List[float]],
         rtdetr_cells: List[List[float]],
         rtdetr_scores: List[float]
-    ) -> Tuple[List[List[float]], Dict[str, int]]:
+    ) -> Tuple[List[List[float]], Dict[str, int], List[str]]:
         """
-        融合 UNet 和 RT-DETR 检测结果
+        融合 UNet 和 RT-DETR 检测结果(增强版:支持合并单元格检测)
         
         融合规则:
-        1. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并
-        2. RT-DETR 独有 + 高置信度 (>0.7) → 补充
-        3. UNet 独有 → 保留
+        1. 检测RT-DETR的合并单元格(一对多匹配,基于包含关系)
+           - 判断RT-DETR单元格包含多少个UNet单元格
+           - 使用中心点+包含率判断(而非IoU)
+        2. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并(一对一)
+        3. RT-DETR 独有 + 高置信度 (>0.7) → 补充
+        4. UNet 独有 → 保留
+        
+        包含关系判断逻辑:
+        - UNet单元格的中心点在RT-DETR内
+        - UNet单元格的50%以上面积在RT-DETR内
+        - RT-DETR包含≥2个UNet单元格
+        - 总覆盖率>40%(所有UNet面积之和 / RT-DETR面积)
         
         Args:
             unet_cells: UNet单元格列表
@@ -224,28 +237,104 @@ class CellFusionEngine:
             rtdetr_scores: RT-DETR置信度列表
             
         Returns:
-            (fused_cells, stats)
+            (fused_cells, stats, cell_labels)
             - fused_cells: 融合后的单元格
-            - stats: {'merged': int, 'added': int}
+            - stats: {'merged': int, 'added': int, 'merged_cells': int}
+            - cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'new']
         """
         fused_cells = []
+        cell_labels = []  # 记录每个单元格的来源标签
+        unet_matched = [False] * len(unet_cells)
         rtdetr_matched = [False] * len(rtdetr_cells)
-        stats = {'merged': 0, 'added': 0}
+        stats = {'merged': 0, 'added': 0, 'merged_cells': 0}
         
-        # Step 1: 遍历 UNet 单元格,尝试与 RT-DETR 匹配
-        for unet_cell in unet_cells:
+        # Step 1: 检测RT-DETR的合并单元格(一对多匹配)
+        # 遍历RT-DETR单元格,查找被包含的多个UNet单元格
+        for rt_idx, rtdetr_cell in enumerate(rtdetr_cells):
+            if rtdetr_matched[rt_idx]:
+                continue
+            
+            # 查找所有被当前RT-DETR单元格包含(或大部分包含)的UNet单元格
+            contained_unet = []
+            for u_idx, unet_cell in enumerate(unet_cells):
+                if unet_matched[u_idx]:
+                    continue
+                
+                # 判断UNet单元格是否被RT-DETR单元格包含
+                # 方法1: 检查UNet的中心点是否在RT-DETR内
+                unet_cx = (unet_cell[0] + unet_cell[2]) / 2
+                unet_cy = (unet_cell[1] + unet_cell[3]) / 2
+                
+                if (rtdetr_cell[0] <= unet_cx <= rtdetr_cell[2] and 
+                    rtdetr_cell[1] <= unet_cy <= rtdetr_cell[3]):
+                    # UNet中心点在RT-DETR内,计算包含程度
+                    # 计算UNet有多少面积在RT-DETR内
+                    intersect_x1 = max(unet_cell[0], rtdetr_cell[0])
+                    intersect_y1 = max(unet_cell[1], rtdetr_cell[1])
+                    intersect_x2 = min(unet_cell[2], rtdetr_cell[2])
+                    intersect_y2 = min(unet_cell[3], rtdetr_cell[3])
+                    
+                    if intersect_x2 > intersect_x1 and intersect_y2 > intersect_y1:
+                        intersect_area = (intersect_x2 - intersect_x1) * (intersect_y2 - intersect_y1)
+                        unet_area = (unet_cell[2] - unet_cell[0]) * (unet_cell[3] - unet_cell[1])
+                        contain_ratio = intersect_area / unet_area if unet_area > 0 else 0
+                        
+                        # 如果UNet单元格的50%以上在RT-DETR内,认为被包含
+                        if contain_ratio > 0.5:
+                            contained_unet.append((u_idx, contain_ratio))
+            
+            # 判断是否为合并单元格(RT-DETR包含多个UNet单元格)
+            if len(contained_unet) >= 2:
+                # 合并单元格场景:优先使用RT-DETR的大框
+                # 条件:1) 包含2个以上UNet单元格 2) RT-DETR置信度足够高
+                if rtdetr_scores[rt_idx] > 0.7:
+                    # 计算总包含率:使用所有被包含UNet单元格的外接矩形面积 vs RT-DETR面积
+                    # 使用外接矩形更合理,因为:
+                    # 1. 合并单元格是一个完整区域,应包括单元格间隙
+                    # 2. 避免重复计算相邻单元格的边界
+                    # 3. 更准确反映覆盖率(如11个连续单元格应该接近100%覆盖)
+                    unet_indices = [u_idx for u_idx, _ in contained_unet]
+                    bounding_x1 = min(unet_cells[i][0] for i in unet_indices)
+                    bounding_y1 = min(unet_cells[i][1] for i in unet_indices)
+                    bounding_x2 = max(unet_cells[i][2] for i in unet_indices)
+                    bounding_y2 = max(unet_cells[i][3] for i in unet_indices)
+                    total_unet_area = (bounding_x2 - bounding_x1) * (bounding_y2 - bounding_y1)
+                    
+                    rtdetr_area = self._calc_bbox_area(rtdetr_cell)
+                    coverage = min(total_unet_area / rtdetr_area, 1.0) if rtdetr_area > 0 else 0
+                    
+                    # 如果覆盖率>40%,说明这是一个真实的合并单元格
+                    # 降低阈值从0.5到0.4,因为合并单元格可能包含很多空白区域
+                    if coverage > 0.4:
+                        fused_cells.append(rtdetr_cell)
+                        cell_labels.append('merged_span')  # 标记为合并单元格
+                        rtdetr_matched[rt_idx] = True
+                        # 标记所有被包含的UNet单元格
+                        for u_idx, contain_ratio in contained_unet:
+                            unet_matched[u_idx] = True
+                        stats['merged_cells'] += 1
+                        logger.debug(
+                            f"🔗 检测到合并单元格: RT-DETR[{rt_idx}] 包含 {len(contained_unet)} 个UNet单元格 "
+                            f"(coverage={coverage:.2f}, score={rtdetr_scores[rt_idx]:.2f})"
+                        )
+        
+        # Step 2: 一对一匹配(处理剩余的单元格)
+        for u_idx, unet_cell in enumerate(unet_cells):
+            if unet_matched[u_idx]:
+                continue
+            
             best_match_idx = -1
             best_iou = 0.0
             
             # 查找最佳匹配的 RT-DETR 单元格
-            for idx, rtdetr_cell in enumerate(rtdetr_cells):
-                if rtdetr_matched[idx]:
+            for rt_idx, rtdetr_cell in enumerate(rtdetr_cells):
+                if rtdetr_matched[rt_idx]:
                     continue
                 
                 iou = CoordinateUtils.calculate_iou(unet_cell, rtdetr_cell)
                 if iou > best_iou:
                     best_iou = iou
-                    best_match_idx = idx
+                    best_match_idx = rt_idx
             
             # 判断是否合并
             if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
@@ -257,19 +346,28 @@ class CellFusionEngine:
                     self.rtdetr_weight
                 )
                 fused_cells.append(merged_cell)
+                cell_labels.append('merged_1to1')  # 标记为1:1融合
                 rtdetr_matched[best_match_idx] = True
+                unet_matched[u_idx] = True
                 stats['merged'] += 1
             else:
                 # UNet 独有:保留
                 fused_cells.append(unet_cell)
+                cell_labels.append('unet_only')  # 标记为UNet独有
+                unet_matched[u_idx] = True
         
-        # Step 2: 补充 RT-DETR 独有的高置信度单元格
+        # Step 3: 补充 RT-DETR 独有的高置信度单元格
         for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
             if not rtdetr_matched[idx] and score > 0.7:
                 fused_cells.append(rtdetr_cell)
+                cell_labels.append('rtdetr_only')  # 标记为RT-DETR独有
                 stats['added'] += 1
         
-        return fused_cells, stats
+        return fused_cells, stats, cell_labels
+    
+    def _calc_bbox_area(self, bbox: List[float]) -> float:
+        """计算bbox面积"""
+        return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
     
     def _weighted_merge_bbox(
         self,
@@ -345,9 +443,10 @@ class CellFusionEngine:
     def _compensate_with_ocr(
         self,
         cells: List[List[float]],
+        cell_labels: List[str],
         ocr_boxes: List[Dict[str, Any]],
         table_size: Tuple[int, int]
-    ) -> Tuple[List[List[float]], int]:
+    ) -> Tuple[List[List[float]], List[str], int]:
         """
         使用 OCR 补偿遗漏的单元格
         
@@ -355,13 +454,15 @@ class CellFusionEngine:
         
         Args:
             cells: 现有单元格列表
+            cell_labels: 单元格标签列表
             ocr_boxes: OCR结果列表
             table_size: 表格尺寸 (width, height)
             
         Returns:
-            (compensated_cells, compensation_count)
+            (compensated_cells, compensated_labels, compensation_count)
         """
         compensated = cells.copy()
+        compensated_labels = cell_labels.copy()
         compensation_count = 0
         w, h = table_size
         
@@ -391,26 +492,27 @@ class CellFusionEngine:
                 # 扩展 OCR bbox 作为新单元格
                 if len(ocr_bbox) == 8:
                     new_cell = [
-                        max(0, min(ocr_bbox[0], ocr_bbox[6]) - 5),
-                        max(0, min(ocr_bbox[1], ocr_bbox[3]) - 5),
-                        min(w, max(ocr_bbox[2], ocr_bbox[4]) + 5),
-                        min(h, max(ocr_bbox[5], ocr_bbox[7]) + 5)
+                        float(max(0, min(ocr_bbox[0], ocr_bbox[6]) - 5)),
+                        float(max(0, min(ocr_bbox[1], ocr_bbox[3]) - 5)),
+                        float(min(w, max(ocr_bbox[2], ocr_bbox[4]) + 5)),
+                        float(min(h, max(ocr_bbox[5], ocr_bbox[7]) + 5))
                     ]
                 else:
                     new_cell = [
-                        max(0, ocr_bbox[0] - 5),
-                        max(0, ocr_bbox[1] - 5),
-                        min(w, ocr_bbox[2] + 5),
-                        min(h, ocr_bbox[3] + 5)
+                        float(max(0, ocr_bbox[0] - 5)),
+                        float(max(0, ocr_bbox[1] - 5)),
+                        float(min(w, ocr_bbox[2] + 5)),
+                        float(min(h, ocr_bbox[3] + 5))
                     ]
                 
                 compensated.append(new_cell)
+                compensated_labels.append('new')  # 标记为新增(OCR补偿)
                 compensation_count += 1
         
         if compensation_count > 0:
             logger.debug(f"OCR compensation: added {compensation_count} cells")
         
-        return compensated, compensation_count
+        return compensated, compensated_labels, compensation_count
     
     def _visualize_fusion(
         self,
@@ -418,10 +520,11 @@ class CellFusionEngine:
         unet_cells: List[List[float]],
         rtdetr_cells: List[List[float]],
         fused_cells: List[List[float]],
+        cell_labels: List[str],
         debug_dir: str,
         debug_prefix: str
     ):
-        """可视化融合结果(调试用)"""
+        """可视化融合结果(调试用)- 增强版:用颜色区分不同来源的单元格"""
         try:
             import cv2
             from pathlib import Path
@@ -429,35 +532,96 @@ class CellFusionEngine:
             output_dir = Path(debug_dir)
             output_dir.mkdir(parents=True, exist_ok=True)
             
-            # 创建三栏对比图
             h, w = table_image.shape[:2]
-            vis_canvas = np.zeros((h, w * 3, 3), dtype=np.uint8)
             
-            # 左栏:UNet
+            # === 图1:UNet原始结果 ===
             img1 = table_image.copy()
+            if len(img1.shape) == 2:
+                img1 = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR)
             for cell in unet_cells:
                 x1, y1, x2, y2 = [int(v) for v in cell]
-                cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 255, 0), 2)
+                cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 255, 0), 2)  # 绿色
             cv2.putText(img1, f"UNet ({len(unet_cells)})", (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
             
-            # 中栏:RT-DETR
+            # === 图2:RT-DETR原始结果 ===
             img2 = table_image.copy()
+            if len(img2.shape) == 2:
+                img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR)
             for cell in rtdetr_cells:
                 x1, y1, x2, y2 = [int(v) for v in cell]
-                cv2.rectangle(img2, (x1, y1), (x2, y2), (255, 0, 0), 2)
+                cv2.rectangle(img2, (x1, y1), (x2, y2), (255, 0, 0), 2)  # 蓝色
             cv2.putText(img2, f"RT-DETR ({len(rtdetr_cells)})", (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
             
-            # 右栏:融合结果
+            # === 图3:融合结果(按来源分类)===
             img3 = table_image.copy()
-            for cell in fused_cells:
+            if len(img3.shape) == 2:
+                img3 = cv2.cvtColor(img3, cv2.COLOR_GRAY2BGR)
+            
+            # 根据标签分类单元格(使用 _fuse_cells 中记录的标签)
+            unet_only = []  # UNet独有(绿色)
+            rtdetr_only = []  # RT-DETR独有(蓝色)
+            merged_cells_1to1 = []  # 1:1融合单元格(黄色)
+            merged_cells_span = []  # 合并单元格(品红色,RT-DETR检测的跨格单元格)
+            new_cells = []  # 新增单元格(紫色)
+            
+            for fused_cell, label in zip(fused_cells, cell_labels):
+                if label == 'unet_only':
+                    unet_only.append(fused_cell)
+                elif label == 'rtdetr_only':
+                    rtdetr_only.append(fused_cell)
+                elif label == 'merged_1to1':
+                    merged_cells_1to1.append(fused_cell)
+                elif label == 'merged_span':
+                    merged_cells_span.append(fused_cell)
+                elif label == 'new':
+                    new_cells.append(fused_cell)
+            
+            # 绘制不同类型的单元格
+            for cell in unet_only:
+                x1, y1, x2, y2 = [int(v) for v in cell]
+                cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 0), 2)  # 绿色 - UNet独有
+            
+            for cell in rtdetr_only:
                 x1, y1, x2, y2 = [int(v) for v in cell]
-                cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 255), 2)
-            cv2.putText(img3, f"Fused ({len(fused_cells)})", (10, 30),
-                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
+                cv2.rectangle(img3, (x1, y1), (x2, y2), (255, 0, 0), 2)  # 蓝色 - RT-DETR独有
             
-            # 拼接
+            for cell in merged_cells_1to1:
+                x1, y1, x2, y2 = [int(v) for v in cell]
+                cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 255), 3)  # 黄色 - 1:1融合(加粗)
+            
+            for cell in merged_cells_span:
+                x1, y1, x2, y2 = [int(v) for v in cell]
+                cv2.rectangle(img3, (x1, y1), (x2, y2), (255, 0, 255), 4)  # 品红色 - 合并单元格(加粗)
+            
+            for cell in new_cells:
+                x1, y1, x2, y2 = [int(v) for v in cell]
+                cv2.rectangle(img3, (x1, y1), (x2, y2), (128, 0, 128), 2)  # 紫色 - 新增
+            
+            # 添加图例
+            legend_y = 30
+            cv2.putText(img3, f"Fused ({len(fused_cells)})", (10, legend_y),
+                       cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
+            legend_y += 35
+            cv2.putText(img3, f"Green: UNet-only ({len(unet_only)})", (10, legend_y),
+                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
+            legend_y += 30
+            cv2.putText(img3, f"Blue: RTDETR-only ({len(rtdetr_only)})", (10, legend_y),
+                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
+            legend_y += 30
+            cv2.putText(img3, f"Yellow: 1:1 Merged ({len(merged_cells_1to1)})", (10, legend_y),
+                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
+            legend_y += 30
+            cv2.putText(img3, f"Magenta: Span Cells ({len(merged_cells_span)})", (10, legend_y),
+                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 255), 2)
+            if new_cells:
+                legend_y += 30
+                cv2.putText(img3, f"Purple: New ({len(new_cells)})", (10, legend_y),
+                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (128, 0, 128), 2)
+            
+            # 拼接三栏对比
+            vis_canvas = np.zeros((h, w * 3, 3), dtype=np.uint8)
             vis_canvas[:, :w] = img1
             vis_canvas[:, w:2*w] = img2
             vis_canvas[:, 2*w:] = img3
@@ -465,7 +629,9 @@ class CellFusionEngine:
             # 保存
             output_path = output_dir / f"{debug_prefix}_fusion_comparison.png"
             cv2.imwrite(str(output_path), vis_canvas)
-            logger.debug(f"💾 Fusion visualization saved: {output_path}")
+            logger.info(f"💾 融合可视化已保存: {output_path}")
+            logger.info(f"   📊 单元格分类: UNet独有={len(unet_only)}, RT-DETR独有={len(rtdetr_only)}, "
+                       f"1:1融合={len(merged_cells_1to1)}, 合并单元格={len(merged_cells_span)}, 新增={len(new_cells)}")
             
         except Exception as e:
             logger.warning(f"Failed to visualize fusion: {e}")