1
0

7 Revīzijas d602b87e58 ... 321e1fe2c7

Autors SHA1 Ziņojums Datums
  zhch158_admin 321e1fe2c7 feat(cell_fusion): 更新测试用例以启用RT-DETR融合,移除文字PDF跳过策略 2 nedēļas atpakaļ
  zhch158_admin dc9ae52f78 feat(MinerUWiredTableRecognizer): 更新单元格提取方法,简化OCR边缘补偿参数传递 2 nedēļas atpakaļ
  zhch158_admin 9f84fef765 feat(TextFiller): 添加动态置信度阈值计算方法,优化OCR结果处理 2 nedēļas atpakaļ
  zhch158_admin 7f414a2691 feat(table_recognition_wired): 更新OCR补偿功能描述,简化配置注释 2 nedēļas atpakaļ
  zhch158_admin 9bd8e9c502 refactor(GridRecovery): 移除冗余的边界线过滤方法,简化代码结构 2 nedēļas atpakaļ
  zhch158_admin a77c565c4a feat(cell_fusion): 更新融合流程,添加边界噪声过滤和OCR边缘补偿功能 2 nedēļas atpakaļ
  zhch158_admin f5099c51d9 feat(cell_fusion): 添加OCR边缘补偿功能,优化单元格融合质量 2 nedēļas atpakaļ

+ 18 - 15
docs/ocr_tools/universal_doc_parser/有线表格-多源单元格融合.md

@@ -721,7 +721,8 @@ def should_use_rtdetr(pdf_type, unet_cell_count, table_size):
   - RT-DETR 独有 + 高置信度 (>0.7) → 补充
   - UNet 独有 → 保留
 - **Phase 3**: NMS去重 (IoU>0.5)
-- **Phase 4**: OCR补偿
+- **Phase 4**: 边界噪声过滤
+- **Phase 5**: OCR边缘补偿(补偿"有OCR文本但无单元格覆盖"的位置)
 
 #### 3. **配置示例**
 ```yaml
@@ -732,23 +733,25 @@ wired_table_recognizer:
     unet_weight: 0.6
     rtdetr_weight: 0.4
     iou_merge_threshold: 0.7
-    skip_rtdetr_for_txt_pdf: true  # 🎯 文字PDF跳过RT-DETR
+    enable_ocr_compensation: true  # 启用OCR边缘补偿
 ```
 
 ### 📊 预期效果
 
-1. **文字PDF** (pdf_type='txt')
-   - 自动跳过 RT-DETR,纯 UNet 模式
-   - 性能提升:节省 100-200ms 推理时间
-   - 准确性:避免 RT-DETR 在无噪声图像上的误检
+1. **所有PDF类型**
+   - 统一使用 UNet + RT-DETR 融合模式
+   - OCR边缘补偿在融合后执行
+   - 边缘单元格召回率显著提升
 
-2. **扫描PDF** (pdf_type='ocr')
-   - 启用融合模式
-   - 鲁棒性提升 30%+(模糊/噪声表格)
-   - 边缘单元格召回率 +15%
+2. **融合流程**
+   - Phase 1: RT-DETR 检测
+   - Phase 2: UNet + RT-DETR 智能融合
+   - Phase 3: NMS 去重
+   - Phase 4: 边界噪声过滤
+   - Phase 5: OCR 边缘补偿
 
 3. **降级机制**
-   - RT-DETR模型未配置 → UNet-only
+   - RT-DETR模型未配置 → UNet-only + OCR补偿
    - RT-DETR推理失败 → 自动回退
    - UNet为空 → 强制启用RT-DETR
 
@@ -756,21 +759,21 @@ wired_table_recognizer:
 
 运行测试脚本:
 ```bash
-cd /Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/models/adapters/wired_table
+cd /Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/tests
 python test_cell_fusion.py
 ```
 
 测试覆盖:
 - ✅ RT-DETR检测器初始化
 - ✅ 融合引擎基本功能
-- ✅ 文字PDF自适应跳过
-- ✅ 扫描PDF启用融合
+- ✅ 所有PDF类型启用融合
+- ✅ OCR边缘补偿
 - ✅ 降级机制
 
 ### 📝 使用文档
 
 完整的配置和使用说明已包含在 cell_fusion_config_example.yaml 中。
 
-系统已完全实现并集成,支持自适应策略,文字PDF自动跳过RT-DETR检测以提升性能和准确性!🎉
+系统已完全实现并集成,所有PDF类型统一使用UNet+RT-DETR融合,OCR补偿在融合后执行!🎉
 
 Made changes.

+ 1 - 2
ocr_tools/universal_doc_parser/config/bank_statement_yusys_v3.yaml

@@ -93,8 +93,7 @@ table_recognition_wired:
     rtdetr_conf_threshold: 0.5  # RT-DETR置信度阈值
     
     # 功能开关
-    enable_ocr_compensation: true      # 启用OCR孤立文本补偿
-    skip_rtdetr_for_txt_pdf: true      # 🎯 文字PDF跳过RT-DETR(自适应策略)
+    enable_ocr_compensation: true      # 启用OCR边缘补偿
 
   # Debug 可视化配置(与 MinerUWiredTableRecognizer.DebugOptions 对齐)
   # 默认关闭。开启后将保存:表格线、连通域、逻辑网格结构、文本覆盖可视化。

+ 2 - 5
ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

@@ -379,7 +379,7 @@ class MinerUWiredTableRecognizer:
             # Step 2: 使用连通域法提取单元格 (替换了原来的投影法)
             debug_prefix = f"{dbg.prefix}_grid" if dbg.prefix else "grid"
             
-            # 传入原图的实际尺寸、裁剪padding和OCR结果
+            # 传入原图的实际尺寸和裁剪padding
             bboxes = self.grid_recovery.compute_cells_from_lines(
                 hpred_up, 
                 vpred_up, 
@@ -389,14 +389,11 @@ class MinerUWiredTableRecognizer:
                 debug_dir=debug_dir,
                 debug_prefix=debug_prefix,
                 crop_padding=10,  # 传递 padding 值(与 element_processors.py 中的 crop_padding 保持一致)
-                ocr_bboxes=ocr_boxes,  # 🆕 传递OCR结果用于边缘补偿
-                enable_ocr_edge_compensation=True  # 🆕 启用OCR边缘补偿
             )
-            # bboxes = self.grid_recovery.compute_cells_from_lines(hpred_up, vpred_up, upscale) # Original call
             if not bboxes:
                 raise RuntimeError("未能提取出单元格")
 
-            # Step 2.3: 🆕 多源单元格融合(UNet + RT-DETR ), ocr边缘补偿在前面compute_cells_from_lines完成
+            # Step 2.3: 🆕 多源单元格融合(UNet + RT-DETR + OCR边缘补偿)
             fusion_stats = {}
             if self.cell_fusion_engine:
                 try:

+ 351 - 42
ocr_tools/universal_doc_parser/models/adapters/wired_table/cell_fusion.py

@@ -17,11 +17,14 @@ class CellFusionEngine:
     融合策略:
     1. UNet 连通域检测(结构性强,适合清晰表格)
     2. RT-DETR 端到端检测(鲁棒性强,适合噪声表格)
-    3. OCR 文本位置(验证单元格存在性
+    3. OCR 边缘补偿(补偿"有OCR文本但无单元格覆盖"的位置
     
-    自适应策略:
-    - 文字PDF (pdf_type='txt'): 跳过 RT-DETR,纯 UNet 模式(无噪声)
-    - 扫描PDF (pdf_type='ocr'): 启用融合模式(有噪声)
+    处理流程:
+    - Phase 1: RT-DETR 检测
+    - Phase 2: UNet + RT-DETR 智能融合
+    - Phase 3: NMS 去重
+    - Phase 4: 边界噪声过滤
+    - Phase 5: OCR 边缘补偿
     """
     
     def __init__(
@@ -41,7 +44,6 @@ class CellFusionEngine:
                 - iou_nms_threshold: 0.5 (NMS去重阈值)
                 - rtdetr_conf_threshold: 0.5 (RT-DETR置信度阈值)
                 - enable_ocr_compensation: True (启用OCR补偿)
-                - skip_rtdetr_for_txt_pdf: True (文字PDF跳过RT-DETR)
                 - enable_boundary_noise_filter: True (启用边界噪声过滤)
         """
         self.rtdetr_detector = rtdetr_detector
@@ -54,12 +56,11 @@ class CellFusionEngine:
         self.iou_nms_threshold = self.config.get('iou_nms_threshold', 0.5)
         self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
         self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
-        self.skip_rtdetr_for_txt_pdf = self.config.get('skip_rtdetr_for_txt_pdf', True)
         self.enable_boundary_noise_filter = self.config.get('enable_boundary_noise_filter', True)
         
         logger.info(f"🔧 CellFusionEngine initialized: "
                    f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
-                   f"iou_merge={self.iou_merge_threshold}, skip_txt_pdf={self.skip_rtdetr_for_txt_pdf}, "
+                   f"iou_merge={self.iou_merge_threshold}, ocr_comp={self.enable_ocr_compensation}, "
                    f"boundary_filter={self.enable_boundary_noise_filter}")
     
     def should_use_rtdetr(
@@ -79,23 +80,18 @@ class CellFusionEngine:
         Returns:
             是否使用 RT-DETR
         """
-        # 策略1: 文字PDF跳过RT-DETR(无噪声,UNet结果已足够准确)
-        if pdf_type == 'txt' and self.skip_rtdetr_for_txt_pdf:
-            logger.debug(f"📄 Text PDF detected, skip RT-DETR (UNet cells: {unet_cell_count})")
-            return False
-        
-        # 策略2: 如果 RT-DETR 检测器未初始化,跳过
+        # 策略1: 如果 RT-DETR 检测器未初始化,跳过
         if self.rtdetr_detector is None:
             logger.debug("⚠️ RT-DETR detector not initialized, skip fusion")
             return False
         
-        # 策略3: UNet检测结果为空,必须使用RT-DETR补救
+        # 策略2: UNet检测结果为空,必须使用RT-DETR补救
         if unet_cell_count == 0:
             logger.info("🚨 UNet detected 0 cells, force enable RT-DETR")
             return True
         
-        # 策略4: 扫描PDF,启用融合模式
-        logger.debug(f"🔍 Scan PDF detected, enable RT-DETR fusion (UNet cells: {unet_cell_count})")
+        # 策略3: 启用融合模式(对所有PDF类型统一使用RT-DETR融合)
+        logger.debug(f"🔍 Enable RT-DETR fusion (UNet cells: {unet_cell_count})")
         return True
     
     def fuse(
@@ -144,10 +140,20 @@ class CellFusionEngine:
             'ocr_compensated_count': 0
         }
         
-        # 如果不使用RT-DETR,直接返回UNet结果
+        # 如果不使用RT-DETR,返回UNet结果(但仍做OCR补偿)
         if not use_rtdetr:
             fused_cells = unet_cells.copy()
             cell_labels = ['unet_only'] * len(fused_cells)  # 所有都是UNet独有
+            
+            # Phase 5: OCR 边缘补偿(RT-DETR不可用时也执行)
+            if self.enable_ocr_compensation and ocr_boxes:
+                compensated_cells, compensated_labels = self._compensate_edge_cells(
+                    fused_cells, cell_labels, ocr_boxes
+                )
+                fused_cells.extend(compensated_cells)
+                cell_labels.extend(compensated_labels)
+                fusion_stats['ocr_compensated_count'] = len(compensated_cells)
+            
             fusion_stats['fused_count'] = len(fused_cells)
             
             logger.info(f"📊 Fusion (UNet-only): {len(unet_cells)} → {len(fused_cells)} cells")
@@ -164,13 +170,6 @@ class CellFusionEngine:
             rtdetr_cells = [res['bbox'] for res in rtdetr_results]
             rtdetr_scores = [res['score'] for res in rtdetr_results]
             fusion_stats['rtdetr_count'] = len(rtdetr_cells)
-            rtdetr_bbox = [
-                min(rtdetr_cells, key=lambda box: box[0])[0],
-                min(rtdetr_cells, key=lambda box: box[1])[1],
-                max(rtdetr_cells, key=lambda box: box[2])[2],
-                max(rtdetr_cells, key=lambda box: box[3])[3]
-            ] if rtdetr_cells else [0,0,0,0]
-            
             logger.debug(f"RT-DETR detected {len(rtdetr_cells)} cells")
         except Exception as e:
             logger.warning(f"⚠️ RT-DETR detection failed: {e}, fallback to UNet-only")
@@ -179,8 +178,14 @@ class CellFusionEngine:
             return fused_cells, fusion_stats
         
         # Phase 2: 智能融合
+        # 使用稳健边界估计(避免单个超大单元格撑开边界)
+        table_bbox = self._estimate_robust_table_bbox(rtdetr_cells)
+        
+        # 将所有单元格的边界限制在表格边界内
+        # rtdetr_cells = self._clip_cells_to_bbox(rtdetr_cells, table_bbox)
+        
         fused_cells, merge_stats, cell_labels = self._fuse_cells(
-            unet_bbox, unet_cells, rtdetr_cells, rtdetr_scores
+            table_bbox, unet_cells, rtdetr_cells, rtdetr_scores
         )
         fusion_stats['merged_count'] = merge_stats['merged']
         fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
@@ -194,19 +199,29 @@ class CellFusionEngine:
         # Phase 4: 边界噪声过滤(过滤掉边界的 unet_only 噪声单元格)
         if self.enable_boundary_noise_filter:
             fused_cells, cell_labels, noise_filtered = self._filter_boundary_noise(
-                fused_cells, cell_labels, ocr_boxes, rtdetr_bbox
+                fused_cells, cell_labels, ocr_boxes, table_bbox
             )
             fusion_stats['noise_filtered_count'] = noise_filtered
         else:
             fusion_stats['noise_filtered_count'] = 0
             noise_filtered = 0
         
+        # Phase 5: OCR 边缘补偿(补偿有 OCR 文本但无单元格覆盖的位置)
+        if self.enable_ocr_compensation and ocr_boxes:
+            compensated_cells, compensated_labels = self._compensate_edge_cells(
+                fused_cells, cell_labels, ocr_boxes
+            )
+            fused_cells.extend(compensated_cells)
+            cell_labels.extend(compensated_labels)
+            fusion_stats['ocr_compensated_count'] = len(compensated_cells)
+        
         fusion_stats['fused_count'] = len(fused_cells)
         
         logger.info(
             f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
             f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
-            f"Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, Final={len(fused_cells)}"
+            f"Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, "
+            f"OCRCompensated={fusion_stats.get('ocr_compensated_count', 0)}, Final={len(fused_cells)}"
         )
         
         # 可视化(调试)
@@ -220,7 +235,7 @@ class CellFusionEngine:
     
     def _fuse_cells(
         self,
-        unet_bbox: List[float],
+        table_bbox: List[float],
         unet_cells: List[List[float]],
         rtdetr_cells: List[List[float]],
         rtdetr_scores: List[float]
@@ -243,7 +258,7 @@ class CellFusionEngine:
         - 总覆盖率>40%(所有UNet面积之和 / RT-DETR面积)
         
         Args:
-            unet_bbox: UNet单元格的边界框 [x1, y1, x2, y2]
+            table_bbox: 表格的边界框 [x1, y1, x2, y2]
             unet_cells: UNet单元格列表
             rtdetr_cells: RT-DETR单元格列表
             rtdetr_scores: RT-DETR置信度列表
@@ -318,18 +333,18 @@ class CellFusionEngine:
                     
                     # 如果覆盖率>50%,说明这是一个真实的合并单元格
                     if coverage > 0.5:
-                        # 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过unet_bbox范围
+                        # 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过table_bbox范围
                         fused_cell = [
                             min(bounding_x1, rtdetr_cell[0]),
                             min(bounding_y1, rtdetr_cell[1]),
                             max(bounding_x2, rtdetr_cell[2]),
                             max(bounding_y2, rtdetr_cell[3])
                         ]
-                        # x限制在unet_bbox范围内
-                        fused_cell[0] = max(fused_cell[0], unet_bbox[0])
-                        # fused_cell[1] = max(fused_cell[1], unet_bbox[1])
-                        fused_cell[2] = min(fused_cell[2], unet_bbox[2])
-                        # fused_cell[3] = min(fused_cell[3], unet_bbox[3])
+                        # x限制在table_bbox范围内
+                        fused_cell[0] = max(fused_cell[0], table_bbox[0])
+                        # fused_cell[1] = max(fused_cell[1], table_bbox[1])
+                        fused_cell[2] = min(fused_cell[2], table_bbox[2])
+                        # fused_cell[3] = min(fused_cell[3], table_bbox[3])
                         fused_cells.append(fused_cell)
                         cell_labels.append('merged_span')  # 标记为合并单元格
                         rtdetr_matched[rt_idx] = True
@@ -364,7 +379,7 @@ class CellFusionEngine:
             if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
                 # 高IoU:加权平均合并
                 merged_cell = self._weighted_merge_bbox(
-                    unet_bbox,
+                    table_bbox,
                     unet_cell,
                     rtdetr_cells[best_match_idx],
                     self.unet_weight,
@@ -384,17 +399,148 @@ class CellFusionEngine:
         # Step 3: 补充 RT-DETR 独有的高置信度单元格
         for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
             if not rtdetr_matched[idx] and score > 0.7:
-                # rtdetr_cell不能超出unet_bbox范围, x方向分别限制
-                rtdetr_cell[0] = max(rtdetr_cell[0], unet_bbox[0])
-                # rtdetr_cell[1] = max(rtdetr_cell[1], unet_bbox[1])
-                rtdetr_cell[2] = min(rtdetr_cell[2], unet_bbox[2])
-                # rtdetr_cell[3] = min(rtdetr_cell[3], unet_bbox[3])
+                # rtdetr_cell不能超出table_bbox范围, x方向分别限制
+                rtdetr_cell[0] = max(rtdetr_cell[0], table_bbox[0])
+                # rtdetr_cell[1] = max(rtdetr_cell[1], table_bbox[1])
+                rtdetr_cell[2] = min(rtdetr_cell[2], table_bbox[2])
+                # rtdetr_cell[3] = min(rtdetr_cell[3], table_bbox[3])
                 fused_cells.append(rtdetr_cell)
                 cell_labels.append('rtdetr_only')  # 标记为RT-DETR独有
                 stats['added'] += 1
         
         return fused_cells, stats, cell_labels
     
+    def _estimate_robust_table_bbox(
+        self,
+        rtdetr_cells: List[List[float]],
+        cluster_tolerance: float = 10.0
+    ) -> List[float]:
+        """
+        稳健的表格边界估计
+        
+        使用聚类方法找到"主流"的左右边界,避免单个超大单元格撑开边界。
+        
+        算法:
+        1. 收集所有单元格的左边界x1和右边界x2
+        2. 对x1聚类,选择支持度最高的聚类中心作为表格左边界
+        3. 对x2聚类,选择支持度最高的聚类中心作为表格右边界
+        4. y方向使用简单的min/max(行高变化大,不适合聚类)
+        
+        Args:
+            rtdetr_cells: RT-DETR单元格列表
+            cluster_tolerance: 聚类容差(像素)
+            
+        Returns:
+            table_bbox: [x1, y1, x2, y2]
+        """
+        all_cells = rtdetr_cells
+        if not all_cells:
+            return [0, 0, 0, 0]
+        
+        # 收集所有边界坐标
+        x1_coords = [cell[0] for cell in all_cells]
+        x2_coords = [cell[2] for cell in all_cells]
+        y1_coords = [cell[1] for cell in all_cells]
+        y2_coords = [cell[3] for cell in all_cells]
+        
+        # 对x1聚类,找主流左边界
+        robust_x1 = self._find_dominant_boundary(x1_coords, cluster_tolerance, mode='min')
+        # 对x2聚类,找主流右边界
+        robust_x2 = self._find_dominant_boundary(x2_coords, cluster_tolerance, mode='max')
+        # y方向直接取极值
+        robust_y1 = min(y1_coords)
+        robust_y2 = max(y2_coords)
+        
+        logger.debug(f"📐 稳健边界估计: x=[{robust_x1:.1f}, {robust_x2:.1f}], "
+                    f"原始x范围=[{min(x1_coords):.1f}, {max(x2_coords):.1f}]")
+        
+        return [robust_x1, robust_y1, robust_x2, robust_y2]
+    
+    def _find_dominant_boundary(
+        self,
+        coords: List[float],
+        tolerance: float,
+        mode: str = 'min'
+    ) -> float:
+        """
+        找到边界方向上的稳健值
+        
+        算法:从边界方向开始,找到第一个支持度足够的聚类
+        - 对于左边界 (mode='min'):从最小值开始,排除孤立的异常小值
+        - 对于右边界 (mode='max'):从最大值开始,排除孤立的异常大值
+        
+        Args:
+            coords: 坐标列表
+            tolerance: 聚类容差
+            mode: 'min' 表示找最小值方向的主流边界,'max' 表示找最大值方向
+            
+        Returns:
+            稳健的边界值
+        """
+        if not coords:
+            return 0.0
+        
+        if len(coords) == 1:
+            return coords[0]
+        
+        # 简单聚类
+        sorted_coords = sorted(coords)
+        clusters = []
+        curr_cluster = [sorted_coords[0]]
+        
+        for x in sorted_coords[1:]:
+            if x - curr_cluster[-1] < tolerance:
+                curr_cluster.append(x)
+            else:
+                clusters.append(curr_cluster)
+                curr_cluster = [x]
+        clusters.append(curr_cluster)
+        
+        # 根据边界方向找稳健值
+        min_support = 2  # 最小支持度阈值
+        
+        if mode == 'min':
+            # 对于左边界:从最小值方向开始,找第一个支持度足够的聚类
+            # clusters 已按值从小到大排序
+            for cluster in clusters:
+                if len(cluster) >= min_support:
+                    return sum(cluster) / len(cluster)
+            # 如果所有聚类支持度都不够,返回最小值
+            return sorted_coords[0]
+        else:  # mode == 'max'
+            # 对于右边界:从最大值方向开始,找第一个支持度足够的聚类
+            for cluster in reversed(clusters):
+                if len(cluster) >= min_support:
+                    return sum(cluster) / len(cluster)
+            # 如果所有聚类支持度都不够,返回最大值
+            return sorted_coords[-1]
+    
+    def _clip_cells_to_bbox(
+        self,
+        cells: List[List[float]],
+        bbox: List[float]
+    ) -> List[List[float]]:
+        """
+        将单元格边界裁剪到指定bbox范围内
+        
+        Args:
+            cells: 单元格列表
+            bbox: 限制边界 [x1, y1, x2, y2]
+            
+        Returns:
+            裁剪后的单元格列表
+        """
+        clipped = []
+        for cell in cells:
+            clipped_cell = [
+                max(cell[0], bbox[0]),  # x1
+                cell[1],                 # y1 不裁剪(行高变化大)
+                min(cell[2], bbox[2]),  # x2
+                cell[3]                  # y2 不裁剪
+            ]
+            clipped.append(clipped_cell)
+        return clipped
+    
     def _calc_bbox_area(self, bbox: List[float]) -> float:
         """计算bbox面积"""
         return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
@@ -558,6 +704,157 @@ class CellFusionEngine:
         
         return filtered_cells, filtered_labels, filtered_count
     
+    def _compensate_edge_cells(
+        self,
+        fused_cells: List[List[float]],
+        cell_labels: List[str],
+        ocr_boxes: List[Dict[str, Any]],
+        min_confidence: float = 0.7
+    ) -> Tuple[List[List[float]], List[str]]:
+        """
+        简洁的 OCR 边缘补偿算法
+        
+        策略:仅补偿"有 OCR 文本但无单元格覆盖"的位置
+        
+        算法流程:
+        1. 从融合后的单元格提取所有唯一的 x/y 坐标构建网格分割线
+        2. 遍历 OCR 文本框,若其中心点不在任何现有单元格内
+        3. 用网格分割线找到最近的网格边界,生成补偿单元格
+        
+        Args:
+            fused_cells: 融合后的单元格列表 [[x1,y1,x2,y2], ...]
+            cell_labels: 单元格标签列表
+            ocr_boxes: OCR结果列表 [{'bbox': [x1,y1,x2,y2], 'text': str, 'confidence': float}, ...]
+            min_confidence: OCR最小置信度阈值
+            
+        Returns:
+            (compensated_cells, compensated_labels)
+            - compensated_cells: 补偿的单元格列表
+            - compensated_labels: 补偿单元格的标签列表(全为 'ocr_compensated')
+        """
+        if not fused_cells or not ocr_boxes:
+            return [], []
+        
+        # Step 1: 从融合后单元格提取网格分割线
+        x_coords = set()
+        y_coords = set()
+        for cell in fused_cells:
+            x1, y1, x2, y2 = cell
+            x_coords.add(x1)
+            x_coords.add(x2)
+            y_coords.add(y1)
+            y_coords.add(y2)
+        
+        x_dividers = sorted(x_coords)
+        y_dividers = sorted(y_coords)
+        
+        if len(x_dividers) < 2 or len(y_dividers) < 2:
+            logger.debug("📊 OCR补偿: 网格分割线不足,跳过")
+            return [], []
+        
+        logger.debug(f"📊 OCR补偿: 构建网格 {len(y_dividers)-1}行 × {len(x_dividers)-1}列")
+        
+        # Step 2: 过滤有效 OCR
+        valid_ocr = [
+            ocr for ocr in ocr_boxes
+            if ocr.get('confidence', 1.0) >= min_confidence
+            and len(ocr.get('text', '').strip()) > 0
+            and len(ocr.get('bbox', [])) >= 4
+        ]
+        
+        if not valid_ocr:
+            logger.debug("📊 OCR补偿: 无有效OCR")
+            return [], []
+        
+        # Step 3: 检查每个 OCR 中心点是否在现有单元格内
+        compensated_cells = []
+        compensated_labels = []
+        
+        for ocr in valid_ocr:
+            ocr_bbox = ocr['bbox']
+            ocr_text = ocr.get('text', '')[:20]
+            
+            # 计算 OCR 中心点
+            if len(ocr_bbox) == 8:  # poly format
+                ocr_cx = (ocr_bbox[0] + ocr_bbox[2] + ocr_bbox[4] + ocr_bbox[6]) / 4
+                ocr_cy = (ocr_bbox[1] + ocr_bbox[3] + ocr_bbox[5] + ocr_bbox[7]) / 4
+            else:  # bbox format [x1, y1, x2, y2]
+                ocr_cx = (ocr_bbox[0] + ocr_bbox[2]) / 2
+                ocr_cy = (ocr_bbox[1] + ocr_bbox[3]) / 2
+            
+            # 检查中心点是否在任何现有单元格内
+            is_covered = False
+            for cell in fused_cells:
+                x1, y1, x2, y2 = cell
+                if x1 <= ocr_cx <= x2 and y1 <= ocr_cy <= y2:
+                    is_covered = True
+                    break
+            
+            # 也检查已补偿的单元格
+            if not is_covered:
+                for cell in compensated_cells:
+                    x1, y1, x2, y2 = cell
+                    if x1 <= ocr_cx <= x2 and y1 <= ocr_cy <= y2:
+                        is_covered = True
+                        break
+            
+            if is_covered:
+                continue  # 已被覆盖,跳过
+            
+            # Step 4: 找到中心点所在的网格单元格
+            # 找到左边界(最大的 x <= ocr_cx)
+            cell_x1 = None
+            for x in x_dividers:
+                if x <= ocr_cx:
+                    cell_x1 = x
+                else:
+                    break
+            
+            # 找到右边界(最小的 x > ocr_cx)
+            cell_x2 = None
+            for x in x_dividers:
+                if x > ocr_cx:
+                    cell_x2 = x
+                    break
+            
+            # 找到上边界(最大的 y <= ocr_cy)
+            cell_y1 = None
+            for y in y_dividers:
+                if y <= ocr_cy:
+                    cell_y1 = y
+                else:
+                    break
+            
+            # 找到下边界(最小的 y > ocr_cy)
+            cell_y2 = None
+            for y in y_dividers:
+                if y > ocr_cy:
+                    cell_y2 = y
+                    break
+            
+            # 如果找不到完整的边界,跳过
+            if cell_x1 is None or cell_x2 is None or cell_y1 is None or cell_y2 is None:
+                logger.debug(f"⏭️ 跳过OCR '{ocr_text}': 中心点({ocr_cx:.1f},{ocr_cy:.1f})不在网格范围内")
+                continue
+            
+            # 生成补偿单元格
+            compensated_bbox = [float(cell_x1), float(cell_y1), float(cell_x2), float(cell_y2)]
+            compensated_cells.append(compensated_bbox)
+            compensated_labels.append('ocr_compensated')
+            
+            logger.info(
+                f"✅ OCR补偿: '{ocr_text}' | "
+                f"center=({ocr_cx:.1f},{ocr_cy:.1f}) → "
+                f"cell=[{cell_x1:.1f},{cell_y1:.1f},{cell_x2:.1f},{cell_y2:.1f}]"
+            )
+        
+        if compensated_cells:
+            logger.info(f"🎉 OCR边缘补偿完成: +{len(compensated_cells)}个单元格")
+        else:
+            logger.debug("📊 OCR补偿: 所有OCR已被现有单元格覆盖")
+        
+        return compensated_cells, compensated_labels
+    
     def _visualize_fusion(
         self,
         table_image: np.ndarray,
@@ -609,6 +906,7 @@ class CellFusionEngine:
             merged_cells_1to1 = []  # 1:1融合单元格(黄色)
             merged_cells_span = []  # 合并单元格(品红色,RT-DETR检测的跨格单元格)
             new_cells = []  # 新增单元格(紫色)
+            ocr_compensated = []  # OCR补偿单元格(橙色)
             
             for fused_cell, label in zip(fused_cells, cell_labels):
                 if label == 'unet_only':
@@ -621,6 +919,8 @@ class CellFusionEngine:
                     merged_cells_span.append(fused_cell)
                 elif label == 'new':
                     new_cells.append(fused_cell)
+                elif label == 'ocr_compensated':
+                    ocr_compensated.append(fused_cell)
             
             # 绘制不同类型的单元格
             for cell in unet_only:
@@ -643,6 +943,10 @@ class CellFusionEngine:
                 x1, y1, x2, y2 = [int(v) for v in cell]
                 cv2.rectangle(img3, (x1, y1), (x2, y2), (128, 0, 128), 2)  # 紫色 - 新增
             
+            for cell in ocr_compensated:
+                x1, y1, x2, y2 = [int(v) for v in cell]
+                cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 165, 255), 3)  # 橙色 - OCR补偿(加粗)
+            
             # 添加图例
             legend_y = 30
             cv2.putText(img3, f"Fused ({len(fused_cells)})", (10, legend_y),
@@ -663,6 +967,10 @@ class CellFusionEngine:
                 legend_y += 30
                 cv2.putText(img3, f"Purple: New ({len(new_cells)})", (10, legend_y),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (128, 0, 128), 2)
+            if ocr_compensated:
+                legend_y += 30
+                cv2.putText(img3, f"Orange: OCR Compensated ({len(ocr_compensated)})", (10, legend_y),
+                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 165, 255), 2)
             
             # 拼接三栏对比
             vis_canvas = np.zeros((h, w * 3, 3), dtype=np.uint8)
@@ -675,7 +983,8 @@ class CellFusionEngine:
             cv2.imwrite(str(output_path), vis_canvas)
             logger.info(f"💾 融合可视化已保存: {output_path}")
             logger.info(f"   📊 单元格分类: UNet独有={len(unet_only)}, RT-DETR独有={len(rtdetr_only)}, "
-                       f"1:1融合={len(merged_cells_1to1)}, 合并单元格={len(merged_cells_span)}, 新增={len(new_cells)}")
+                       f"1:1融合={len(merged_cells_1to1)}, 合并单元格={len(merged_cells_span)}, "
+                       f"新增={len(new_cells)}, OCR补偿={len(ocr_compensated)}")
             
         except Exception as e:
             logger.warning(f"Failed to visualize fusion: {e}")

+ 0 - 523
ocr_tools/universal_doc_parser/models/adapters/wired_table/grid_recovery.py

@@ -147,25 +147,6 @@ class GridRecovery:
         return valid_lines
 
     @staticmethod
-    def _filter_lines_by_bboxes(lines, bboxes, is_horizontal, tolerance=5.0):
-        """过滤线条,只保留与bboxes边界对齐的线条"""
-        if not bboxes:
-            return lines
-        
-        if is_horizontal:
-            bbox_coords = {bbox[1] for bbox in bboxes} | {bbox[3] for bbox in bboxes}
-        else:
-            bbox_coords = {bbox[0] for bbox in bboxes} | {bbox[2] for bbox in bboxes}
-        
-        filtered_lines = []
-        for line in lines:
-            line_coord = (line[1] + line[3]) / 2 if is_horizontal else (line[0] + line[2]) / 2
-            if any(abs(line_coord - coord) < tolerance for coord in bbox_coords):
-                filtered_lines.append(line)
-        
-        return filtered_lines
-
-    @staticmethod
     def _save_debug_image(debug_dir, debug_prefix, step_name, img, is_lines=False, lines=None):
         """保存调试图片"""
         if not debug_dir:
@@ -198,8 +179,6 @@ class GridRecovery:
         debug_dir: Optional[str] = None,
         debug_prefix: str = "",
         crop_padding: int = 10,  # 新增:裁剪时的padding值(原图坐标系)
-        ocr_bboxes: Optional[List[Dict]] = None,  # 🆕 整页OCR结果
-        enable_ocr_edge_compensation: bool = True,  # 🆕 是否启用OCR边缘补偿
     ) -> List[List[float]]:
         """
         基于矢量重构的连通域分析 (Advanced Vector-based Recovery)
@@ -210,7 +189,6 @@ class GridRecovery:
         3. 线段归并/连接 (adjust_lines)
         4. 几何延长线段 (Custom final_adjust_lines with larger threshold)
         5. 重绘Mask并进行连通域分析
-        6. 🆕 OCR补偿未封闭的边缘单元格
         
         Args:
             hpred_up: 横线预测mask(上采样后)
@@ -221,14 +199,11 @@ class GridRecovery:
             debug_dir: 调试输出目录 (Optional)
             debug_prefix: 调试文件名前缀 (Optional)
             crop_padding: 裁剪时的padding值(原图坐标系,默认10px)
-            ocr_bboxes: 🆕 整页OCR结果 [{'bbox': [x1,y1,x2,y2], 'text': str, 'confidence': float}, ...]
-            enable_ocr_edge_compensation: 🆕 是否启用OCR边缘补偿(默认True)
             
         注意:
             - hpred_up/vpred_up 是上采样后的mask,坐标系已经放大了 upscale 倍
             - crop_padding 是原图坐标系的值,需要乘以 upscale 转换到mask坐标系
             - edge_margin 用于过滤贴近图像边缘的线条(padding区域的噪声)
-            - ocr_bboxes坐标应为原图坐标系,补偿算法会自动处理坐标转换
         
         Returns:
             单元格bbox列表 [[x1, y1, x2, y2], ...] (原图坐标系)
@@ -480,49 +455,6 @@ class GridRecovery:
         else:
             logger.info(f"矢量重构分析提取到 {len(bboxes)} 个单元格 (Dynamic Alpha: {dynamic_alpha}, upscale={upscale:.3f})")
         
-        # 🆕 Step 6: OCR补偿未封闭的边缘单元格
-        if enable_ocr_edge_compensation and ocr_bboxes and orig_h is not None and orig_w is not None:
-            logger.info("━━━━━━━━ 🔍 OCR边缘补偿 ━━━━━━━━")
-            
-            # 转换线条坐标到原图坐标系 (从mask坐标系转换)
-            rowboxes_orig = [
-                [line[0] / scale_w, line[1] / scale_h, line[2] / scale_w, line[3] / scale_h]
-                for line in rowboxes
-            ]
-            colboxes_orig = [
-                [line[0] / scale_w, line[1] / scale_h, line[2] / scale_w, line[3] / scale_h]
-                for line in colboxes
-            ]
-            
-            # 过滤线条:只保留与existing_bboxes边界对齐的线条
-            rowboxes_filtered = GridRecovery._filter_lines_by_bboxes(rowboxes_orig, bboxes, is_horizontal=True)
-            colboxes_filtered = GridRecovery._filter_lines_by_bboxes(colboxes_orig, bboxes, is_horizontal=False)
-            
-            logger.debug(
-                f"🔍 线条过滤: 横线 {len(rowboxes_orig)}→{len(rowboxes_filtered)}, "
-                f"竖线 {len(colboxes_orig)}→{len(colboxes_filtered)}"
-            )
-            
-            # 调用OCR补偿算法 (所有坐标均为原图坐标系)
-            compensated_bboxes = GridRecovery._compensate_unclosed_cells(
-                existing_bboxes=bboxes,  # 已有bbox (原图坐标系)
-                ocr_bboxes=ocr_bboxes,   # OCR结果 (原图坐标系)
-                rowboxes=rowboxes_filtered,  # 🆕 使用过滤后的水平线
-                colboxes=colboxes_filtered,  # 🆕 使用过滤后的垂直线
-                img_h=orig_h,
-                img_w=orig_w,
-                debug_dir=debug_dir,
-                debug_prefix=debug_prefix
-            )
-            
-            if compensated_bboxes:
-                logger.info(f"✅ OCR补偿成功: +{len(compensated_bboxes)}个边缘单元格")
-                bboxes.extend(compensated_bboxes)
-                # 重新排序
-                bboxes.sort(key=lambda b: (int(b[1] / 10), b[0]))
-            else:
-                logger.info("ℹ️ OCR补偿: 无需补偿边缘单元格")
-        
         return bboxes
 
     @staticmethod
@@ -757,458 +689,3 @@ class GridRecovery:
             new_cells.append(new_cell)
 
         return new_cells
-    
-    @staticmethod
-    def _compensate_unclosed_cells(
-        existing_bboxes: List[List[float]],
-        ocr_bboxes: List[Dict],
-        rowboxes: List[List[float]],
-        colboxes: List[List[float]],
-        img_h: float,
-        img_w: float,
-        min_confidence: float = 0.7,
-        debug_dir: Optional[str] = None,
-        debug_prefix: str = ""
-    ) -> List[List[float]]:
-        """
-        基于网格矩阵补偿未封闭的边缘单元格
-        
-        新算法思路:
-        1. 从rowboxes/colboxes构建网格矩阵
-        2. 将existing_bboxes映射到网格单元
-        3. 检测空的边缘单元格(与已有单元格相邻)
-        4. 用OCR填充这些空单元格
-        
-        Args:
-            existing_bboxes: 连通域检测到的bbox列表 (原图坐标系)
-            ocr_bboxes: 整页OCR结果
-            rowboxes: 水平线列表 (原图坐标系)
-            colboxes: 垂直线列表 (原图坐标系)
-            img_h, img_w: 原图尺寸
-            min_confidence: OCR最小置信度阈值
-            debug_dir, debug_prefix: Debug可视化参数
-            
-        Returns:
-            补偿的bbox列表 (原图坐标系)
-        """
-        if not ocr_bboxes or not rowboxes or not colboxes:
-            logger.debug("📊 OCR补偿: 缺少必要数据")
-            return []
-        
-        logger.info(f"🔧 OCR补偿参数: img_size=({img_w:.0f}×{img_h:.0f})")
-        
-        # Step 1: 过滤OCR
-        valid_ocr = [
-            ocr for ocr in ocr_bboxes
-            if ocr.get('confidence', 1.0) >= min_confidence
-            and len(ocr.get('text', '').strip()) > 0
-        ]
-        
-        if not valid_ocr:
-            logger.debug(f"📊 OCR补偿: 过滤后无有效OCR")
-            return []
-        
-        # Step 2: 构建网格(使用线条中点作为分割线)
-        row_dividers = sorted(set((line[1] + line[3]) / 2 for line in rowboxes))
-        col_dividers = sorted(set((line[0] + line[2]) / 2 for line in colboxes))
-        
-        # 添加图像边界
-        if not row_dividers or row_dividers[0] > 5:
-            row_dividers.insert(0, 0.0)
-        if not row_dividers or row_dividers[-1] < img_h - 5:
-            row_dividers.append(img_h)
-        if not col_dividers or col_dividers[0] > 5:
-            col_dividers.insert(0, 0.0)
-        if not col_dividers or col_dividers[-1] < img_w - 5:
-            col_dividers.append(img_w)
-        
-        logger.debug(f"📊 网格: {len(row_dividers)-1}行 × {len(col_dividers)-1}列")
-        
-        # Step 3: 将existing_bboxes映射到网格单元(支持跨行跨列)
-        grid = {}  # {(row, col): True} - 标记已占用的单元格
-        
-        def find_overlapping_cells(bbox: List[float]) -> List[tuple]:
-            """找到bbox覆盖的所有网格单元[(row, col), ...]"""
-            x1, y1, x2, y2 = bbox
-            cells = []
-            
-            for i in range(len(row_dividers) - 1):
-                # 检查垂直方向重叠
-                grid_y1, grid_y2 = row_dividers[i], row_dividers[i + 1]
-                if max(y1, grid_y1) < min(y2, grid_y2):  # 有重叠
-                    for j in range(len(col_dividers) - 1):
-                        # 检查水平方向重叠
-                        grid_x1, grid_x2 = col_dividers[j], col_dividers[j + 1]
-                        if max(x1, grid_x1) < min(x2, grid_x2):  # 有重叠
-                            cells.append((i, j))
-            
-            return cells
-        
-        # 标记所有existing_bbox占用的网格单元
-        for bbox in existing_bboxes:
-            cells = find_overlapping_cells(bbox)
-            for cell in cells:
-                grid[cell] = True
-        
-        logger.debug(f"📊 已占用: {len(grid)}个网格单元 (共{(len(row_dividers)-1)*(len(col_dividers)-1)}个)")
-        
-        # Step 4: 迭代补偿 - 多轮查找有相邻单元格的OCR
-        # 第一轮补偿的OCR成为"已占用",让后续OCR能找到相邻单元格
-        ocr_to_empty_cells = {}  # {ocr_index: {'ocr', 'empty_cells'}}
-        remaining_ocr_indices = set(range(len(valid_ocr)))  # 剩余未处理的OCR索引
-        iteration = 0
-        max_iterations = 10  # 防止无限循环
-        
-        while remaining_ocr_indices and iteration < max_iterations:
-            iteration += 1
-            newly_added = {}  # 本轮新增的OCR
-            
-            for idx in list(remaining_ocr_indices):
-                ocr = valid_ocr[idx]
-                ocr_bbox = ocr['bbox']
-                ocr_text = ocr.get('text', '')[:30]
-                
-                # 🆕 使用OCR bbox的中心点查找所在单元格,避免跨多行/列的错误映射
-                ocr_center_x = (ocr_bbox[0] + ocr_bbox[2]) / 2
-                ocr_center_y = (ocr_bbox[1] + ocr_bbox[3]) / 2
-                
-                # 找到中心点所在的行和列
-                center_row = None
-                center_col = None
-                for i in range(len(row_dividers) - 1):
-                    if row_dividers[i] <= ocr_center_y < row_dividers[i + 1]:
-                        center_row = i
-                        break
-                for j in range(len(col_dividers) - 1):
-                    if col_dividers[j] <= ocr_center_x < col_dividers[j + 1]:
-                        center_col = j
-                        break
-                
-                if center_row is None or center_col is None:
-                    logger.debug(
-                        f"⏭️ 跳过OCR '{ocr_text}': 中心点({ocr_center_x:.1f},{ocr_center_y:.1f})不在网格内"
-                    )
-                    remaining_ocr_indices.remove(idx)
-                    continue
-                
-                # 检查中心点所在单元格是否为空
-                center_cell = (center_row, center_col)
-                if center_cell in grid:
-                    logger.debug(
-                        f"⏭️ 跳过OCR '{ocr_text}': 单元格[{center_row},{center_col}]已被占用"
-                    )
-                    remaining_ocr_indices.remove(idx)
-                    continue
-                
-                # 只使用中心点所在的单元格作为初始empty_cells
-                empty_cells = [center_cell]
-                
-                # 检查是否是边缘单元格(至少一个空单元格与已占用单元格相邻)
-                has_neighbor = False
-                for row, col in empty_cells:
-                    for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
-                        neighbor = (row + dr, col + dc)
-                        if neighbor in grid:
-                            has_neighbor = True
-                            break
-                    if has_neighbor:
-                        break
-                
-                if not has_neighbor:
-                    # 本轮没有相邻单元格,留到下一轮
-                    continue
-                
-                # 找到有相邻单元格的OCR,添加到本轮结果
-                newly_added[idx] = {
-                    'ocr': ocr,
-                    'empty_cells': empty_cells
-                }
-                remaining_ocr_indices.remove(idx)
-            
-            if not newly_added:
-                # 本轮没有新增OCR,终止迭代
-                logger.debug(f"📊 迭代终止: 第{iteration}轮无新增OCR")
-                break
-            
-            # 将本轮新增的OCR添加到总结果
-            ocr_to_empty_cells.update(newly_added)
-            
-            # 🆕 立即将本轮新增的OCR标记到grid,作为下一轮的"已占用单元格"
-            for idx, ocr_data in newly_added.items():
-                for cell in ocr_data['empty_cells']:
-                    grid[cell] = True
-            
-            logger.debug(
-                f"📊 第{iteration}轮: 新增{len(newly_added)}个OCR, "
-                f"剩余{len(remaining_ocr_indices)}个待处理"
-            )
-        
-        if remaining_ocr_indices:
-            logger.debug(
-                f"⏭️ {len(remaining_ocr_indices)}个OCR无法补偿(无相邻单元格或超出迭代次数)"
-            )
-        
-        logger.debug(f"📊 Step 4完成: {len(ocr_to_empty_cells)}个OCR需要补偿(共{iteration}轮迭代)")
-        
-        # Step 5: grid已在迭代过程中更新,跳过
-        # (不需要再次标记,因为每轮迭代都已经更新了grid)
-        
-        # Step 6: 去除边缘整行或整列的空网格(确定表格实际内容边界)
-        occupied_rows = set(r for r, c in grid.keys())
-        occupied_cols = set(c for r, c in grid.keys())
-        
-        if not occupied_rows or not occupied_cols:
-            logger.warning("⚠️ 没有占用的单元格,无法确定表格边界")
-            return []
-        
-        # 确定表格实际内容范围
-        content_min_row = min(occupied_rows)
-        content_max_row = max(occupied_rows)
-        content_min_col = min(occupied_cols)
-        content_max_col = max(occupied_cols)
-        
-        logger.debug(
-            f"📊 Step 6完成: 表格内容边界 = "
-            f"row[{content_min_row}-{content_max_row}] × col[{content_min_col}-{content_max_col}]"
-        )
-        
-        # 🆕 不恢复grid状态,保持OCR单元格的临时标记
-        # 这样在扩展时,每个OCR都能看到其他OCR的占用,避免重复扩展
-        
-        # Step 7: 对所有标记的OCR区域统一扩展(只能向表格内部扩展)
-        # 🆕 辅助函数:检查侧边相邻列/行的已占用单元格边界
-        def get_side_boundary_for_vertical_expansion(current_min_col, current_max_col, direction='up'):
-            """向上/下扩展时,检查左右两侧相邻列的单元格边界"""
-            boundary_rows = []
-            
-            # 检查左侧相邻列(current_min_col - 1)
-            if current_min_col > 0:
-                left_col = current_min_col - 1
-                occupied_rows_in_left = [r for r, c in grid.keys() if c == left_col]
-                if occupied_rows_in_left:
-                    if direction == 'up':
-                        boundary_rows.append(min(occupied_rows_in_left))
-                    else:  # down
-                        boundary_rows.append(max(occupied_rows_in_left))
-            
-            # 检查右侧相邻列(current_max_col + 1)
-            if current_max_col < len(col_dividers) - 2:
-                right_col = current_max_col + 1
-                occupied_rows_in_right = [r for r, c in grid.keys() if c == right_col]
-                if occupied_rows_in_right:
-                    if direction == 'up':
-                        boundary_rows.append(min(occupied_rows_in_right))
-                    else:  # down
-                        boundary_rows.append(max(occupied_rows_in_right))
-            
-            return boundary_rows
-        
-        def get_side_boundary_for_horizontal_expansion(current_min_row, current_max_row, direction='left'):
-            """向左/右扩展时,检查上下两侧相邻行的单元格边界"""
-            boundary_cols = []
-            
-            # 检查上侧相邻行(current_min_row - 1)
-            if current_min_row > 0:
-                top_row = current_min_row - 1
-                occupied_cols_in_top = [c for r, c in grid.keys() if r == top_row]
-                if occupied_cols_in_top:
-                    if direction == 'left':
-                        boundary_cols.append(min(occupied_cols_in_top))
-                    else:  # right
-                        boundary_cols.append(max(occupied_cols_in_top))
-            
-            # 检查下侧相邻行(current_max_row + 1)
-            if current_max_row < len(row_dividers) - 2:
-                bottom_row = current_max_row + 1
-                occupied_cols_in_bottom = [c for r, c in grid.keys() if r == bottom_row]
-                if occupied_cols_in_bottom:
-                    if direction == 'left':
-                        boundary_cols.append(min(occupied_cols_in_bottom))
-                    else:  # right
-                        boundary_cols.append(max(occupied_cols_in_bottom))
-            
-            return boundary_cols
-        
-        # 🆕 逐个处理每个OCR,扩展完立即更新grid状态
-        # 这样后续OCR能看到前面OCR已经扩展占据的单元格,避免重复扩展
-        for idx, ocr_data in ocr_to_empty_cells.items():
-            empty_cells = ocr_data['empty_cells']
-            ocr = ocr_data['ocr']
-            ocr_text = ocr.get('text', '')[:30]
-            
-            # 向上下左右扩展连续的空单元格(只能在表格内容边界内扩展)
-            expanded = set(empty_cells)
-            changed = True
-            while changed:
-                changed = False
-                current_min_row = min(r for r, c in expanded)
-                current_max_row = max(r for r, c in expanded)
-                current_min_col = min(c for r, c in expanded)
-                current_max_col = max(c for r, c in expanded)
-                
-                # 🆕 向上扩展(不能超过表格内容上边界 content_min_row)
-                if current_min_row > content_min_row:
-                    row_above = current_min_row - 1
-                    # 检查该行是否都是空的
-                    if all((row_above, col) not in grid for col in range(current_min_col, current_max_col + 1)):
-                        # 检查左右侧相邻列的单元格最小行(上边界)
-                        side_boundaries = get_side_boundary_for_vertical_expansion(
-                            current_min_col, current_max_col, 'up'
-                        )
-                        can_expand = True
-                        if side_boundaries:
-                            # 左右侧单元格的最小行,不能扩展超过它
-                            min_side_row = min(side_boundaries)
-                            if row_above < min_side_row:
-                                can_expand = False
-                        
-                        if can_expand:
-                            for col in range(current_min_col, current_max_col + 1):
-                                expanded.add((row_above, col))
-                            changed = True
-                
-                # 🆕 向下扩展(不能超过表格内容下边界 content_max_row)
-                if current_max_row < content_max_row:
-                    row_below = current_max_row + 1
-                    if all((row_below, col) not in grid for col in range(current_min_col, current_max_col + 1)):
-                        side_boundaries = get_side_boundary_for_vertical_expansion(
-                            current_min_col, current_max_col, 'down'
-                        )
-                        can_expand = True
-                        if side_boundaries:
-                            max_side_row = max(side_boundaries)
-                            if row_below > max_side_row:
-                                can_expand = False
-                        
-                        if can_expand:
-                            for col in range(current_min_col, current_max_col + 1):
-                                expanded.add((row_below, col))
-                            changed = True
-                
-                # 🆕 向左扩展(不能超过表格内容左边界 content_min_col)
-                if current_min_col > content_min_col:
-                    col_left = current_min_col - 1
-                    if all((row, col_left) not in grid for row in range(current_min_row, current_max_row + 1)):
-                        side_boundaries = get_side_boundary_for_horizontal_expansion(
-                            current_min_row, current_max_row, 'left'
-                        )
-                        can_expand = True
-                        if side_boundaries:
-                            min_side_col = min(side_boundaries)
-                            if col_left < min_side_col:
-                                can_expand = False
-                        
-                        if can_expand:
-                            for row in range(current_min_row, current_max_row + 1):
-                                expanded.add((row, col_left))
-                            changed = True
-                
-                # 🆕 向右扩展(不能超过表格内容右边界 content_max_col)
-                if current_max_col < content_max_col:
-                    col_right = current_max_col + 1
-                    if all((row, col_right) not in grid for row in range(current_min_row, current_max_row + 1)):
-                        side_boundaries = get_side_boundary_for_horizontal_expansion(
-                            current_min_row, current_max_row, 'right'
-                        )
-                        can_expand = True
-                        if side_boundaries:
-                            max_side_col = max(side_boundaries)
-                            if col_right > max_side_col:
-                                can_expand = False
-                        
-                        if can_expand:
-                            for row in range(current_min_row, current_max_row + 1):
-                                expanded.add((row, col_right))
-                            changed = True
-            
-            # 🆕 扩展完成后,立即将扩展后的单元格标记到grid中
-            # 这样后续OCR扩展时能看到这个OCR占据的区域,避免重复扩展
-            for cell in expanded:
-                grid[cell] = True
-            
-            # 更新扩展后的空单元格
-            ocr_to_empty_cells[idx]['expanded_cells'] = list(expanded)
-            logger.debug(f"  OCR '{ocr_text}' 扩展完成: {list(expanded)}")
-        
-        logger.debug(f"📊 Step 7完成: 所有OCR区域已扩展")
-        
-        # Step 8: 生成补偿bbox
-        compensated_bboxes = []
-        
-        for idx, ocr_data in ocr_to_empty_cells.items():
-            empty_cells = ocr_data['expanded_cells']
-            ocr = ocr_data['ocr']
-            ocr_text = ocr.get('text', '')[:30]
-            
-            # 找到所有空单元格的边界范围
-            min_row = min(r for r, c in empty_cells)
-            max_row = max(r for r, c in empty_cells)
-            min_col = min(c for r, c in empty_cells)
-            max_col = max(c for r, c in empty_cells)
-            
-            # 使用网格边界作为bbox(精确对齐)
-            # 显式转换为Python float,避免numpy.float32导致JSON序列化错误
-            y1 = float(row_dividers[min_row])
-            y2 = float(row_dividers[max_row + 1])
-            x1 = float(col_dividers[min_col])
-            x2 = float(col_dividers[max_col + 1])
-            
-            compensated_bbox = [x1, y1, x2, y2]
-            compensated_bboxes.append(compensated_bbox)
-            
-            # 标记这些单元格为已占用
-            for row, col in empty_cells:
-                grid[(row, col)] = True
-            
-            logger.info(
-                f"✅ 补偿单元格[{min_row}-{max_row},{min_col}-{max_col}]: '{ocr_text}' | "
-                f"bbox=[{x1:.1f},{y1:.1f},{x2:.1f},{y2:.1f}] | "
-                f"占据{len(empty_cells)}个网格单元"
-            )
-        
-        # Step 5: Debug可视化(增强版:颜色区分原有/补偿单元格)
-        if debug_dir and compensated_bboxes:
-            try:
-                from pathlib import Path
-                vis_img = np.ones((int(img_h), int(img_w), 3), dtype=np.uint8) * 255
-                
-                # 绘制网格线(浅灰色虚线)
-                for y in row_dividers:
-                    cv2.line(vis_img, (0, int(y)), (int(img_w), int(y)), (220, 220, 220), 1, cv2.LINE_AA)
-                for x in col_dividers:
-                    cv2.line(vis_img, (int(x), 0), (int(x), int(img_h)), (220, 220, 220), 1, cv2.LINE_AA)
-                
-                # 绘制现有bbox(绿色 - 原有单元格)
-                for bbox in existing_bboxes:
-                    x1, y1, x2, y2 = [int(v) for v in bbox]
-                    cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 200, 0), 2)
-                
-                # 绘制补偿bbox(橙色 - 补偿单元格,加粗)
-                for bbox in compensated_bboxes:
-                    x1, y1, x2, y2 = [int(v) for v in bbox]
-                    cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 165, 255), 3)  # 橙色,线宽3
-                
-                # 添加图例和统计信息
-                legend_y = 30
-                cv2.putText(vis_img, f"OCR Compensation: +{len(compensated_bboxes)} cells", (10, legend_y),
-                           cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
-                legend_y += 35
-                cv2.putText(vis_img, f"Green: Original ({len(existing_bboxes)})", (10, legend_y),
-                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 200, 0), 2)
-                legend_y += 30
-                cv2.putText(vis_img, f"Orange: Compensated ({len(compensated_bboxes)})", (10, legend_y),
-                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 165, 255), 2)
-                legend_y += 30
-                cv2.putText(vis_img, f"Gray: Grid lines ({len(row_dividers)-1}x{len(col_dividers)-1})", (10, legend_y),
-                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (150, 150, 150), 2)
-                
-                out_path = Path(debug_dir) / f"{debug_prefix}step06_ocr_compensation.png"
-                cv2.imwrite(str(out_path), vis_img)
-                logger.info(f"💾 OCR补偿可视化已保存: {out_path}")
-                logger.info(f"   📊 单元格统计: 原有={len(existing_bboxes)}, 补偿={len(compensated_bboxes)}, "
-                           f"总计={len(existing_bboxes) + len(compensated_bboxes)}")
-            except Exception as e:
-                logger.warning(f"⚠️ Debug可视化失败: {e}")
-        
-        logger.info(f"🎉 OCR补偿完成: +{len(compensated_bboxes)}个边缘单元格")
-        return compensated_bboxes

+ 52 - 6
ocr_tools/universal_doc_parser/models/adapters/wired_table/text_filling.py

@@ -28,7 +28,7 @@ class TextFiller:
         """
         self.ocr_engine = ocr_engine
         self.cell_crop_margin: int = config.get("cell_crop_margin", 2)
-        self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.9)  # 单元格 OCR 置信度阈值
+        self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.9)  # 单元格 OCR 置信度阈值(基准值)
         
         # 跨单元格检测配置参数
         self.overlap_threshold_horizontal: float = config.get("overlap_threshold_horizontal", 0.2)
@@ -38,6 +38,45 @@ class TextFiller:
         self.other_cell_max_ratio: float = config.get("other_cell_max_ratio", 0.3)
     
     @staticmethod
+    def calculate_dynamic_confidence_threshold(text: str, base_threshold: float = 0.9) -> float:
+        """
+        根据文本长度动态计算置信度阈值
+        
+        策略:
+        - 单字符:使用较高阈值(避免误识别,如"1"误识别为"l")
+        - 短文本(2-3字符):使用中等阈值
+        - 中等长度(4-10字符):使用基准阈值
+        - 长文本(10+字符):使用较低阈值(长文本整体可靠性更高)
+        
+        Args:
+            text: 识别的文本
+            base_threshold: 基准置信度阈值(默认0.9)
+            
+        Returns:
+            动态调整后的置信度阈值
+        """
+        if not text:
+            return base_threshold
+        
+        text_len = len(text.strip())
+        
+        if text_len == 1:
+            # 单字符:提高阈值 +0.05
+            return min(0.95, base_threshold + 0.1)
+        elif text_len <= 3:
+            # 2-3字符:轻微提高阈值 +0.02
+            return min(0.92, base_threshold + 0.02)
+        elif text_len <= 10:
+            # 4-10字符:使用基准阈值
+            return max(0.85, base_threshold - 0.05)
+        elif text_len <= 20:
+            # 11-20字符:降低阈值 -0.03
+            return max(0.80, base_threshold - 0.1)
+        else:
+            # 20+字符:显著降低阈值 -0.05
+            return max(0.75, base_threshold - 0.15)
+    
+    @staticmethod
     def calculate_overlap_ratio(ocr_bbox: List[float], cell_bbox: List[float]) -> float:
         """
         计算 OCR box 与单元格的重叠比例(重叠面积 / OCR box 面积)
@@ -608,7 +647,7 @@ class TextFiller:
 
             # 对齐长度,避免越界
             n = min(len(results) if isinstance(results, list) else 0, len(crop_list), len(crop_indices))
-            conf_th = self.ocr_conf_threshold
+            base_conf_th = self.ocr_conf_threshold
 
             # 辅助函数:清理文件名中的非法字符
             def sanitize_filename(text: str, max_length: int = 50) -> str:
@@ -642,10 +681,17 @@ class TextFiller:
                     except Exception as e:
                         logger.warning(f"保存单元格OCR图片失败 (cell {cell_idx}): {e}")
                 
-                if text_k and score_k >= conf_th:
-                    texts[cell_idx] = text_k
-                elif text_k:
-                    logger.debug(f"单元格 {cell_idx} 二次OCR结果置信度({score_k:.2f})低于阈值({conf_th}): (文本: '{text_k[:30]}...')")
+                if text_k:
+                    # 根据文本长度动态调整置信度阈值
+                    dynamic_conf_th = self.calculate_dynamic_confidence_threshold(text_k, base_conf_th)
+                    
+                    if score_k >= dynamic_conf_th:
+                        texts[cell_idx] = text_k
+                    else:
+                        logger.debug(
+                            f"单元格 {cell_idx} 二次OCR结果置信度({score_k:.2f})低于动态阈值({dynamic_conf_th:.2f}) "
+                            f"[文本长度={len(text_k)}, 基准阈值={base_conf_th:.2f}]: '{text_k[:30]}...'"
+                        )
 
         except Exception as e:
             logger.warning(f"二次OCR失败: {e}")

+ 3 - 4
ocr_tools/universal_doc_parser/tests/cell_fusion_config_example.yaml

@@ -24,8 +24,7 @@ wired_table_recognizer:
     rtdetr_conf_threshold: 0.5  # RT-DETR置信度阈值
     
     # 功能开关
-    enable_ocr_compensation: true      # 启用OCR孤立文本补偿
-    skip_rtdetr_for_txt_pdf: true      # 🎯 文字PDF跳过RT-DETR(自适应策略)
+    enable_ocr_compensation: true      # 启用OCR边缘补偿
   
   # 调试选项
   debug_options:
@@ -35,8 +34,8 @@ wired_table_recognizer:
     save_fusion_comparison: true  # 保存融合对比图
 
 # 使用说明:
-# 1. 文字PDF (pdf_type='txt'): 自动跳过RT-DETR,使用纯UNet模式(无噪声干扰)
-# 2. 扫描PDF (pdf_type='ocr'): 启用融合模式,结合UNet、RT-DETR和OCR三路结果
+# 1. 所有PDF类型都使用UNet+RT-DETR融合模式
+# 2. OCR边缘补偿在融合后执行,补偿"有OCR文本但无单元格覆盖"的位置
 # 3. UNet结果为空: 强制启用RT-DETR补救
 # 4. 融合失败: 自动降级到UNet-only模式
 

+ 18 - 21
ocr_tools/universal_doc_parser/tests/test_cell_fusion.py

@@ -71,8 +71,7 @@ def test_fusion_engine(detector):
         'iou_merge_threshold': 0.7,
         'iou_nms_threshold': 0.5,
         'rtdetr_conf_threshold': 0.5,
-        'enable_ocr_compensation': True,
-        'skip_rtdetr_for_txt_pdf': True
+        'enable_ocr_compensation': True
     }
     
     # 初始化
@@ -93,33 +92,30 @@ def test_fusion_engine(detector):
         {'bbox': [20, 70, 80, 90], 'text': 'Cell 2'}
     ]
     
-    # Test 2.1: 文字PDF模式(应跳过RT-DETR
-    print("\n📄 Test 2.1: Text PDF mode (should skip RT-DETR)")
+    # Test 2.1: 文字PDF模式(现在也使用RT-DETR融合
+    print("\n📄 Test 2.1: Text PDF mode (now uses RT-DETR fusion)")
     fused_cells, stats = engine.fuse(
         table_image=table_image,
         unet_cells=unet_cells,
         ocr_boxes=ocr_boxes,
-        pdf_type='txt',
-        upscale=1.0
+        pdf_type='txt'
     )
     print(f"   Use RT-DETR: {stats['use_rtdetr']}")
     print(f"   Fused cells: {len(fused_cells)}")
-    assert not stats['use_rtdetr'], "❌ Should skip RT-DETR for text PDF"
-    assert len(fused_cells) == len(unet_cells), "❌ Should keep UNet cells only"
-    print("   ✅ Correctly skipped RT-DETR for text PDF")
+    assert stats['use_rtdetr'], "✔️ Now uses RT-DETR for all PDF types"
+    print("   ✅ Correctly enabled RT-DETR for text PDF")
     
-    # Test 2.2: 扫描PDF模式(应启用RT-DETR,但因为是假图片可能失败)
-    print("\n🔍 Test 2.2: Scan PDF mode (should enable RT-DETR)")
+    # Test 2.2: 扫描PDF模式
+    print("\n🔍 Test 2.2: Scan PDF mode")
     fused_cells, stats = engine.fuse(
         table_image=table_image,
         unet_cells=unet_cells,
         ocr_boxes=ocr_boxes,
-        pdf_type='ocr',
-        upscale=1.0
+        pdf_type='ocr'
     )
     print(f"   Use RT-DETR: {stats['use_rtdetr']}")
     print(f"   Stats: {stats}")
-    print("   ✅ Fusion completed (RT-DETR may return 0 cells on blank image)")
+    print("   ✅ Fusion completed")
     
     return engine
 
@@ -130,24 +126,25 @@ def test_adaptive_strategy():
     print("Test 3: 自适应策略测试")
     print("=" * 60)
     
-    engine = CellFusionEngine(rtdetr_detector=None, config={'skip_rtdetr_for_txt_pdf': True})
+    engine = CellFusionEngine(rtdetr_detector=None, config={})
     
-    # Test 3.1: 文字PDF + 正常单元格数 → 跳过
+    # Test 3.1: 文字PDF + 检测器未初始化 → 跳过
     should_use = engine.should_use_rtdetr('txt', unet_cell_count=10, table_size=(500, 500))
-    print(f"📄 Text PDF, 10 cells: use_rtdetr={should_use}")
-    assert not should_use, "❌ Should skip RT-DETR"
+    print(f"📄 Text PDF, 10 cells, no detector: use_rtdetr={should_use}")
+    assert not should_use, "❌ Should skip (detector not available)"
     print("   ✅ Correct")
     
-    # Test 3.2: 扫描PDF + 正常单元格数 → 跳过(因为检测器未初始化)
+    # Test 3.2: 扫描PDF + 检测器未初始化 → 跳过
     should_use = engine.should_use_rtdetr('ocr', unet_cell_count=10, table_size=(500, 500))
     print(f"🔍 Scan PDF, 10 cells, no detector: use_rtdetr={should_use}")
     assert not should_use, "❌ Should skip (detector not available)"
     print("   ✅ Correct")
     
-    # Test 3.3: UNet为空 → 强制启用(但检测器未初始化,仍跳过)
+    # Test 3.3: UNet为空 + 检测器未初始化 → 仍跳过
     should_use = engine.should_use_rtdetr('ocr', unet_cell_count=0, table_size=(500, 500))
     print(f"🚨 Scan PDF, 0 cells, no detector: use_rtdetr={should_use}")
-    print("   ⚠️ Would force enable if detector available")
+    assert not should_use, "❌ Should skip (detector not available)"
+    print("   ✅ Correct (would force enable if detector available)")
     
     print("\n✅ All adaptive strategy tests passed")