|
|
@@ -42,6 +42,7 @@ class CellFusionEngine:
|
|
|
- rtdetr_conf_threshold: 0.5 (RT-DETR置信度阈值)
|
|
|
- enable_ocr_compensation: True (启用OCR补偿)
|
|
|
- skip_rtdetr_for_txt_pdf: True (文字PDF跳过RT-DETR)
|
|
|
+ - enable_boundary_noise_filter: True (启用边界噪声过滤)
|
|
|
"""
|
|
|
self.rtdetr_detector = rtdetr_detector
|
|
|
self.config = config or {}
|
|
|
@@ -54,10 +55,12 @@ class CellFusionEngine:
|
|
|
self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
|
|
|
self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
|
|
|
self.skip_rtdetr_for_txt_pdf = self.config.get('skip_rtdetr_for_txt_pdf', True)
|
|
|
+ self.enable_boundary_noise_filter = self.config.get('enable_boundary_noise_filter', True)
|
|
|
|
|
|
logger.info(f"🔧 CellFusionEngine initialized: "
|
|
|
f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
|
|
|
- f"iou_merge={self.iou_merge_threshold}, skip_txt_pdf={self.skip_rtdetr_for_txt_pdf}")
|
|
|
+ f"iou_merge={self.iou_merge_threshold}, skip_txt_pdf={self.skip_rtdetr_for_txt_pdf}, "
|
|
|
+ f"boundary_filter={self.enable_boundary_noise_filter}")
|
|
|
|
|
|
def should_use_rtdetr(
|
|
|
self,
|
|
|
@@ -101,7 +104,6 @@ class CellFusionEngine:
|
|
|
unet_cells: List[List[float]],
|
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
|
pdf_type: str = 'ocr',
|
|
|
- upscale: float = 1.0,
|
|
|
debug_dir: Optional[str] = None,
|
|
|
debug_prefix: str = "fusion"
|
|
|
) -> Tuple[List[List[float]], Dict[str, Any]]:
|
|
|
@@ -113,7 +115,6 @@ class CellFusionEngine:
|
|
|
unet_cells: UNet检测的单元格列表 [[x1,y1,x2,y2], ...](原图坐标系)
|
|
|
ocr_boxes: OCR结果列表
|
|
|
pdf_type: PDF类型 ('txt' 或 'ocr')
|
|
|
- upscale: UNet的上采样比例
|
|
|
debug_dir: 调试输出目录(可选)
|
|
|
debug_prefix: 调试文件前缀
|
|
|
|
|
|
@@ -123,6 +124,12 @@ class CellFusionEngine:
|
|
|
- fusion_stats: 融合统计信息
|
|
|
"""
|
|
|
h, w = table_image.shape[:2]
|
|
|
+ unet_bbox = [
|
|
|
+ min(unet_cells, key=lambda box: box[0])[0], \
|
|
|
+ min(unet_cells, key=lambda box: box[1])[1], \
|
|
|
+ max(unet_cells, key=lambda box: box[2])[2], \
|
|
|
+ max(unet_cells, key=lambda box: box[3])[3]
|
|
|
+ ] if unet_cells else [0,0,0,0]
|
|
|
|
|
|
# 决策:是否使用 RT-DETR
|
|
|
use_rtdetr = self.should_use_rtdetr(pdf_type, len(unet_cells), (w, h))
|
|
|
@@ -143,14 +150,6 @@ class CellFusionEngine:
|
|
|
cell_labels = ['unet_only'] * len(fused_cells) # 所有都是UNet独有
|
|
|
fusion_stats['fused_count'] = len(fused_cells)
|
|
|
|
|
|
- # 可选:OCR补偿
|
|
|
- if self.enable_ocr_compensation and ocr_boxes:
|
|
|
- fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
|
|
|
- fused_cells, cell_labels, ocr_boxes, (w, h)
|
|
|
- )
|
|
|
- fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
|
- fusion_stats['fused_count'] = len(fused_cells)
|
|
|
-
|
|
|
logger.info(f"📊 Fusion (UNet-only): {len(unet_cells)} → {len(fused_cells)} cells")
|
|
|
return fused_cells, fusion_stats
|
|
|
|
|
|
@@ -165,6 +164,12 @@ class CellFusionEngine:
|
|
|
rtdetr_cells = [res['bbox'] for res in rtdetr_results]
|
|
|
rtdetr_scores = [res['score'] for res in rtdetr_results]
|
|
|
fusion_stats['rtdetr_count'] = len(rtdetr_cells)
|
|
|
+ rtdetr_bbox = [
|
|
|
+ min(rtdetr_cells, key=lambda box: box[0])[0],
|
|
|
+ min(rtdetr_cells, key=lambda box: box[1])[1],
|
|
|
+ max(rtdetr_cells, key=lambda box: box[2])[2],
|
|
|
+ max(rtdetr_cells, key=lambda box: box[3])[3]
|
|
|
+ ] if rtdetr_cells else [0,0,0,0]
|
|
|
|
|
|
logger.debug(f"RT-DETR detected {len(rtdetr_cells)} cells")
|
|
|
except Exception as e:
|
|
|
@@ -175,28 +180,33 @@ class CellFusionEngine:
|
|
|
|
|
|
# Phase 2: 智能融合
|
|
|
fused_cells, merge_stats, cell_labels = self._fuse_cells(
|
|
|
- unet_cells, rtdetr_cells, rtdetr_scores
|
|
|
+ unet_bbox, unet_cells, rtdetr_cells, rtdetr_scores
|
|
|
)
|
|
|
fusion_stats['merged_count'] = merge_stats['merged']
|
|
|
fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
|
|
|
fusion_stats['added_count'] = merge_stats['added']
|
|
|
|
|
|
# Phase 3: NMS 去重
|
|
|
- fused_cells = self._nms_filter(fused_cells, self.iou_nms_threshold)
|
|
|
-
|
|
|
- # Phase 4: OCR 补偿(可选)
|
|
|
- # if self.enable_ocr_compensation and ocr_boxes:
|
|
|
- # fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
|
|
|
- # fused_cells, cell_labels, ocr_boxes, (w, h)
|
|
|
- # )
|
|
|
- # fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
|
+ fused_cells, suppressed = self._nms_filter(fused_cells, self.iou_nms_threshold)
|
|
|
+ # 同步更新 cell_labels
|
|
|
+ cell_labels = [label for label, keep in zip(cell_labels, suppressed) if not keep]
|
|
|
+
|
|
|
+ # Phase 4: 边界噪声过滤(过滤掉边界的 unet_only 噪声单元格)
|
|
|
+ if self.enable_boundary_noise_filter:
|
|
|
+ fused_cells, cell_labels, noise_filtered = self._filter_boundary_noise(
|
|
|
+ fused_cells, cell_labels, ocr_boxes, rtdetr_bbox
|
|
|
+ )
|
|
|
+ fusion_stats['noise_filtered_count'] = noise_filtered
|
|
|
+ else:
|
|
|
+ fusion_stats['noise_filtered_count'] = 0
|
|
|
+ noise_filtered = 0
|
|
|
|
|
|
fusion_stats['fused_count'] = len(fused_cells)
|
|
|
|
|
|
logger.info(
|
|
|
f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
|
|
|
f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
|
|
|
- f"Added={merge_stats['added']}, Final={len(fused_cells)}"
|
|
|
+ f"Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, Final={len(fused_cells)}"
|
|
|
)
|
|
|
|
|
|
# 可视化(调试)
|
|
|
@@ -210,6 +220,7 @@ class CellFusionEngine:
|
|
|
|
|
|
def _fuse_cells(
|
|
|
self,
|
|
|
+ unet_bbox: List[float],
|
|
|
unet_cells: List[List[float]],
|
|
|
rtdetr_cells: List[List[float]],
|
|
|
rtdetr_scores: List[float]
|
|
|
@@ -232,6 +243,7 @@ class CellFusionEngine:
|
|
|
- 总覆盖率>40%(所有UNet面积之和 / RT-DETR面积)
|
|
|
|
|
|
Args:
|
|
|
+ unet_bbox: UNet单元格的边界框 [x1, y1, x2, y2]
|
|
|
unet_cells: UNet单元格列表
|
|
|
rtdetr_cells: RT-DETR单元格列表
|
|
|
rtdetr_scores: RT-DETR置信度列表
|
|
|
@@ -243,14 +255,6 @@ class CellFusionEngine:
|
|
|
- cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'new']
|
|
|
"""
|
|
|
|
|
|
- # 计算unet_cells的边界框bbox[x1,y1,x2,y2]
|
|
|
- unet_bbox = [
|
|
|
- min(unet_cells, key=lambda box: box[0])[0], \
|
|
|
- min(unet_cells, key=lambda box: box[1])[1], \
|
|
|
- max(unet_cells, key=lambda box: box[2])[2], \
|
|
|
- max(unet_cells, key=lambda box: box[3])[3]
|
|
|
- ] if unet_cells else [0,0,0,0]
|
|
|
-
|
|
|
fused_cells = []
|
|
|
cell_labels = [] # 记录每个单元格的来源标签
|
|
|
unet_matched = [False] * len(unet_cells)
|
|
|
@@ -360,6 +364,7 @@ class CellFusionEngine:
|
|
|
if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
|
|
|
# 高IoU:加权平均合并
|
|
|
merged_cell = self._weighted_merge_bbox(
|
|
|
+ unet_bbox,
|
|
|
unet_cell,
|
|
|
rtdetr_cells[best_match_idx],
|
|
|
self.unet_weight,
|
|
|
@@ -396,6 +401,7 @@ class CellFusionEngine:
|
|
|
|
|
|
def _weighted_merge_bbox(
|
|
|
self,
|
|
|
+ table_bbox: List[float],
|
|
|
bbox1: List[float],
|
|
|
bbox2: List[float],
|
|
|
weight1: float,
|
|
|
@@ -405,6 +411,7 @@ class CellFusionEngine:
|
|
|
加权平均合并两个 bbox
|
|
|
|
|
|
Args:
|
|
|
+ table_bbox: 表格整体 bbox(用于限制合并结果)
|
|
|
bbox1: [x1, y1, x2, y2]
|
|
|
bbox2: [x1, y1, x2, y2]
|
|
|
weight1: bbox1 的权重
|
|
|
@@ -414,17 +421,17 @@ class CellFusionEngine:
|
|
|
merged_bbox: [x1, y1, x2, y2]
|
|
|
"""
|
|
|
return [
|
|
|
- weight1 * bbox1[0] + weight2 * bbox2[0],
|
|
|
- weight1 * bbox1[1] + weight2 * bbox2[1],
|
|
|
- weight1 * bbox1[2] + weight2 * bbox2[2],
|
|
|
- weight1 * bbox1[3] + weight2 * bbox2[3]
|
|
|
+ max(table_bbox[0], weight1 * bbox1[0] + weight2 * bbox2[0]),
|
|
|
+ max(table_bbox[1], weight1 * bbox1[1] + weight2 * bbox2[1]),
|
|
|
+ min(table_bbox[2], weight1 * bbox1[2] + weight2 * bbox2[2]),
|
|
|
+ min(table_bbox[3], weight1 * bbox1[3] + weight2 * bbox2[3])
|
|
|
]
|
|
|
|
|
|
def _nms_filter(
|
|
|
self,
|
|
|
cells: List[List[float]],
|
|
|
iou_threshold: float
|
|
|
- ) -> List[List[float]]:
|
|
|
+ ) -> Tuple[List[List[float]], List[bool]]:
|
|
|
"""
|
|
|
简单 NMS 过滤(去除高度重叠的冗余框)
|
|
|
|
|
|
@@ -436,9 +443,10 @@ class CellFusionEngine:
|
|
|
|
|
|
Returns:
|
|
|
过滤后的单元格列表
|
|
|
+ 抑制标记列表
|
|
|
"""
|
|
|
if len(cells) == 0:
|
|
|
- return []
|
|
|
+ return [], []
|
|
|
|
|
|
# 计算面积并排序(大框优先)
|
|
|
areas = [(x2 - x1) * (y2 - y1) for x1, y1, x2, y2 in cells]
|
|
|
@@ -463,81 +471,92 @@ class CellFusionEngine:
|
|
|
suppressed[other_idx] = True
|
|
|
|
|
|
logger.debug(f"NMS: {len(cells)} → {len(keep)} cells (threshold={iou_threshold})")
|
|
|
- return keep
|
|
|
+ return keep, suppressed
|
|
|
|
|
|
- def _compensate_with_ocr(
|
|
|
+ def _filter_boundary_noise(
|
|
|
self,
|
|
|
cells: List[List[float]],
|
|
|
cell_labels: List[str],
|
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
|
- table_size: Tuple[int, int]
|
|
|
+ rtdetr_bbox: List[float]
|
|
|
) -> Tuple[List[List[float]], List[str], int]:
|
|
|
"""
|
|
|
- 使用 OCR 补偿遗漏的单元格
|
|
|
+ 过滤边界噪声单元格
|
|
|
|
|
|
- 策略:如果 OCR 文本没有匹配到任何单元格,创建新单元格
|
|
|
+ 过滤条件:
|
|
|
+ 1. 单元格标记为 'unet_only'(只在 UNet 中检测到,RT-DETR 未匹配)
|
|
|
+ 2. 单元格位于表格边界(左边界或右边界)
|
|
|
+ 3. 单元格内没有任何 OCR 文本框(说明是空白区域)
|
|
|
|
|
|
Args:
|
|
|
- cells: 现有单元格列表
|
|
|
+ cells: 单元格列表
|
|
|
cell_labels: 单元格标签列表
|
|
|
ocr_boxes: OCR结果列表
|
|
|
- table_size: 表格尺寸 (width, height)
|
|
|
-
|
|
|
+ rtdetr_bbox: RT-DETR单元格的边界框 [x1, y1, x2, y2]
|
|
|
Returns:
|
|
|
- (compensated_cells, compensated_labels, compensation_count)
|
|
|
+ (filtered_cells, filtered_labels, filtered_count)
|
|
|
"""
|
|
|
- compensated = cells.copy()
|
|
|
- compensated_labels = cell_labels.copy()
|
|
|
- compensation_count = 0
|
|
|
- w, h = table_size
|
|
|
-
|
|
|
- for ocr in ocr_boxes:
|
|
|
- ocr_bbox = ocr.get('bbox', [])
|
|
|
- if not ocr_bbox or len(ocr_bbox) < 4:
|
|
|
+ filtered_cells = []
|
|
|
+ filtered_labels = []
|
|
|
+ filtered_count = 0
|
|
|
+
|
|
|
+ for cell, label in zip(cells, cell_labels):
|
|
|
+ # # 只过滤 unet_only 标记的单元格
|
|
|
+ # if label != 'unet_only':
|
|
|
+ # filtered_cells.append(cell)
|
|
|
+ # filtered_labels.append(label)
|
|
|
+ # continue
|
|
|
+
|
|
|
+ x1, y1, x2, y2 = cell
|
|
|
+
|
|
|
+ # 检查是否在边界
|
|
|
+ is_left_boundary = x1 <= rtdetr_bbox[0]
|
|
|
+ is_right_boundary = x2 >= rtdetr_bbox[2]
|
|
|
+ is_on_boundary = is_left_boundary or is_right_boundary
|
|
|
+
|
|
|
+ if not is_on_boundary:
|
|
|
+ # 不在边界,保留
|
|
|
+ filtered_cells.append(cell)
|
|
|
+ filtered_labels.append(label)
|
|
|
continue
|
|
|
|
|
|
- # 计算 OCR 中心点
|
|
|
- if len(ocr_bbox) == 8: # poly format
|
|
|
- ocr_cx = (ocr_bbox[0] + ocr_bbox[2] + ocr_bbox[4] + ocr_bbox[6]) / 4
|
|
|
- ocr_cy = (ocr_bbox[1] + ocr_bbox[3] + ocr_bbox[5] + ocr_bbox[7]) / 4
|
|
|
- else: # bbox format
|
|
|
- ocr_cx = (ocr_bbox[0] + ocr_bbox[2]) / 2
|
|
|
- ocr_cy = (ocr_bbox[1] + ocr_bbox[3]) / 2
|
|
|
-
|
|
|
- # 检查是否在任何单元格内
|
|
|
- is_covered = False
|
|
|
- for cell in compensated:
|
|
|
- x1, y1, x2, y2 = cell
|
|
|
+ # 检查单元格内是否有 OCR 文本框
|
|
|
+ has_ocr = False
|
|
|
+ for ocr in ocr_boxes:
|
|
|
+ ocr_bbox = ocr.get('bbox', [])
|
|
|
+ if not ocr_bbox or len(ocr_bbox) < 4:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 计算 OCR 中心点
|
|
|
+ if len(ocr_bbox) == 8: # poly format
|
|
|
+ ocr_cx = (ocr_bbox[0] + ocr_bbox[2] + ocr_bbox[4] + ocr_bbox[6]) / 4
|
|
|
+ ocr_cy = (ocr_bbox[1] + ocr_bbox[3] + ocr_bbox[5] + ocr_bbox[7]) / 4
|
|
|
+ else: # bbox format
|
|
|
+ ocr_cx = (ocr_bbox[0] + ocr_bbox[2]) / 2
|
|
|
+ ocr_cy = (ocr_bbox[1] + ocr_bbox[3]) / 2
|
|
|
+
|
|
|
+ # 检查 OCR 中心点是否在当前单元格内
|
|
|
if x1 <= ocr_cx <= x2 and y1 <= ocr_cy <= y2:
|
|
|
- is_covered = True
|
|
|
+ has_ocr = True
|
|
|
break
|
|
|
|
|
|
- # 如果孤立,创建新单元格
|
|
|
- if not is_covered:
|
|
|
- # 扩展 OCR bbox 作为新单元格
|
|
|
- if len(ocr_bbox) == 8:
|
|
|
- new_cell = [
|
|
|
- float(max(0, min(ocr_bbox[0], ocr_bbox[6]) - 5)),
|
|
|
- float(max(0, min(ocr_bbox[1], ocr_bbox[3]) - 5)),
|
|
|
- float(min(w, max(ocr_bbox[2], ocr_bbox[4]) + 5)),
|
|
|
- float(min(h, max(ocr_bbox[5], ocr_bbox[7]) + 5))
|
|
|
- ]
|
|
|
- else:
|
|
|
- new_cell = [
|
|
|
- float(max(0, ocr_bbox[0] - 5)),
|
|
|
- float(max(0, ocr_bbox[1] - 5)),
|
|
|
- float(min(w, ocr_bbox[2] + 5)),
|
|
|
- float(min(h, ocr_bbox[3] + 5))
|
|
|
- ]
|
|
|
-
|
|
|
- compensated.append(new_cell)
|
|
|
- compensated_labels.append('new') # 标记为新增(OCR补偿)
|
|
|
- compensation_count += 1
|
|
|
+ # 如果在边界且没有 OCR 文本,认为是噪声,过滤掉
|
|
|
+ if not has_ocr:
|
|
|
+ boundary_type = "left" if is_left_boundary else "right"
|
|
|
+ logger.debug(
|
|
|
+ f"🗑️ 过滤边界噪声: {boundary_type} boundary cell "
|
|
|
+ f"[{x1:.1f}, {y1:.1f}, {x2:.1f}, {y2:.1f}] (no OCR)"
|
|
|
+ )
|
|
|
+ filtered_count += 1
|
|
|
+ else:
|
|
|
+ # 有 OCR 文本,保留
|
|
|
+ filtered_cells.append(cell)
|
|
|
+ filtered_labels.append(label)
|
|
|
|
|
|
- if compensation_count > 0:
|
|
|
- logger.debug(f"OCR compensation: added {compensation_count} cells")
|
|
|
+ if filtered_count > 0:
|
|
|
+ logger.info(f"🗑️ Boundary noise filtering: removed {filtered_count} unet_only cells from boundaries")
|
|
|
|
|
|
- return compensated, compensated_labels, compensation_count
|
|
|
+ return filtered_cells, filtered_labels, filtered_count
|
|
|
|
|
|
def _visualize_fusion(
|
|
|
self,
|