Bladeren bron

feat(cell_fusion): 添加边界噪声过滤功能以提高单元格融合质量

zhch158_admin 2 weken geleden
bovenliggende
commit
49c5f418ac
1 gewijzigde bestanden met toevoegingen van 108 en 89 verwijderingen
  1. 108 89
      ocr_tools/universal_doc_parser/models/adapters/wired_table/cell_fusion.py

+ 108 - 89
ocr_tools/universal_doc_parser/models/adapters/wired_table/cell_fusion.py

@@ -42,6 +42,7 @@ class CellFusionEngine:
                 - rtdetr_conf_threshold: 0.5 (RT-DETR置信度阈值)
                 - enable_ocr_compensation: True (启用OCR补偿)
                 - skip_rtdetr_for_txt_pdf: True (文字PDF跳过RT-DETR)
+                - enable_boundary_noise_filter: True (启用边界噪声过滤)
         """
         self.rtdetr_detector = rtdetr_detector
         self.config = config or {}
@@ -54,10 +55,12 @@ class CellFusionEngine:
         self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
         self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
         self.skip_rtdetr_for_txt_pdf = self.config.get('skip_rtdetr_for_txt_pdf', True)
+        self.enable_boundary_noise_filter = self.config.get('enable_boundary_noise_filter', True)
         
         logger.info(f"🔧 CellFusionEngine initialized: "
                    f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
-                   f"iou_merge={self.iou_merge_threshold}, skip_txt_pdf={self.skip_rtdetr_for_txt_pdf}")
+                   f"iou_merge={self.iou_merge_threshold}, skip_txt_pdf={self.skip_rtdetr_for_txt_pdf}, "
+                   f"boundary_filter={self.enable_boundary_noise_filter}")
     
     def should_use_rtdetr(
         self,
@@ -101,7 +104,6 @@ class CellFusionEngine:
         unet_cells: List[List[float]],
         ocr_boxes: List[Dict[str, Any]],
         pdf_type: str = 'ocr',
-        upscale: float = 1.0,
         debug_dir: Optional[str] = None,
         debug_prefix: str = "fusion"
     ) -> Tuple[List[List[float]], Dict[str, Any]]:
@@ -113,7 +115,6 @@ class CellFusionEngine:
             unet_cells: UNet检测的单元格列表 [[x1,y1,x2,y2], ...](原图坐标系)
             ocr_boxes: OCR结果列表
             pdf_type: PDF类型 ('txt' 或 'ocr')
-            upscale: UNet的上采样比例
             debug_dir: 调试输出目录(可选)
             debug_prefix: 调试文件前缀
             
@@ -123,6 +124,12 @@ class CellFusionEngine:
             - fusion_stats: 融合统计信息
         """
         h, w = table_image.shape[:2]
+        unet_bbox = [
+            min(unet_cells, key=lambda box: box[0])[0], \
+            min(unet_cells, key=lambda box: box[1])[1], \
+            max(unet_cells, key=lambda box: box[2])[2], \
+            max(unet_cells, key=lambda box: box[3])[3]
+        ] if unet_cells else [0,0,0,0]
         
         # 决策:是否使用 RT-DETR
         use_rtdetr = self.should_use_rtdetr(pdf_type, len(unet_cells), (w, h))
@@ -143,14 +150,6 @@ class CellFusionEngine:
             cell_labels = ['unet_only'] * len(fused_cells)  # 所有都是UNet独有
             fusion_stats['fused_count'] = len(fused_cells)
             
-            # 可选:OCR补偿
-            if self.enable_ocr_compensation and ocr_boxes:
-                fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
-                    fused_cells, cell_labels, ocr_boxes, (w, h)
-                )
-                fusion_stats['ocr_compensated_count'] = ocr_comp_count
-                fusion_stats['fused_count'] = len(fused_cells)
-            
             logger.info(f"📊 Fusion (UNet-only): {len(unet_cells)} → {len(fused_cells)} cells")
             return fused_cells, fusion_stats
         
@@ -165,6 +164,12 @@ class CellFusionEngine:
             rtdetr_cells = [res['bbox'] for res in rtdetr_results]
             rtdetr_scores = [res['score'] for res in rtdetr_results]
             fusion_stats['rtdetr_count'] = len(rtdetr_cells)
+            rtdetr_bbox = [
+                min(rtdetr_cells, key=lambda box: box[0])[0],
+                min(rtdetr_cells, key=lambda box: box[1])[1],
+                max(rtdetr_cells, key=lambda box: box[2])[2],
+                max(rtdetr_cells, key=lambda box: box[3])[3]
+            ] if rtdetr_cells else [0,0,0,0]
             
             logger.debug(f"RT-DETR detected {len(rtdetr_cells)} cells")
         except Exception as e:
@@ -175,28 +180,33 @@ class CellFusionEngine:
         
         # Phase 2: 智能融合
         fused_cells, merge_stats, cell_labels = self._fuse_cells(
-            unet_cells, rtdetr_cells, rtdetr_scores
+            unet_bbox, unet_cells, rtdetr_cells, rtdetr_scores
         )
         fusion_stats['merged_count'] = merge_stats['merged']
         fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
         fusion_stats['added_count'] = merge_stats['added']
         
         # Phase 3: NMS 去重
-        fused_cells = self._nms_filter(fused_cells, self.iou_nms_threshold)
-        
-        # Phase 4: OCR 补偿(可选)
-        # if self.enable_ocr_compensation and ocr_boxes:
-        #     fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
-        #         fused_cells, cell_labels, ocr_boxes, (w, h)
-        #     )
-        #     fusion_stats['ocr_compensated_count'] = ocr_comp_count
+        fused_cells, suppressed = self._nms_filter(fused_cells, self.iou_nms_threshold)
+        # 同步更新 cell_labels
+        cell_labels = [label for label, keep in zip(cell_labels, suppressed) if not keep]
+        
+        # Phase 4: 边界噪声过滤(过滤掉边界的 unet_only 噪声单元格)
+        if self.enable_boundary_noise_filter:
+            fused_cells, cell_labels, noise_filtered = self._filter_boundary_noise(
+                fused_cells, cell_labels, ocr_boxes, rtdetr_bbox
+            )
+            fusion_stats['noise_filtered_count'] = noise_filtered
+        else:
+            fusion_stats['noise_filtered_count'] = 0
+            noise_filtered = 0
         
         fusion_stats['fused_count'] = len(fused_cells)
         
         logger.info(
             f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
             f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
-            f"Added={merge_stats['added']}, Final={len(fused_cells)}"
+            f"Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, Final={len(fused_cells)}"
         )
         
         # 可视化(调试)
@@ -210,6 +220,7 @@ class CellFusionEngine:
     
     def _fuse_cells(
         self,
+        unet_bbox: List[float],
         unet_cells: List[List[float]],
         rtdetr_cells: List[List[float]],
         rtdetr_scores: List[float]
@@ -232,6 +243,7 @@ class CellFusionEngine:
         - 总覆盖率>40%(所有UNet面积之和 / RT-DETR面积)
         
         Args:
+            unet_bbox: UNet单元格的边界框 [x1, y1, x2, y2]
             unet_cells: UNet单元格列表
             rtdetr_cells: RT-DETR单元格列表
             rtdetr_scores: RT-DETR置信度列表
@@ -243,14 +255,6 @@ class CellFusionEngine:
             - cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'new']
         """
         
-        # 计算unet_cells的边界框bbox[x1,y1,x2,y2]
-        unet_bbox = [
-                    min(unet_cells, key=lambda box: box[0])[0], \
-                    min(unet_cells, key=lambda box: box[1])[1], \
-                    max(unet_cells, key=lambda box: box[2])[2], \
-                    max(unet_cells, key=lambda box: box[3])[3]
-                ] if unet_cells else [0,0,0,0]
-                    
         fused_cells = []
         cell_labels = []  # 记录每个单元格的来源标签
         unet_matched = [False] * len(unet_cells)
@@ -360,6 +364,7 @@ class CellFusionEngine:
             if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
                 # 高IoU:加权平均合并
                 merged_cell = self._weighted_merge_bbox(
+                    unet_bbox,
                     unet_cell,
                     rtdetr_cells[best_match_idx],
                     self.unet_weight,
@@ -396,6 +401,7 @@ class CellFusionEngine:
     
     def _weighted_merge_bbox(
         self,
+        table_bbox: List[float],
         bbox1: List[float],
         bbox2: List[float],
         weight1: float,
@@ -405,6 +411,7 @@ class CellFusionEngine:
         加权平均合并两个 bbox
         
         Args:
+            table_bbox: 表格整体 bbox(用于限制合并结果)
             bbox1: [x1, y1, x2, y2]
             bbox2: [x1, y1, x2, y2]
             weight1: bbox1 的权重
@@ -414,17 +421,17 @@ class CellFusionEngine:
             merged_bbox: [x1, y1, x2, y2]
         """
         return [
-            weight1 * bbox1[0] + weight2 * bbox2[0],
-            weight1 * bbox1[1] + weight2 * bbox2[1],
-            weight1 * bbox1[2] + weight2 * bbox2[2],
-            weight1 * bbox1[3] + weight2 * bbox2[3]
+            max(table_bbox[0], weight1 * bbox1[0] + weight2 * bbox2[0]),
+            max(table_bbox[1], weight1 * bbox1[1] + weight2 * bbox2[1]),
+            min(table_bbox[2], weight1 * bbox1[2] + weight2 * bbox2[2]),
+            min(table_bbox[3], weight1 * bbox1[3] + weight2 * bbox2[3])
         ]
     
     def _nms_filter(
         self,
         cells: List[List[float]],
         iou_threshold: float
-    ) -> List[List[float]]:
+    ) -> Tuple[List[List[float]], List[bool]]:
         """
         简单 NMS 过滤(去除高度重叠的冗余框)
         
@@ -436,9 +443,10 @@ class CellFusionEngine:
             
         Returns:
             过滤后的单元格列表
+            抑制标记列表
         """
         if len(cells) == 0:
-            return []
+            return [], []
         
         # 计算面积并排序(大框优先)
         areas = [(x2 - x1) * (y2 - y1) for x1, y1, x2, y2 in cells]
@@ -463,81 +471,92 @@ class CellFusionEngine:
                     suppressed[other_idx] = True
         
         logger.debug(f"NMS: {len(cells)} → {len(keep)} cells (threshold={iou_threshold})")
-        return keep
+        return keep, suppressed
     
-    def _compensate_with_ocr(
+    def _filter_boundary_noise(
         self,
         cells: List[List[float]],
         cell_labels: List[str],
         ocr_boxes: List[Dict[str, Any]],
-        table_size: Tuple[int, int]
+        rtdetr_bbox: List[float]
     ) -> Tuple[List[List[float]], List[str], int]:
         """
-        使用 OCR 补偿遗漏的单元格
+        过滤边界噪声单元格
         
-        策略:如果 OCR 文本没有匹配到任何单元格,创建新单元格
+        过滤条件:
+        1. 单元格标记为 'unet_only'(只在 UNet 中检测到,RT-DETR 未匹配)
+        2. 单元格位于表格边界(左边界或右边界)
+        3. 单元格内没有任何 OCR 文本框(说明是空白区域)
         
         Args:
-            cells: 现有单元格列表
+            cells: 单元格列表
             cell_labels: 单元格标签列表
             ocr_boxes: OCR结果列表
-            table_size: 表格尺寸 (width, height)
-            
+            rtdetr_bbox: RT-DETR单元格的边界框 [x1, y1, x2, y2]
         Returns:
-            (compensated_cells, compensated_labels, compensation_count)
+            (filtered_cells, filtered_labels, filtered_count)
         """
-        compensated = cells.copy()
-        compensated_labels = cell_labels.copy()
-        compensation_count = 0
-        w, h = table_size
-        
-        for ocr in ocr_boxes:
-            ocr_bbox = ocr.get('bbox', [])
-            if not ocr_bbox or len(ocr_bbox) < 4:
+        filtered_cells = []
+        filtered_labels = []
+        filtered_count = 0
+        
+        for cell, label in zip(cells, cell_labels):
+            # # 只过滤 unet_only 标记的单元格
+            # if label != 'unet_only':
+            #     filtered_cells.append(cell)
+            #     filtered_labels.append(label)
+            #     continue
+            
+            x1, y1, x2, y2 = cell
+            
+            # 检查是否在边界
+            is_left_boundary = x1 <= rtdetr_bbox[0]
+            is_right_boundary = x2 >= rtdetr_bbox[2]
+            is_on_boundary = is_left_boundary or is_right_boundary
+            
+            if not is_on_boundary:
+                # 不在边界,保留
+                filtered_cells.append(cell)
+                filtered_labels.append(label)
                 continue
             
-            # 计算 OCR 中心点
-            if len(ocr_bbox) == 8:  # poly format
-                ocr_cx = (ocr_bbox[0] + ocr_bbox[2] + ocr_bbox[4] + ocr_bbox[6]) / 4
-                ocr_cy = (ocr_bbox[1] + ocr_bbox[3] + ocr_bbox[5] + ocr_bbox[7]) / 4
-            else:  # bbox format
-                ocr_cx = (ocr_bbox[0] + ocr_bbox[2]) / 2
-                ocr_cy = (ocr_bbox[1] + ocr_bbox[3]) / 2
-            
-            # 检查是否在任何单元格内
-            is_covered = False
-            for cell in compensated:
-                x1, y1, x2, y2 = cell
+            # 检查单元格内是否有 OCR 文本框
+            has_ocr = False
+            for ocr in ocr_boxes:
+                ocr_bbox = ocr.get('bbox', [])
+                if not ocr_bbox or len(ocr_bbox) < 4:
+                    continue
+                
+                # 计算 OCR 中心点
+                if len(ocr_bbox) == 8:  # poly format
+                    ocr_cx = (ocr_bbox[0] + ocr_bbox[2] + ocr_bbox[4] + ocr_bbox[6]) / 4
+                    ocr_cy = (ocr_bbox[1] + ocr_bbox[3] + ocr_bbox[5] + ocr_bbox[7]) / 4
+                else:  # bbox format
+                    ocr_cx = (ocr_bbox[0] + ocr_bbox[2]) / 2
+                    ocr_cy = (ocr_bbox[1] + ocr_bbox[3]) / 2
+                
+                # 检查 OCR 中心点是否在当前单元格内
                 if x1 <= ocr_cx <= x2 and y1 <= ocr_cy <= y2:
-                    is_covered = True
+                    has_ocr = True
                     break
             
-            # 如果孤立,创建新单元格
-            if not is_covered:
-                # 扩展 OCR bbox 作为新单元格
-                if len(ocr_bbox) == 8:
-                    new_cell = [
-                        float(max(0, min(ocr_bbox[0], ocr_bbox[6]) - 5)),
-                        float(max(0, min(ocr_bbox[1], ocr_bbox[3]) - 5)),
-                        float(min(w, max(ocr_bbox[2], ocr_bbox[4]) + 5)),
-                        float(min(h, max(ocr_bbox[5], ocr_bbox[7]) + 5))
-                    ]
-                else:
-                    new_cell = [
-                        float(max(0, ocr_bbox[0] - 5)),
-                        float(max(0, ocr_bbox[1] - 5)),
-                        float(min(w, ocr_bbox[2] + 5)),
-                        float(min(h, ocr_bbox[3] + 5))
-                    ]
-                
-                compensated.append(new_cell)
-                compensated_labels.append('new')  # 标记为新增(OCR补偿)
-                compensation_count += 1
+            # 如果在边界且没有 OCR 文本,认为是噪声,过滤掉
+            if not has_ocr:
+                boundary_type = "left" if is_left_boundary else "right"
+                logger.debug(
+                    f"🗑️ 过滤边界噪声: {boundary_type} boundary cell "
+                    f"[{x1:.1f}, {y1:.1f}, {x2:.1f}, {y2:.1f}] (no OCR)"
+                )
+                filtered_count += 1
+            else:
+                # 有 OCR 文本,保留
+                filtered_cells.append(cell)
+                filtered_labels.append(label)
         
-        if compensation_count > 0:
-            logger.debug(f"OCR compensation: added {compensation_count} cells")
+        if filtered_count > 0:
+            logger.info(f"🗑️ Boundary noise filtering: removed {filtered_count} unet_only cells from boundaries")
         
-        return compensated, compensated_labels, compensation_count
+        return filtered_cells, filtered_labels, filtered_count
     
     def _visualize_fusion(
         self,