Jelajahi Sumber

fix(cell_fusion): 更新合并单元格逻辑,增加unet_bbox边界框限制以提高准确性

zhch158_admin 2 minggu lalu
induk
melakukan
6db2bb35e7

+ 22 - 4
ocr_tools/universal_doc_parser/models/adapters/wired_table/cell_fusion.py

@@ -242,6 +242,15 @@ class CellFusionEngine:
             - stats: {'merged': int, 'added': int, 'merged_cells': int}
             - cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'new']
         """
+        
+        # 计算unet_cells的边界框bbox[x1,y1,x2,y2]
+        unet_bbox = [
+                    min(unet_cells, key=lambda box: box[0])[0], \
+                    min(unet_cells, key=lambda box: box[1])[1], \
+                    max(unet_cells, key=lambda box: box[2])[2], \
+                    max(unet_cells, key=lambda box: box[3])[3]
+                ] if unet_cells else [0,0,0,0]
+                    
         fused_cells = []
         cell_labels = []  # 记录每个单元格的来源标签
         unet_matched = [False] * len(unet_cells)
@@ -303,16 +312,20 @@ class CellFusionEngine:
                     rtdetr_area = self._calc_bbox_area(rtdetr_cell)
                     coverage = min(total_unet_area / rtdetr_area, 1.0) if rtdetr_area > 0 else 0
                     
-                    # 如果覆盖率>40%,说明这是一个真实的合并单元格
-                    # 降低阈值从0.5到0.4,因为合并单元格可能包含很多空白区域
-                    if coverage > 0.4:
-                        # 认定为合并单元格,取bounding与RT-DETR的最大范围
+                    # 如果覆盖率>50%,说明这是一个真实的合并单元格
+                    if coverage > 0.5:
+                        # 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过unet_bbox范围
                         fused_cell = [
                             min(bounding_x1, rtdetr_cell[0]),
                             min(bounding_y1, rtdetr_cell[1]),
                             max(bounding_x2, rtdetr_cell[2]),
                             max(bounding_y2, rtdetr_cell[3])
                         ]
+                        # x限制在unet_bbox范围内
+                        fused_cell[0] = max(fused_cell[0], unet_bbox[0])
+                        # fused_cell[1] = max(fused_cell[1], unet_bbox[1])
+                        fused_cell[2] = min(fused_cell[2], unet_bbox[2])
+                        # fused_cell[3] = min(fused_cell[3], unet_bbox[3])
                         fused_cells.append(fused_cell)
                         cell_labels.append('merged_span')  # 标记为合并单元格
                         rtdetr_matched[rt_idx] = True
@@ -366,6 +379,11 @@ class CellFusionEngine:
         # Step 3: 补充 RT-DETR 独有的高置信度单元格
         for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
             if not rtdetr_matched[idx] and score > 0.7:
+                # rtdetr_cell不能超出unet_bbox范围, x方向分别限制
+                rtdetr_cell[0] = max(rtdetr_cell[0], unet_bbox[0])
+                # rtdetr_cell[1] = max(rtdetr_cell[1], unet_bbox[1])
+                rtdetr_cell[2] = min(rtdetr_cell[2], unet_bbox[2])
+                # rtdetr_cell[3] = min(rtdetr_cell[3], unet_bbox[3])
                 fused_cells.append(rtdetr_cell)
                 cell_labels.append('rtdetr_only')  # 标记为RT-DETR独有
                 stats['added'] += 1