|
|
@@ -17,11 +17,14 @@ class CellFusionEngine:
|
|
|
融合策略:
|
|
|
1. UNet 连通域检测(结构性强,适合清晰表格)
|
|
|
2. RT-DETR 端到端检测(鲁棒性强,适合噪声表格)
|
|
|
- 3. OCR 文本位置(验证单元格存在性)
|
|
|
+ 3. OCR 边缘补偿(补偿"有OCR文本但无单元格覆盖"的位置)
|
|
|
|
|
|
- 自适应策略:
|
|
|
- - 文字PDF (pdf_type='txt'): 跳过 RT-DETR,纯 UNet 模式(无噪声)
|
|
|
- - 扫描PDF (pdf_type='ocr'): 启用融合模式(有噪声)
|
|
|
+ 处理流程:
|
|
|
+ - Phase 1: RT-DETR 检测
|
|
|
+ - Phase 2: UNet + RT-DETR 智能融合
|
|
|
+ - Phase 3: NMS 去重
|
|
|
+ - Phase 4: 边界噪声过滤
|
|
|
+ - Phase 5: OCR 边缘补偿
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
@@ -41,7 +44,6 @@ class CellFusionEngine:
|
|
|
- iou_nms_threshold: 0.5 (NMS去重阈值)
|
|
|
- rtdetr_conf_threshold: 0.5 (RT-DETR置信度阈值)
|
|
|
- enable_ocr_compensation: True (启用OCR补偿)
|
|
|
- - skip_rtdetr_for_txt_pdf: True (文字PDF跳过RT-DETR)
|
|
|
- enable_boundary_noise_filter: True (启用边界噪声过滤)
|
|
|
"""
|
|
|
self.rtdetr_detector = rtdetr_detector
|
|
|
@@ -54,12 +56,11 @@ class CellFusionEngine:
|
|
|
self.iou_nms_threshold = self.config.get('iou_nms_threshold', 0.5)
|
|
|
self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
|
|
|
self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
|
|
|
- self.skip_rtdetr_for_txt_pdf = self.config.get('skip_rtdetr_for_txt_pdf', True)
|
|
|
self.enable_boundary_noise_filter = self.config.get('enable_boundary_noise_filter', True)
|
|
|
|
|
|
logger.info(f"🔧 CellFusionEngine initialized: "
|
|
|
f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
|
|
|
- f"iou_merge={self.iou_merge_threshold}, skip_txt_pdf={self.skip_rtdetr_for_txt_pdf}, "
|
|
|
+ f"iou_merge={self.iou_merge_threshold}, ocr_comp={self.enable_ocr_compensation}, "
|
|
|
f"boundary_filter={self.enable_boundary_noise_filter}")
|
|
|
|
|
|
def should_use_rtdetr(
|
|
|
@@ -79,23 +80,18 @@ class CellFusionEngine:
|
|
|
Returns:
|
|
|
是否使用 RT-DETR
|
|
|
"""
|
|
|
- # 策略1: 文字PDF跳过RT-DETR(无噪声,UNet结果已足够准确)
|
|
|
- if pdf_type == 'txt' and self.skip_rtdetr_for_txt_pdf:
|
|
|
- logger.debug(f"📄 Text PDF detected, skip RT-DETR (UNet cells: {unet_cell_count})")
|
|
|
- return False
|
|
|
-
|
|
|
- # 策略2: 如果 RT-DETR 检测器未初始化,跳过
|
|
|
+ # 策略1: 如果 RT-DETR 检测器未初始化,跳过
|
|
|
if self.rtdetr_detector is None:
|
|
|
logger.debug("⚠️ RT-DETR detector not initialized, skip fusion")
|
|
|
return False
|
|
|
|
|
|
- # 策略3: UNet检测结果为空,必须使用RT-DETR补救
|
|
|
+ # 策略2: UNet检测结果为空,必须使用RT-DETR补救
|
|
|
if unet_cell_count == 0:
|
|
|
logger.info("🚨 UNet detected 0 cells, force enable RT-DETR")
|
|
|
return True
|
|
|
|
|
|
- # 策略4: 扫描PDF,启用融合模式
|
|
|
- logger.debug(f"🔍 Scan PDF detected, enable RT-DETR fusion (UNet cells: {unet_cell_count})")
|
|
|
+ # 策略3: 启用融合模式(对所有PDF类型统一使用RT-DETR融合)
|
|
|
+ logger.debug(f"🔍 Enable RT-DETR fusion (UNet cells: {unet_cell_count})")
|
|
|
return True
|
|
|
|
|
|
def fuse(
|
|
|
@@ -144,10 +140,20 @@ class CellFusionEngine:
|
|
|
'ocr_compensated_count': 0
|
|
|
}
|
|
|
|
|
|
- # 如果不使用RT-DETR,直接返回UNet结果
|
|
|
+ # 如果不使用RT-DETR,返回UNet结果(但仍做OCR补偿)
|
|
|
if not use_rtdetr:
|
|
|
fused_cells = unet_cells.copy()
|
|
|
cell_labels = ['unet_only'] * len(fused_cells) # 所有都是UNet独有
|
|
|
+
|
|
|
+ # Phase 5: OCR 边缘补偿(RT-DETR不可用时也执行)
|
|
|
+ if self.enable_ocr_compensation and ocr_boxes:
|
|
|
+ compensated_cells, compensated_labels = self._compensate_edge_cells(
|
|
|
+ fused_cells, cell_labels, ocr_boxes
|
|
|
+ )
|
|
|
+ fused_cells.extend(compensated_cells)
|
|
|
+ cell_labels.extend(compensated_labels)
|
|
|
+ fusion_stats['ocr_compensated_count'] = len(compensated_cells)
|
|
|
+
|
|
|
fusion_stats['fused_count'] = len(fused_cells)
|
|
|
|
|
|
logger.info(f"📊 Fusion (UNet-only): {len(unet_cells)} → {len(fused_cells)} cells")
|
|
|
@@ -164,13 +170,6 @@ class CellFusionEngine:
|
|
|
rtdetr_cells = [res['bbox'] for res in rtdetr_results]
|
|
|
rtdetr_scores = [res['score'] for res in rtdetr_results]
|
|
|
fusion_stats['rtdetr_count'] = len(rtdetr_cells)
|
|
|
- rtdetr_bbox = [
|
|
|
- min(rtdetr_cells, key=lambda box: box[0])[0],
|
|
|
- min(rtdetr_cells, key=lambda box: box[1])[1],
|
|
|
- max(rtdetr_cells, key=lambda box: box[2])[2],
|
|
|
- max(rtdetr_cells, key=lambda box: box[3])[3]
|
|
|
- ] if rtdetr_cells else [0,0,0,0]
|
|
|
-
|
|
|
logger.debug(f"RT-DETR detected {len(rtdetr_cells)} cells")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"⚠️ RT-DETR detection failed: {e}, fallback to UNet-only")
|
|
|
@@ -179,8 +178,14 @@ class CellFusionEngine:
|
|
|
return fused_cells, fusion_stats
|
|
|
|
|
|
# Phase 2: 智能融合
|
|
|
+ # 使用稳健边界估计(避免单个超大单元格撑开边界)
|
|
|
+ table_bbox = self._estimate_robust_table_bbox(rtdetr_cells)
|
|
|
+
|
|
|
+ # 将所有单元格的边界限制在表格边界内
|
|
|
+ # rtdetr_cells = self._clip_cells_to_bbox(rtdetr_cells, table_bbox)
|
|
|
+
|
|
|
fused_cells, merge_stats, cell_labels = self._fuse_cells(
|
|
|
- unet_bbox, unet_cells, rtdetr_cells, rtdetr_scores
|
|
|
+ table_bbox, unet_cells, rtdetr_cells, rtdetr_scores
|
|
|
)
|
|
|
fusion_stats['merged_count'] = merge_stats['merged']
|
|
|
fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
|
|
|
@@ -194,19 +199,29 @@ class CellFusionEngine:
|
|
|
# Phase 4: 边界噪声过滤(过滤掉边界的 unet_only 噪声单元格)
|
|
|
if self.enable_boundary_noise_filter:
|
|
|
fused_cells, cell_labels, noise_filtered = self._filter_boundary_noise(
|
|
|
- fused_cells, cell_labels, ocr_boxes, rtdetr_bbox
|
|
|
+ fused_cells, cell_labels, ocr_boxes, table_bbox
|
|
|
)
|
|
|
fusion_stats['noise_filtered_count'] = noise_filtered
|
|
|
else:
|
|
|
fusion_stats['noise_filtered_count'] = 0
|
|
|
noise_filtered = 0
|
|
|
|
|
|
+ # Phase 5: OCR 边缘补偿(补偿有 OCR 文本但无单元格覆盖的位置)
|
|
|
+ if self.enable_ocr_compensation and ocr_boxes:
|
|
|
+ compensated_cells, compensated_labels = self._compensate_edge_cells(
|
|
|
+ fused_cells, cell_labels, ocr_boxes
|
|
|
+ )
|
|
|
+ fused_cells.extend(compensated_cells)
|
|
|
+ cell_labels.extend(compensated_labels)
|
|
|
+ fusion_stats['ocr_compensated_count'] = len(compensated_cells)
|
|
|
+
|
|
|
fusion_stats['fused_count'] = len(fused_cells)
|
|
|
|
|
|
logger.info(
|
|
|
f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
|
|
|
f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
|
|
|
- f"Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, Final={len(fused_cells)}"
|
|
|
+ f"Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, "
|
|
|
+ f"OCRCompensated={fusion_stats.get('ocr_compensated_count', 0)}, Final={len(fused_cells)}"
|
|
|
)
|
|
|
|
|
|
# 可视化(调试)
|
|
|
@@ -220,7 +235,7 @@ class CellFusionEngine:
|
|
|
|
|
|
def _fuse_cells(
|
|
|
self,
|
|
|
- unet_bbox: List[float],
|
|
|
+ table_bbox: List[float],
|
|
|
unet_cells: List[List[float]],
|
|
|
rtdetr_cells: List[List[float]],
|
|
|
rtdetr_scores: List[float]
|
|
|
@@ -243,7 +258,7 @@ class CellFusionEngine:
|
|
|
- 总覆盖率>40%(所有UNet面积之和 / RT-DETR面积)
|
|
|
|
|
|
Args:
|
|
|
- unet_bbox: UNet单元格的边界框 [x1, y1, x2, y2]
|
|
|
+ table_bbox: 表格的边界框 [x1, y1, x2, y2]
|
|
|
unet_cells: UNet单元格列表
|
|
|
rtdetr_cells: RT-DETR单元格列表
|
|
|
rtdetr_scores: RT-DETR置信度列表
|
|
|
@@ -318,18 +333,18 @@ class CellFusionEngine:
|
|
|
|
|
|
# 如果覆盖率>50%,说明这是一个真实的合并单元格
|
|
|
if coverage > 0.5:
|
|
|
- # 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过unet_bbox范围
|
|
|
+ # 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过table_bbox范围
|
|
|
fused_cell = [
|
|
|
min(bounding_x1, rtdetr_cell[0]),
|
|
|
min(bounding_y1, rtdetr_cell[1]),
|
|
|
max(bounding_x2, rtdetr_cell[2]),
|
|
|
max(bounding_y2, rtdetr_cell[3])
|
|
|
]
|
|
|
- # x限制在unet_bbox范围内
|
|
|
- fused_cell[0] = max(fused_cell[0], unet_bbox[0])
|
|
|
- # fused_cell[1] = max(fused_cell[1], unet_bbox[1])
|
|
|
- fused_cell[2] = min(fused_cell[2], unet_bbox[2])
|
|
|
- # fused_cell[3] = min(fused_cell[3], unet_bbox[3])
|
|
|
+ # x限制在table_bbox范围内
|
|
|
+ fused_cell[0] = max(fused_cell[0], table_bbox[0])
|
|
|
+ # fused_cell[1] = max(fused_cell[1], table_bbox[1])
|
|
|
+ fused_cell[2] = min(fused_cell[2], table_bbox[2])
|
|
|
+ # fused_cell[3] = min(fused_cell[3], table_bbox[3])
|
|
|
fused_cells.append(fused_cell)
|
|
|
cell_labels.append('merged_span') # 标记为合并单元格
|
|
|
rtdetr_matched[rt_idx] = True
|
|
|
@@ -364,7 +379,7 @@ class CellFusionEngine:
|
|
|
if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
|
|
|
# 高IoU:加权平均合并
|
|
|
merged_cell = self._weighted_merge_bbox(
|
|
|
- unet_bbox,
|
|
|
+ table_bbox,
|
|
|
unet_cell,
|
|
|
rtdetr_cells[best_match_idx],
|
|
|
self.unet_weight,
|
|
|
@@ -384,17 +399,148 @@ class CellFusionEngine:
|
|
|
# Step 3: 补充 RT-DETR 独有的高置信度单元格
|
|
|
for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
|
|
|
if not rtdetr_matched[idx] and score > 0.7:
|
|
|
- # rtdetr_cell不能超出unet_bbox范围, x方向分别限制
|
|
|
- rtdetr_cell[0] = max(rtdetr_cell[0], unet_bbox[0])
|
|
|
- # rtdetr_cell[1] = max(rtdetr_cell[1], unet_bbox[1])
|
|
|
- rtdetr_cell[2] = min(rtdetr_cell[2], unet_bbox[2])
|
|
|
- # rtdetr_cell[3] = min(rtdetr_cell[3], unet_bbox[3])
|
|
|
+ # rtdetr_cell不能超出table_bbox范围, x方向分别限制
|
|
|
+ rtdetr_cell[0] = max(rtdetr_cell[0], table_bbox[0])
|
|
|
+ # rtdetr_cell[1] = max(rtdetr_cell[1], table_bbox[1])
|
|
|
+ rtdetr_cell[2] = min(rtdetr_cell[2], table_bbox[2])
|
|
|
+ # rtdetr_cell[3] = min(rtdetr_cell[3], table_bbox[3])
|
|
|
fused_cells.append(rtdetr_cell)
|
|
|
cell_labels.append('rtdetr_only') # 标记为RT-DETR独有
|
|
|
stats['added'] += 1
|
|
|
|
|
|
return fused_cells, stats, cell_labels
|
|
|
|
|
|
+ def _estimate_robust_table_bbox(
|
|
|
+ self,
|
|
|
+ rtdetr_cells: List[List[float]],
|
|
|
+ cluster_tolerance: float = 10.0
|
|
|
+ ) -> List[float]:
|
|
|
+ """
|
|
|
+ 稳健的表格边界估计
|
|
|
+
|
|
|
+ 使用聚类方法找到"主流"的左右边界,避免单个超大单元格撑开边界。
|
|
|
+
|
|
|
+ 算法:
|
|
|
+ 1. 收集所有单元格的左边界x1和右边界x2
|
|
|
+ 2. 对x1聚类,选择支持度最高的聚类中心作为表格左边界
|
|
|
+ 3. 对x2聚类,选择支持度最高的聚类中心作为表格右边界
|
|
|
+ 4. y方向使用简单的min/max(行高变化大,不适合聚类)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ rtdetr_cells: RT-DETR单元格列表
|
|
|
+ cluster_tolerance: 聚类容差(像素)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ table_bbox: [x1, y1, x2, y2]
|
|
|
+ """
|
|
|
+ all_cells = rtdetr_cells
|
|
|
+ if not all_cells:
|
|
|
+ return [0, 0, 0, 0]
|
|
|
+
|
|
|
+ # 收集所有边界坐标
|
|
|
+ x1_coords = [cell[0] for cell in all_cells]
|
|
|
+ x2_coords = [cell[2] for cell in all_cells]
|
|
|
+ y1_coords = [cell[1] for cell in all_cells]
|
|
|
+ y2_coords = [cell[3] for cell in all_cells]
|
|
|
+
|
|
|
+ # 对x1聚类,找主流左边界
|
|
|
+ robust_x1 = self._find_dominant_boundary(x1_coords, cluster_tolerance, mode='min')
|
|
|
+ # 对x2聚类,找主流右边界
|
|
|
+ robust_x2 = self._find_dominant_boundary(x2_coords, cluster_tolerance, mode='max')
|
|
|
+ # y方向直接取极值
|
|
|
+ robust_y1 = min(y1_coords)
|
|
|
+ robust_y2 = max(y2_coords)
|
|
|
+
|
|
|
+ logger.debug(f"📐 稳健边界估计: x=[{robust_x1:.1f}, {robust_x2:.1f}], "
|
|
|
+ f"原始x范围=[{min(x1_coords):.1f}, {max(x2_coords):.1f}]")
|
|
|
+
|
|
|
+ return [robust_x1, robust_y1, robust_x2, robust_y2]
|
|
|
+
|
|
|
+ def _find_dominant_boundary(
|
|
|
+ self,
|
|
|
+ coords: List[float],
|
|
|
+ tolerance: float,
|
|
|
+ mode: str = 'min'
|
|
|
+ ) -> float:
|
|
|
+ """
|
|
|
+ 找到边界方向上的稳健值
|
|
|
+
|
|
|
+ 算法:从边界方向开始,找到第一个支持度足够的聚类
|
|
|
+ - 对于左边界 (mode='min'):从最小值开始,排除孤立的异常小值
|
|
|
+ - 对于右边界 (mode='max'):从最大值开始,排除孤立的异常大值
|
|
|
+
|
|
|
+ Args:
|
|
|
+ coords: 坐标列表
|
|
|
+ tolerance: 聚类容差
|
|
|
+ mode: 'min' 表示找最小值方向的主流边界,'max' 表示找最大值方向
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 稳健的边界值
|
|
|
+ """
|
|
|
+ if not coords:
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ if len(coords) == 1:
|
|
|
+ return coords[0]
|
|
|
+
|
|
|
+ # 简单聚类
|
|
|
+ sorted_coords = sorted(coords)
|
|
|
+ clusters = []
|
|
|
+ curr_cluster = [sorted_coords[0]]
|
|
|
+
|
|
|
+ for x in sorted_coords[1:]:
|
|
|
+ if x - curr_cluster[-1] < tolerance:
|
|
|
+ curr_cluster.append(x)
|
|
|
+ else:
|
|
|
+ clusters.append(curr_cluster)
|
|
|
+ curr_cluster = [x]
|
|
|
+ clusters.append(curr_cluster)
|
|
|
+
|
|
|
+ # 根据边界方向找稳健值
|
|
|
+ min_support = 2 # 最小支持度阈值
|
|
|
+
|
|
|
+ if mode == 'min':
|
|
|
+ # 对于左边界:从最小值方向开始,找第一个支持度足够的聚类
|
|
|
+ # clusters 已按值从小到大排序
|
|
|
+ for cluster in clusters:
|
|
|
+ if len(cluster) >= min_support:
|
|
|
+ return sum(cluster) / len(cluster)
|
|
|
+ # 如果所有聚类支持度都不够,返回最小值
|
|
|
+ return sorted_coords[0]
|
|
|
+ else: # mode == 'max'
|
|
|
+ # 对于右边界:从最大值方向开始,找第一个支持度足够的聚类
|
|
|
+ for cluster in reversed(clusters):
|
|
|
+ if len(cluster) >= min_support:
|
|
|
+ return sum(cluster) / len(cluster)
|
|
|
+ # 如果所有聚类支持度都不够,返回最大值
|
|
|
+ return sorted_coords[-1]
|
|
|
+
|
|
|
+ def _clip_cells_to_bbox(
|
|
|
+ self,
|
|
|
+ cells: List[List[float]],
|
|
|
+ bbox: List[float]
|
|
|
+ ) -> List[List[float]]:
|
|
|
+ """
|
|
|
+ 将单元格边界裁剪到指定bbox范围内
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cells: 单元格列表
|
|
|
+ bbox: 限制边界 [x1, y1, x2, y2]
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 裁剪后的单元格列表
|
|
|
+ """
|
|
|
+ clipped = []
|
|
|
+ for cell in cells:
|
|
|
+ clipped_cell = [
|
|
|
+ max(cell[0], bbox[0]), # x1
|
|
|
+ cell[1], # y1 不裁剪(行高变化大)
|
|
|
+ min(cell[2], bbox[2]), # x2
|
|
|
+ cell[3] # y2 不裁剪
|
|
|
+ ]
|
|
|
+ clipped.append(clipped_cell)
|
|
|
+ return clipped
|
|
|
+
|
|
|
def _calc_bbox_area(self, bbox: List[float]) -> float:
|
|
|
"""计算bbox面积"""
|
|
|
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
|
|
@@ -558,6 +704,157 @@ class CellFusionEngine:
|
|
|
|
|
|
return filtered_cells, filtered_labels, filtered_count
|
|
|
|
|
|
+ def _compensate_edge_cells(
|
|
|
+ self,
|
|
|
+ fused_cells: List[List[float]],
|
|
|
+ cell_labels: List[str],
|
|
|
+ ocr_boxes: List[Dict[str, Any]],
|
|
|
+ min_confidence: float = 0.7
|
|
|
+ ) -> Tuple[List[List[float]], List[str]]:
|
|
|
+ """
|
|
|
+ 简洁的 OCR 边缘补偿算法
|
|
|
+
|
|
|
+ 策略:仅补偿"有 OCR 文本但无单元格覆盖"的位置
|
|
|
+
|
|
|
+ 算法流程:
|
|
|
+ 1. 从融合后的单元格提取所有唯一的 x/y 坐标构建网格分割线
|
|
|
+ 2. 遍历 OCR 文本框,若其中心点不在任何现有单元格内
|
|
|
+ 3. 用网格分割线找到最近的网格边界,生成补偿单元格
|
|
|
+
|
|
|
+ Args:
|
|
|
+ fused_cells: 融合后的单元格列表 [[x1,y1,x2,y2], ...]
|
|
|
+ cell_labels: 单元格标签列表
|
|
|
+ ocr_boxes: OCR结果列表 [{'bbox': [x1,y1,x2,y2], 'text': str, 'confidence': float}, ...]
|
|
|
+ min_confidence: OCR最小置信度阈值
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (compensated_cells, compensated_labels)
|
|
|
+ - compensated_cells: 补偿的单元格列表
|
|
|
+ - compensated_labels: 补偿单元格的标签列表(全为 'ocr_compensated')
|
|
|
+ """
|
|
|
+ if not fused_cells or not ocr_boxes:
|
|
|
+ return [], []
|
|
|
+
|
|
|
+ # Step 1: 从融合后单元格提取网格分割线
|
|
|
+ x_coords = set()
|
|
|
+ y_coords = set()
|
|
|
+ for cell in fused_cells:
|
|
|
+ x1, y1, x2, y2 = cell
|
|
|
+ x_coords.add(x1)
|
|
|
+ x_coords.add(x2)
|
|
|
+ y_coords.add(y1)
|
|
|
+ y_coords.add(y2)
|
|
|
+
|
|
|
+ x_dividers = sorted(x_coords)
|
|
|
+ y_dividers = sorted(y_coords)
|
|
|
+
|
|
|
+ if len(x_dividers) < 2 or len(y_dividers) < 2:
|
|
|
+ logger.debug("📊 OCR补偿: 网格分割线不足,跳过")
|
|
|
+ return [], []
|
|
|
+
|
|
|
+ logger.debug(f"📊 OCR补偿: 构建网格 {len(y_dividers)-1}行 × {len(x_dividers)-1}列")
|
|
|
+
|
|
|
+ # Step 2: 过滤有效 OCR
|
|
|
+ valid_ocr = [
|
|
|
+ ocr for ocr in ocr_boxes
|
|
|
+ if ocr.get('confidence', 1.0) >= min_confidence
|
|
|
+ and len(ocr.get('text', '').strip()) > 0
|
|
|
+ and len(ocr.get('bbox', [])) >= 4
|
|
|
+ ]
|
|
|
+
|
|
|
+ if not valid_ocr:
|
|
|
+ logger.debug("📊 OCR补偿: 无有效OCR")
|
|
|
+ return [], []
|
|
|
+
|
|
|
+ # Step 3: 检查每个 OCR 中心点是否在现有单元格内
|
|
|
+ compensated_cells = []
|
|
|
+ compensated_labels = []
|
|
|
+
|
|
|
+ for ocr in valid_ocr:
|
|
|
+ ocr_bbox = ocr['bbox']
|
|
|
+ ocr_text = ocr.get('text', '')[:20]
|
|
|
+
|
|
|
+ # 计算 OCR 中心点
|
|
|
+ if len(ocr_bbox) == 8: # poly format
|
|
|
+ ocr_cx = (ocr_bbox[0] + ocr_bbox[2] + ocr_bbox[4] + ocr_bbox[6]) / 4
|
|
|
+ ocr_cy = (ocr_bbox[1] + ocr_bbox[3] + ocr_bbox[5] + ocr_bbox[7]) / 4
|
|
|
+ else: # bbox format [x1, y1, x2, y2]
|
|
|
+ ocr_cx = (ocr_bbox[0] + ocr_bbox[2]) / 2
|
|
|
+ ocr_cy = (ocr_bbox[1] + ocr_bbox[3]) / 2
|
|
|
+
|
|
|
+ # 检查中心点是否在任何现有单元格内
|
|
|
+ is_covered = False
|
|
|
+ for cell in fused_cells:
|
|
|
+ x1, y1, x2, y2 = cell
|
|
|
+ if x1 <= ocr_cx <= x2 and y1 <= ocr_cy <= y2:
|
|
|
+ is_covered = True
|
|
|
+ break
|
|
|
+
|
|
|
+ # 也检查已补偿的单元格
|
|
|
+ if not is_covered:
|
|
|
+ for cell in compensated_cells:
|
|
|
+ x1, y1, x2, y2 = cell
|
|
|
+ if x1 <= ocr_cx <= x2 and y1 <= ocr_cy <= y2:
|
|
|
+ is_covered = True
|
|
|
+ break
|
|
|
+
|
|
|
+ if is_covered:
|
|
|
+ continue # 已被覆盖,跳过
|
|
|
+
|
|
|
+ # Step 4: 找到中心点所在的网格单元格
|
|
|
+ # 找到左边界(最大的 x <= ocr_cx)
|
|
|
+ cell_x1 = None
|
|
|
+ for x in x_dividers:
|
|
|
+ if x <= ocr_cx:
|
|
|
+ cell_x1 = x
|
|
|
+ else:
|
|
|
+ break
|
|
|
+
|
|
|
+ # 找到右边界(最小的 x > ocr_cx)
|
|
|
+ cell_x2 = None
|
|
|
+ for x in x_dividers:
|
|
|
+ if x > ocr_cx:
|
|
|
+ cell_x2 = x
|
|
|
+ break
|
|
|
+
|
|
|
+ # 找到上边界(最大的 y <= ocr_cy)
|
|
|
+ cell_y1 = None
|
|
|
+ for y in y_dividers:
|
|
|
+ if y <= ocr_cy:
|
|
|
+ cell_y1 = y
|
|
|
+ else:
|
|
|
+ break
|
|
|
+
|
|
|
+ # 找到下边界(最小的 y > ocr_cy)
|
|
|
+ cell_y2 = None
|
|
|
+ for y in y_dividers:
|
|
|
+ if y > ocr_cy:
|
|
|
+ cell_y2 = y
|
|
|
+ break
|
|
|
+
|
|
|
+ # 如果找不到完整的边界,跳过
|
|
|
+ if cell_x1 is None or cell_x2 is None or cell_y1 is None or cell_y2 is None:
|
|
|
+ logger.debug(f"⏭️ 跳过OCR '{ocr_text}': 中心点({ocr_cx:.1f},{ocr_cy:.1f})不在网格范围内")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 生成补偿单元格
|
|
|
+ compensated_bbox = [float(cell_x1), float(cell_y1), float(cell_x2), float(cell_y2)]
|
|
|
+ compensated_cells.append(compensated_bbox)
|
|
|
+ compensated_labels.append('ocr_compensated')
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ f"✅ OCR补偿: '{ocr_text}' | "
|
|
|
+ f"center=({ocr_cx:.1f},{ocr_cy:.1f}) → "
|
|
|
+ f"cell=[{cell_x1:.1f},{cell_y1:.1f},{cell_x2:.1f},{cell_y2:.1f}]"
|
|
|
+ )
|
|
|
+
|
|
|
+ if compensated_cells:
|
|
|
+ logger.info(f"🎉 OCR边缘补偿完成: +{len(compensated_cells)}个单元格")
|
|
|
+ else:
|
|
|
+ logger.debug("📊 OCR补偿: 所有OCR已被现有单元格覆盖")
|
|
|
+
|
|
|
+ return compensated_cells, compensated_labels
|
|
|
+
|
|
|
def _visualize_fusion(
|
|
|
self,
|
|
|
table_image: np.ndarray,
|
|
|
@@ -609,6 +906,7 @@ class CellFusionEngine:
|
|
|
merged_cells_1to1 = [] # 1:1融合单元格(黄色)
|
|
|
merged_cells_span = [] # 合并单元格(品红色,RT-DETR检测的跨格单元格)
|
|
|
new_cells = [] # 新增单元格(紫色)
|
|
|
+ ocr_compensated = [] # OCR补偿单元格(橙色)
|
|
|
|
|
|
for fused_cell, label in zip(fused_cells, cell_labels):
|
|
|
if label == 'unet_only':
|
|
|
@@ -621,6 +919,8 @@ class CellFusionEngine:
|
|
|
merged_cells_span.append(fused_cell)
|
|
|
elif label == 'new':
|
|
|
new_cells.append(fused_cell)
|
|
|
+ elif label == 'ocr_compensated':
|
|
|
+ ocr_compensated.append(fused_cell)
|
|
|
|
|
|
# 绘制不同类型的单元格
|
|
|
for cell in unet_only:
|
|
|
@@ -643,6 +943,10 @@ class CellFusionEngine:
|
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
cv2.rectangle(img3, (x1, y1), (x2, y2), (128, 0, 128), 2) # 紫色 - 新增
|
|
|
|
|
|
+ for cell in ocr_compensated:
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 165, 255), 3) # 橙色 - OCR补偿(加粗)
|
|
|
+
|
|
|
# 添加图例
|
|
|
legend_y = 30
|
|
|
cv2.putText(img3, f"Fused ({len(fused_cells)})", (10, legend_y),
|
|
|
@@ -663,6 +967,10 @@ class CellFusionEngine:
|
|
|
legend_y += 30
|
|
|
cv2.putText(img3, f"Purple: New ({len(new_cells)})", (10, legend_y),
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (128, 0, 128), 2)
|
|
|
+ if ocr_compensated:
|
|
|
+ legend_y += 30
|
|
|
+ cv2.putText(img3, f"Orange: OCR Compensated ({len(ocr_compensated)})", (10, legend_y),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 165, 255), 2)
|
|
|
|
|
|
# 拼接三栏对比
|
|
|
vis_canvas = np.zeros((h, w * 3, 3), dtype=np.uint8)
|
|
|
@@ -675,7 +983,8 @@ class CellFusionEngine:
|
|
|
cv2.imwrite(str(output_path), vis_canvas)
|
|
|
logger.info(f"💾 融合可视化已保存: {output_path}")
|
|
|
logger.info(f" 📊 单元格分类: UNet独有={len(unet_only)}, RT-DETR独有={len(rtdetr_only)}, "
|
|
|
- f"1:1融合={len(merged_cells_1to1)}, 合并单元格={len(merged_cells_span)}, 新增={len(new_cells)}")
|
|
|
+ f"1:1融合={len(merged_cells_1to1)}, 合并单元格={len(merged_cells_span)}, "
|
|
|
+ f"新增={len(new_cells)}, OCR补偿={len(ocr_compensated)}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to visualize fusion: {e}")
|