|
|
@@ -0,0 +1,644 @@
|
|
|
+"""多源单元格融合引擎:融合 UNet、RT-DETR 和 OCR 结果"""
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+from typing import Dict, List, Tuple, Optional, Any
|
|
|
+from loguru import logger
|
|
|
+
|
|
|
+try:
|
|
|
+ from ocr_utils.coordinate_utils import CoordinateUtils
|
|
|
+except ImportError:
|
|
|
+ from ...core.coordinate_utils import CoordinateUtils
|
|
|
+
|
|
|
+
|
|
|
+class CellFusionEngine:
|
|
|
+ """
|
|
|
+ 多源单元格融合引擎
|
|
|
+
|
|
|
+ 融合策略:
|
|
|
+ 1. UNet 连通域检测(结构性强,适合清晰表格)
|
|
|
+ 2. RT-DETR 端到端检测(鲁棒性强,适合噪声表格)
|
|
|
+ 3. OCR 文本位置(验证单元格存在性)
|
|
|
+
|
|
|
+ 自适应策略:
|
|
|
+ - 文字PDF (pdf_type='txt'): 跳过 RT-DETR,纯 UNet 模式(无噪声)
|
|
|
+ - 扫描PDF (pdf_type='ocr'): 启用融合模式(有噪声)
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ rtdetr_detector: Optional[Any] = None,
|
|
|
+ config: Optional[Dict[str, Any]] = None
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ 初始化融合引擎
|
|
|
+
|
|
|
+ Args:
|
|
|
+ rtdetr_detector: RT-DETR 检测器实例(可选)
|
|
|
+ config: 融合配置
|
|
|
+ - unet_weight: 0.6 (UNet 权重)
|
|
|
+ - rtdetr_weight: 0.4 (RT-DETR 权重)
|
|
|
+ - iou_merge_threshold: 0.7 (高IoU合并阈值)
|
|
|
+ - iou_nms_threshold: 0.5 (NMS去重阈值)
|
|
|
+ - rtdetr_conf_threshold: 0.5 (RT-DETR置信度阈值)
|
|
|
+ - enable_ocr_compensation: True (启用OCR补偿)
|
|
|
+ - skip_rtdetr_for_txt_pdf: True (文字PDF跳过RT-DETR)
|
|
|
+ """
|
|
|
+ self.rtdetr_detector = rtdetr_detector
|
|
|
+ self.config = config or {}
|
|
|
+
|
|
|
+ # 融合参数
|
|
|
+ self.unet_weight = self.config.get('unet_weight', 0.6)
|
|
|
+ self.rtdetr_weight = self.config.get('rtdetr_weight', 0.4)
|
|
|
+ self.iou_merge_threshold = self.config.get('iou_merge_threshold', 0.7)
|
|
|
+ self.iou_nms_threshold = self.config.get('iou_nms_threshold', 0.5)
|
|
|
+ self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
|
|
|
+ self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
|
|
|
+ self.skip_rtdetr_for_txt_pdf = self.config.get('skip_rtdetr_for_txt_pdf', True)
|
|
|
+
|
|
|
+ logger.info(f"🔧 CellFusionEngine initialized: "
|
|
|
+ f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
|
|
|
+ f"iou_merge={self.iou_merge_threshold}, skip_txt_pdf={self.skip_rtdetr_for_txt_pdf}")
|
|
|
+
|
|
|
+ def should_use_rtdetr(
|
|
|
+ self,
|
|
|
+ pdf_type: str,
|
|
|
+ unet_cell_count: int,
|
|
|
+ table_size: Tuple[int, int]
|
|
|
+ ) -> bool:
|
|
|
+ """
|
|
|
+ 判断是否需要使用 RT-DETR 检测(自适应策略)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pdf_type: PDF类型 ('txt' 或 'ocr')
|
|
|
+ unet_cell_count: UNet检测到的单元格数量
|
|
|
+ table_size: 表格尺寸 (width, height)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 是否使用 RT-DETR
|
|
|
+ """
|
|
|
+ # 策略1: 文字PDF跳过RT-DETR(无噪声,UNet结果已足够准确)
|
|
|
+ if pdf_type == 'txt' and self.skip_rtdetr_for_txt_pdf:
|
|
|
+ logger.debug(f"📄 Text PDF detected, skip RT-DETR (UNet cells: {unet_cell_count})")
|
|
|
+ return False
|
|
|
+
|
|
|
+ # 策略2: 如果 RT-DETR 检测器未初始化,跳过
|
|
|
+ if self.rtdetr_detector is None:
|
|
|
+ logger.debug("⚠️ RT-DETR detector not initialized, skip fusion")
|
|
|
+ return False
|
|
|
+
|
|
|
+ # 策略3: UNet检测结果为空,必须使用RT-DETR补救
|
|
|
+ if unet_cell_count == 0:
|
|
|
+ logger.info("🚨 UNet detected 0 cells, force enable RT-DETR")
|
|
|
+ return True
|
|
|
+
|
|
|
+ # 策略4: 扫描PDF,启用融合模式
|
|
|
+ logger.debug(f"🔍 Scan PDF detected, enable RT-DETR fusion (UNet cells: {unet_cell_count})")
|
|
|
+ return True
|
|
|
+
|
|
|
+ def fuse(
|
|
|
+ self,
|
|
|
+ table_image: np.ndarray,
|
|
|
+ unet_cells: List[List[float]],
|
|
|
+ ocr_boxes: List[Dict[str, Any]],
|
|
|
+ pdf_type: str = 'ocr',
|
|
|
+ upscale: float = 1.0,
|
|
|
+ debug_dir: Optional[str] = None,
|
|
|
+ debug_prefix: str = "fusion"
|
|
|
+ ) -> Tuple[List[List[float]], Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 融合多源单元格检测结果
|
|
|
+
|
|
|
+ Args:
|
|
|
+ table_image: 表格图像(原图坐标系)
|
|
|
+ unet_cells: UNet检测的单元格列表 [[x1,y1,x2,y2], ...](原图坐标系)
|
|
|
+ ocr_boxes: OCR结果列表
|
|
|
+ pdf_type: PDF类型 ('txt' 或 'ocr')
|
|
|
+ upscale: UNet的上采样比例
|
|
|
+ debug_dir: 调试输出目录(可选)
|
|
|
+ debug_prefix: 调试文件前缀
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (fused_cells, fusion_stats)
|
|
|
+ - fused_cells: 融合后的单元格列表 [[x1,y1,x2,y2], ...]
|
|
|
+ - fusion_stats: 融合统计信息
|
|
|
+ """
|
|
|
+ h, w = table_image.shape[:2]
|
|
|
+
|
|
|
+ # 决策:是否使用 RT-DETR
|
|
|
+ use_rtdetr = self.should_use_rtdetr(pdf_type, len(unet_cells), (w, h))
|
|
|
+
|
|
|
+ fusion_stats = {
|
|
|
+ 'use_rtdetr': use_rtdetr,
|
|
|
+ 'unet_count': len(unet_cells),
|
|
|
+ 'rtdetr_count': 0,
|
|
|
+ 'fused_count': 0,
|
|
|
+ 'merged_count': 0,
|
|
|
+ 'added_count': 0,
|
|
|
+ 'ocr_compensated_count': 0
|
|
|
+ }
|
|
|
+
|
|
|
+ # 如果不使用RT-DETR,直接返回UNet结果
|
|
|
+ if not use_rtdetr:
|
|
|
+ fused_cells = unet_cells.copy()
|
|
|
+ cell_labels = ['unet_only'] * len(fused_cells) # 所有都是UNet独有
|
|
|
+ fusion_stats['fused_count'] = len(fused_cells)
|
|
|
+
|
|
|
+ # 可选:OCR补偿
|
|
|
+ if self.enable_ocr_compensation and ocr_boxes:
|
|
|
+ fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
|
|
|
+ fused_cells, cell_labels, ocr_boxes, (w, h)
|
|
|
+ )
|
|
|
+ fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
|
+ fusion_stats['fused_count'] = len(fused_cells)
|
|
|
+
|
|
|
+ logger.info(f"📊 Fusion (UNet-only): {len(unet_cells)} → {len(fused_cells)} cells")
|
|
|
+ return fused_cells, fusion_stats
|
|
|
+
|
|
|
+ # Phase 1: RT-DETR 检测
|
|
|
+ try:
|
|
|
+ rtdetr_results = self.rtdetr_detector.detect(
|
|
|
+ table_image,
|
|
|
+ conf_threshold=self.rtdetr_conf_threshold
|
|
|
+ )
|
|
|
+ # rtdetr_result从上到下,从左到右排序
|
|
|
+ rtdetr_results.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
|
|
|
+ rtdetr_cells = [res['bbox'] for res in rtdetr_results]
|
|
|
+ rtdetr_scores = [res['score'] for res in rtdetr_results]
|
|
|
+ fusion_stats['rtdetr_count'] = len(rtdetr_cells)
|
|
|
+
|
|
|
+ logger.debug(f"RT-DETR detected {len(rtdetr_cells)} cells")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"⚠️ RT-DETR detection failed: {e}, fallback to UNet-only")
|
|
|
+ fused_cells = unet_cells.copy()
|
|
|
+ fusion_stats['fused_count'] = len(fused_cells)
|
|
|
+ return fused_cells, fusion_stats
|
|
|
+
|
|
|
+ # Phase 2: 智能融合
|
|
|
+ fused_cells, merge_stats, cell_labels = self._fuse_cells(
|
|
|
+ unet_cells, rtdetr_cells, rtdetr_scores
|
|
|
+ )
|
|
|
+ fusion_stats['merged_count'] = merge_stats['merged']
|
|
|
+ fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
|
|
|
+ fusion_stats['added_count'] = merge_stats['added']
|
|
|
+
|
|
|
+ # Phase 3: NMS 去重
|
|
|
+ fused_cells = self._nms_filter(fused_cells, self.iou_nms_threshold)
|
|
|
+
|
|
|
+ # Phase 4: OCR 补偿(可选)
|
|
|
+ # if self.enable_ocr_compensation and ocr_boxes:
|
|
|
+ # fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
|
|
|
+ # fused_cells, cell_labels, ocr_boxes, (w, h)
|
|
|
+ # )
|
|
|
+ # fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
|
+
|
|
|
+ fusion_stats['fused_count'] = len(fused_cells)
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
|
|
|
+ f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
|
|
|
+ f"Added={merge_stats['added']}, Final={len(fused_cells)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 可视化(调试)
|
|
|
+ if debug_dir:
|
|
|
+ self._visualize_fusion(
|
|
|
+ table_image, unet_cells, rtdetr_cells, fused_cells, cell_labels,
|
|
|
+ debug_dir, debug_prefix
|
|
|
+ )
|
|
|
+
|
|
|
+ return fused_cells, fusion_stats
|
|
|
+
|
|
|
+ def _fuse_cells(
|
|
|
+ self,
|
|
|
+ unet_cells: List[List[float]],
|
|
|
+ rtdetr_cells: List[List[float]],
|
|
|
+ rtdetr_scores: List[float]
|
|
|
+ ) -> Tuple[List[List[float]], Dict[str, int], List[str]]:
|
|
|
+ """
|
|
|
+ 融合 UNet 和 RT-DETR 检测结果(增强版:支持合并单元格检测)
|
|
|
+
|
|
|
+ 融合规则:
|
|
|
+ 1. 检测RT-DETR的合并单元格(一对多匹配,基于包含关系)
|
|
|
+ - 判断RT-DETR单元格包含多少个UNet单元格
|
|
|
+ - 使用中心点+包含率判断(而非IoU)
|
|
|
+ 2. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并(一对一)
|
|
|
+ 3. RT-DETR 独有 + 高置信度 (>0.7) → 补充
|
|
|
+ 4. UNet 独有 → 保留
|
|
|
+
|
|
|
+ 包含关系判断逻辑:
|
|
|
+ - UNet单元格的中心点在RT-DETR内
|
|
|
+ - UNet单元格的50%以上面积在RT-DETR内
|
|
|
+ - RT-DETR包含≥2个UNet单元格
|
|
|
+ - 总覆盖率>40%(所有UNet面积之和 / RT-DETR面积)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ unet_cells: UNet单元格列表
|
|
|
+ rtdetr_cells: RT-DETR单元格列表
|
|
|
+ rtdetr_scores: RT-DETR置信度列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (fused_cells, stats, cell_labels)
|
|
|
+ - fused_cells: 融合后的单元格
|
|
|
+ - stats: {'merged': int, 'added': int, 'merged_cells': int}
|
|
|
+ - cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'new']
|
|
|
+ """
|
|
|
+ fused_cells = []
|
|
|
+ cell_labels = [] # 记录每个单元格的来源标签
|
|
|
+ unet_matched = [False] * len(unet_cells)
|
|
|
+ rtdetr_matched = [False] * len(rtdetr_cells)
|
|
|
+ stats = {'merged': 0, 'added': 0, 'merged_cells': 0}
|
|
|
+
|
|
|
+ # Step 1: 检测RT-DETR的合并单元格(一对多匹配)
|
|
|
+ # 遍历RT-DETR单元格,查找被包含的多个UNet单元格
|
|
|
+ for rt_idx, rtdetr_cell in enumerate(rtdetr_cells):
|
|
|
+ if rtdetr_matched[rt_idx]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 查找所有被当前RT-DETR单元格包含(或大部分包含)的UNet单元格
|
|
|
+ contained_unet = []
|
|
|
+ for u_idx, unet_cell in enumerate(unet_cells):
|
|
|
+ if unet_matched[u_idx]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 判断UNet单元格是否被RT-DETR单元格包含
|
|
|
+ # 方法1: 检查UNet的中心点是否在RT-DETR内
|
|
|
+ unet_cx = (unet_cell[0] + unet_cell[2]) / 2
|
|
|
+ unet_cy = (unet_cell[1] + unet_cell[3]) / 2
|
|
|
+
|
|
|
+ if (rtdetr_cell[0] <= unet_cx <= rtdetr_cell[2] and
|
|
|
+ rtdetr_cell[1] <= unet_cy <= rtdetr_cell[3]):
|
|
|
+ # UNet中心点在RT-DETR内,计算包含程度
|
|
|
+ # 计算UNet有多少面积在RT-DETR内
|
|
|
+ intersect_x1 = max(unet_cell[0], rtdetr_cell[0])
|
|
|
+ intersect_y1 = max(unet_cell[1], rtdetr_cell[1])
|
|
|
+ intersect_x2 = min(unet_cell[2], rtdetr_cell[2])
|
|
|
+ intersect_y2 = min(unet_cell[3], rtdetr_cell[3])
|
|
|
+
|
|
|
+ if intersect_x2 > intersect_x1 and intersect_y2 > intersect_y1:
|
|
|
+ intersect_area = (intersect_x2 - intersect_x1) * (intersect_y2 - intersect_y1)
|
|
|
+ unet_area = (unet_cell[2] - unet_cell[0]) * (unet_cell[3] - unet_cell[1])
|
|
|
+ contain_ratio = intersect_area / unet_area if unet_area > 0 else 0
|
|
|
+
|
|
|
+ # 如果UNet单元格的50%以上在RT-DETR内,认为被包含
|
|
|
+ if contain_ratio > 0.5:
|
|
|
+ contained_unet.append((u_idx, contain_ratio))
|
|
|
+
|
|
|
+ # 判断是否为合并单元格(RT-DETR包含多个UNet单元格)
|
|
|
+ if len(contained_unet) >= 2:
|
|
|
+ # 合并单元格场景:优先使用RT-DETR的大框
|
|
|
+ # 条件:1) 包含2个以上UNet单元格 2) RT-DETR置信度足够高
|
|
|
+ if rtdetr_scores[rt_idx] > 0.7:
|
|
|
+ # 计算总包含率:使用所有被包含UNet单元格的外接矩形面积 vs RT-DETR面积
|
|
|
+ # 使用外接矩形更合理,因为:
|
|
|
+ # 1. 合并单元格是一个完整区域,应包括单元格间隙
|
|
|
+ # 2. 避免重复计算相邻单元格的边界
|
|
|
+ # 3. 更准确反映覆盖率(如11个连续单元格应该接近100%覆盖)
|
|
|
+ unet_indices = [u_idx for u_idx, _ in contained_unet]
|
|
|
+ bounding_x1 = min(unet_cells[i][0] for i in unet_indices)
|
|
|
+ bounding_y1 = min(unet_cells[i][1] for i in unet_indices)
|
|
|
+ bounding_x2 = max(unet_cells[i][2] for i in unet_indices)
|
|
|
+ bounding_y2 = max(unet_cells[i][3] for i in unet_indices)
|
|
|
+ total_unet_area = (bounding_x2 - bounding_x1) * (bounding_y2 - bounding_y1)
|
|
|
+
|
|
|
+ rtdetr_area = self._calc_bbox_area(rtdetr_cell)
|
|
|
+ coverage = min(total_unet_area / rtdetr_area, 1.0) if rtdetr_area > 0 else 0
|
|
|
+
|
|
|
+ # 如果覆盖率>40%,说明这是一个真实的合并单元格
|
|
|
+ # 降低阈值从0.5到0.4,因为合并单元格可能包含很多空白区域
|
|
|
+ if coverage > 0.4:
|
|
|
+ # 认定为合并单元格,取bounding与RT-DETR的最大范围
|
|
|
+ fused_cell = [
|
|
|
+ min(bounding_x1, rtdetr_cell[0]),
|
|
|
+ min(bounding_y1, rtdetr_cell[1]),
|
|
|
+ max(bounding_x2, rtdetr_cell[2]),
|
|
|
+ max(bounding_y2, rtdetr_cell[3])
|
|
|
+ ]
|
|
|
+ fused_cells.append(fused_cell)
|
|
|
+ cell_labels.append('merged_span') # 标记为合并单元格
|
|
|
+ rtdetr_matched[rt_idx] = True
|
|
|
+ # 标记所有被包含的UNet单元格
|
|
|
+ for u_idx, contain_ratio in contained_unet:
|
|
|
+ unet_matched[u_idx] = True
|
|
|
+ stats['merged_cells'] += 1
|
|
|
+ logger.debug(
|
|
|
+ f"🔗 检测到合并单元格: RT-DETR[{rt_idx}] 包含 {len(contained_unet)} 个UNet单元格 "
|
|
|
+ f"(coverage={coverage:.2f}, score={rtdetr_scores[rt_idx]:.2f})"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Step 2: 一对一匹配(处理剩余的单元格)
|
|
|
+ for u_idx, unet_cell in enumerate(unet_cells):
|
|
|
+ if unet_matched[u_idx]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ best_match_idx = -1
|
|
|
+ best_iou = 0.0
|
|
|
+
|
|
|
+ # 查找最佳匹配的 RT-DETR 单元格
|
|
|
+ for rt_idx, rtdetr_cell in enumerate(rtdetr_cells):
|
|
|
+ if rtdetr_matched[rt_idx]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ iou = CoordinateUtils.calculate_iou(unet_cell, rtdetr_cell)
|
|
|
+ if iou > best_iou:
|
|
|
+ best_iou = iou
|
|
|
+ best_match_idx = rt_idx
|
|
|
+
|
|
|
+ # 判断是否合并
|
|
|
+ if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
|
|
|
+ # 高IoU:加权平均合并
|
|
|
+ merged_cell = self._weighted_merge_bbox(
|
|
|
+ unet_cell,
|
|
|
+ rtdetr_cells[best_match_idx],
|
|
|
+ self.unet_weight,
|
|
|
+ self.rtdetr_weight
|
|
|
+ )
|
|
|
+ fused_cells.append(merged_cell)
|
|
|
+ cell_labels.append('merged_1to1') # 标记为1:1融合
|
|
|
+ rtdetr_matched[best_match_idx] = True
|
|
|
+ unet_matched[u_idx] = True
|
|
|
+ stats['merged'] += 1
|
|
|
+ else:
|
|
|
+ # UNet 独有:保留
|
|
|
+ fused_cells.append(unet_cell)
|
|
|
+ cell_labels.append('unet_only') # 标记为UNet独有
|
|
|
+ unet_matched[u_idx] = True
|
|
|
+
|
|
|
+ # Step 3: 补充 RT-DETR 独有的高置信度单元格
|
|
|
+ for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
|
|
|
+ if not rtdetr_matched[idx] and score > 0.7:
|
|
|
+ fused_cells.append(rtdetr_cell)
|
|
|
+ cell_labels.append('rtdetr_only') # 标记为RT-DETR独有
|
|
|
+ stats['added'] += 1
|
|
|
+
|
|
|
+ return fused_cells, stats, cell_labels
|
|
|
+
|
|
|
+ def _calc_bbox_area(self, bbox: List[float]) -> float:
|
|
|
+ """计算bbox面积"""
|
|
|
+ return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
|
|
+
|
|
|
+ def _weighted_merge_bbox(
|
|
|
+ self,
|
|
|
+ bbox1: List[float],
|
|
|
+ bbox2: List[float],
|
|
|
+ weight1: float,
|
|
|
+ weight2: float
|
|
|
+ ) -> List[float]:
|
|
|
+ """
|
|
|
+ 加权平均合并两个 bbox
|
|
|
+
|
|
|
+ Args:
|
|
|
+ bbox1: [x1, y1, x2, y2]
|
|
|
+ bbox2: [x1, y1, x2, y2]
|
|
|
+ weight1: bbox1 的权重
|
|
|
+ weight2: bbox2 的权重
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ merged_bbox: [x1, y1, x2, y2]
|
|
|
+ """
|
|
|
+ return [
|
|
|
+ weight1 * bbox1[0] + weight2 * bbox2[0],
|
|
|
+ weight1 * bbox1[1] + weight2 * bbox2[1],
|
|
|
+ weight1 * bbox1[2] + weight2 * bbox2[2],
|
|
|
+ weight1 * bbox1[3] + weight2 * bbox2[3]
|
|
|
+ ]
|
|
|
+
|
|
|
+ def _nms_filter(
|
|
|
+ self,
|
|
|
+ cells: List[List[float]],
|
|
|
+ iou_threshold: float
|
|
|
+ ) -> List[List[float]]:
|
|
|
+ """
|
|
|
+ 简单 NMS 过滤(去除高度重叠的冗余框)
|
|
|
+
|
|
|
+ 策略:按面积排序,保留大框,移除与大框高IoU的小框
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cells: 单元格列表 [[x1,y1,x2,y2], ...]
|
|
|
+ iou_threshold: IoU阈值
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 过滤后的单元格列表
|
|
|
+ """
|
|
|
+ if len(cells) == 0:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 计算面积并排序(大框优先)
|
|
|
+ areas = [(x2 - x1) * (y2 - y1) for x1, y1, x2, y2 in cells]
|
|
|
+ sorted_indices = sorted(range(len(cells)), key=lambda i: areas[i], reverse=True)
|
|
|
+
|
|
|
+ keep = []
|
|
|
+ suppressed = [False] * len(cells)
|
|
|
+
|
|
|
+ for idx in sorted_indices:
|
|
|
+ if suppressed[idx]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ keep.append(cells[idx])
|
|
|
+
|
|
|
+ # 抑制与当前框高IoU的其他框
|
|
|
+ for other_idx in sorted_indices:
|
|
|
+ if other_idx == idx or suppressed[other_idx]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ iou = CoordinateUtils.calculate_iou(cells[idx], cells[other_idx])
|
|
|
+ if iou > iou_threshold:
|
|
|
+ suppressed[other_idx] = True
|
|
|
+
|
|
|
+ logger.debug(f"NMS: {len(cells)} → {len(keep)} cells (threshold={iou_threshold})")
|
|
|
+ return keep
|
|
|
+
|
|
|
+ def _compensate_with_ocr(
|
|
|
+ self,
|
|
|
+ cells: List[List[float]],
|
|
|
+ cell_labels: List[str],
|
|
|
+ ocr_boxes: List[Dict[str, Any]],
|
|
|
+ table_size: Tuple[int, int]
|
|
|
+ ) -> Tuple[List[List[float]], List[str], int]:
|
|
|
+ """
|
|
|
+ 使用 OCR 补偿遗漏的单元格
|
|
|
+
|
|
|
+ 策略:如果 OCR 文本没有匹配到任何单元格,创建新单元格
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cells: 现有单元格列表
|
|
|
+ cell_labels: 单元格标签列表
|
|
|
+ ocr_boxes: OCR结果列表
|
|
|
+ table_size: 表格尺寸 (width, height)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (compensated_cells, compensated_labels, compensation_count)
|
|
|
+ """
|
|
|
+ compensated = cells.copy()
|
|
|
+ compensated_labels = cell_labels.copy()
|
|
|
+ compensation_count = 0
|
|
|
+ w, h = table_size
|
|
|
+
|
|
|
+ for ocr in ocr_boxes:
|
|
|
+ ocr_bbox = ocr.get('bbox', [])
|
|
|
+ if not ocr_bbox or len(ocr_bbox) < 4:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 计算 OCR 中心点
|
|
|
+ if len(ocr_bbox) == 8: # poly format
|
|
|
+ ocr_cx = (ocr_bbox[0] + ocr_bbox[2] + ocr_bbox[4] + ocr_bbox[6]) / 4
|
|
|
+ ocr_cy = (ocr_bbox[1] + ocr_bbox[3] + ocr_bbox[5] + ocr_bbox[7]) / 4
|
|
|
+ else: # bbox format
|
|
|
+ ocr_cx = (ocr_bbox[0] + ocr_bbox[2]) / 2
|
|
|
+ ocr_cy = (ocr_bbox[1] + ocr_bbox[3]) / 2
|
|
|
+
|
|
|
+ # 检查是否在任何单元格内
|
|
|
+ is_covered = False
|
|
|
+ for cell in compensated:
|
|
|
+ x1, y1, x2, y2 = cell
|
|
|
+ if x1 <= ocr_cx <= x2 and y1 <= ocr_cy <= y2:
|
|
|
+ is_covered = True
|
|
|
+ break
|
|
|
+
|
|
|
+ # 如果孤立,创建新单元格
|
|
|
+ if not is_covered:
|
|
|
+ # 扩展 OCR bbox 作为新单元格
|
|
|
+ if len(ocr_bbox) == 8:
|
|
|
+ new_cell = [
|
|
|
+ float(max(0, min(ocr_bbox[0], ocr_bbox[6]) - 5)),
|
|
|
+ float(max(0, min(ocr_bbox[1], ocr_bbox[3]) - 5)),
|
|
|
+ float(min(w, max(ocr_bbox[2], ocr_bbox[4]) + 5)),
|
|
|
+ float(min(h, max(ocr_bbox[5], ocr_bbox[7]) + 5))
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ new_cell = [
|
|
|
+ float(max(0, ocr_bbox[0] - 5)),
|
|
|
+ float(max(0, ocr_bbox[1] - 5)),
|
|
|
+ float(min(w, ocr_bbox[2] + 5)),
|
|
|
+ float(min(h, ocr_bbox[3] + 5))
|
|
|
+ ]
|
|
|
+
|
|
|
+ compensated.append(new_cell)
|
|
|
+ compensated_labels.append('new') # 标记为新增(OCR补偿)
|
|
|
+ compensation_count += 1
|
|
|
+
|
|
|
+ if compensation_count > 0:
|
|
|
+ logger.debug(f"OCR compensation: added {compensation_count} cells")
|
|
|
+
|
|
|
+ return compensated, compensated_labels, compensation_count
|
|
|
+
|
|
|
+ def _visualize_fusion(
|
|
|
+ self,
|
|
|
+ table_image: np.ndarray,
|
|
|
+ unet_cells: List[List[float]],
|
|
|
+ rtdetr_cells: List[List[float]],
|
|
|
+ fused_cells: List[List[float]],
|
|
|
+ cell_labels: List[str],
|
|
|
+ debug_dir: str,
|
|
|
+ debug_prefix: str
|
|
|
+ ):
|
|
|
+ """可视化融合结果(调试用)- 增强版:用颜色区分不同来源的单元格"""
|
|
|
+ try:
|
|
|
+ import cv2
|
|
|
+ from pathlib import Path
|
|
|
+
|
|
|
+ output_dir = Path(debug_dir)
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ h, w = table_image.shape[:2]
|
|
|
+
|
|
|
+ # === 图1:UNet原始结果 ===
|
|
|
+ img1 = table_image.copy()
|
|
|
+ if len(img1.shape) == 2:
|
|
|
+ img1 = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR)
|
|
|
+ for cell in unet_cells:
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
+ cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 255, 0), 2) # 绿色
|
|
|
+ cv2.putText(img1, f"UNet ({len(unet_cells)})", (10, 30),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
|
+
|
|
|
+ # === 图2:RT-DETR原始结果 ===
|
|
|
+ img2 = table_image.copy()
|
|
|
+ if len(img2.shape) == 2:
|
|
|
+ img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR)
|
|
|
+ for cell in rtdetr_cells:
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
+ cv2.rectangle(img2, (x1, y1), (x2, y2), (255, 0, 0), 2) # 蓝色
|
|
|
+ cv2.putText(img2, f"RT-DETR ({len(rtdetr_cells)})", (10, 30),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
|
|
|
+
|
|
|
+ # === 图3:融合结果(按来源分类)===
|
|
|
+ img3 = table_image.copy()
|
|
|
+ if len(img3.shape) == 2:
|
|
|
+ img3 = cv2.cvtColor(img3, cv2.COLOR_GRAY2BGR)
|
|
|
+
|
|
|
+ # 根据标签分类单元格(使用 _fuse_cells 中记录的标签)
|
|
|
+ unet_only = [] # UNet独有(绿色)
|
|
|
+ rtdetr_only = [] # RT-DETR独有(蓝色)
|
|
|
+ merged_cells_1to1 = [] # 1:1融合单元格(黄色)
|
|
|
+ merged_cells_span = [] # 合并单元格(品红色,RT-DETR检测的跨格单元格)
|
|
|
+ new_cells = [] # 新增单元格(紫色)
|
|
|
+
|
|
|
+ for fused_cell, label in zip(fused_cells, cell_labels):
|
|
|
+ if label == 'unet_only':
|
|
|
+ unet_only.append(fused_cell)
|
|
|
+ elif label == 'rtdetr_only':
|
|
|
+ rtdetr_only.append(fused_cell)
|
|
|
+ elif label == 'merged_1to1':
|
|
|
+ merged_cells_1to1.append(fused_cell)
|
|
|
+ elif label == 'merged_span':
|
|
|
+ merged_cells_span.append(fused_cell)
|
|
|
+ elif label == 'new':
|
|
|
+ new_cells.append(fused_cell)
|
|
|
+
|
|
|
+ # 绘制不同类型的单元格
|
|
|
+ for cell in unet_only:
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 0), 2) # 绿色 - UNet独有
|
|
|
+
|
|
|
+ for cell in rtdetr_only:
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (255, 0, 0), 2) # 蓝色 - RT-DETR独有
|
|
|
+
|
|
|
+ for cell in merged_cells_1to1:
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 255), 3) # 黄色 - 1:1融合(加粗)
|
|
|
+
|
|
|
+ for cell in merged_cells_span:
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (255, 0, 255), 4) # 品红色 - 合并单元格(加粗)
|
|
|
+
|
|
|
+ for cell in new_cells:
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (128, 0, 128), 2) # 紫色 - 新增
|
|
|
+
|
|
|
+ # 添加图例
|
|
|
+ legend_y = 30
|
|
|
+ cv2.putText(img3, f"Fused ({len(fused_cells)})", (10, legend_y),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
|
|
|
+ legend_y += 35
|
|
|
+ cv2.putText(img3, f"Green: UNet-only ({len(unet_only)})", (10, legend_y),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
|
|
+ legend_y += 30
|
|
|
+ cv2.putText(img3, f"Blue: RTDETR-only ({len(rtdetr_only)})", (10, legend_y),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
|
|
|
+ legend_y += 30
|
|
|
+ cv2.putText(img3, f"Yellow: 1:1 Merged ({len(merged_cells_1to1)})", (10, legend_y),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
|
|
+ legend_y += 30
|
|
|
+ cv2.putText(img3, f"Magenta: Span Cells ({len(merged_cells_span)})", (10, legend_y),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 255), 2)
|
|
|
+ if new_cells:
|
|
|
+ legend_y += 30
|
|
|
+ cv2.putText(img3, f"Purple: New ({len(new_cells)})", (10, legend_y),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (128, 0, 128), 2)
|
|
|
+
|
|
|
+ # 拼接三栏对比
|
|
|
+ vis_canvas = np.zeros((h, w * 3, 3), dtype=np.uint8)
|
|
|
+ vis_canvas[:, :w] = img1
|
|
|
+ vis_canvas[:, w:2*w] = img2
|
|
|
+ vis_canvas[:, 2*w:] = img3
|
|
|
+
|
|
|
+ # 保存
|
|
|
+ output_path = output_dir / f"{debug_prefix}_fusion_comparison.png"
|
|
|
+ cv2.imwrite(str(output_path), vis_canvas)
|
|
|
+ logger.info(f"💾 融合可视化已保存: {output_path}")
|
|
|
+ logger.info(f" 📊 单元格分类: UNet独有={len(unet_only)}, RT-DETR独有={len(rtdetr_only)}, "
|
|
|
+ f"1:1融合={len(merged_cells_1to1)}, 合并单元格={len(merged_cells_span)}, 新增={len(new_cells)}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"Failed to visualize fusion: {e}")
|