|
@@ -0,0 +1,471 @@
|
|
|
|
|
+"""多源单元格融合引擎:融合 UNet、RT-DETR 和 OCR 结果"""
|
|
|
|
|
+
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+from typing import Dict, List, Tuple, Optional, Any
|
|
|
|
|
+from loguru import logger
|
|
|
|
|
+
|
|
|
|
|
+try:
|
|
|
|
|
+ from ocr_utils.coordinate_utils import CoordinateUtils
|
|
|
|
|
+except ImportError:
|
|
|
|
|
+ from ...core.coordinate_utils import CoordinateUtils
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class CellFusionEngine:
|
|
|
|
|
+ """
|
|
|
|
|
+ 多源单元格融合引擎
|
|
|
|
|
+
|
|
|
|
|
+ 融合策略:
|
|
|
|
|
+ 1. UNet 连通域检测(结构性强,适合清晰表格)
|
|
|
|
|
+ 2. RT-DETR 端到端检测(鲁棒性强,适合噪声表格)
|
|
|
|
|
+ 3. OCR 文本位置(验证单元格存在性)
|
|
|
|
|
+
|
|
|
|
|
+ 自适应策略:
|
|
|
|
|
+ - 文字PDF (pdf_type='txt'): 跳过 RT-DETR,纯 UNet 模式(无噪声)
|
|
|
|
|
+ - 扫描PDF (pdf_type='ocr'): 启用融合模式(有噪声)
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ rtdetr_detector: Optional[Any] = None,
|
|
|
|
|
+ config: Optional[Dict[str, Any]] = None
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ 初始化融合引擎
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ rtdetr_detector: RT-DETR 检测器实例(可选)
|
|
|
|
|
+ config: 融合配置
|
|
|
|
|
+ - unet_weight: 0.6 (UNet 权重)
|
|
|
|
|
+ - rtdetr_weight: 0.4 (RT-DETR 权重)
|
|
|
|
|
+ - iou_merge_threshold: 0.7 (高IoU合并阈值)
|
|
|
|
|
+ - iou_nms_threshold: 0.5 (NMS去重阈值)
|
|
|
|
|
+ - rtdetr_conf_threshold: 0.5 (RT-DETR置信度阈值)
|
|
|
|
|
+ - enable_ocr_compensation: True (启用OCR补偿)
|
|
|
|
|
+ - skip_rtdetr_for_txt_pdf: True (文字PDF跳过RT-DETR)
|
|
|
|
|
+ """
|
|
|
|
|
+ self.rtdetr_detector = rtdetr_detector
|
|
|
|
|
+ self.config = config or {}
|
|
|
|
|
+
|
|
|
|
|
+ # 融合参数
|
|
|
|
|
+ self.unet_weight = self.config.get('unet_weight', 0.6)
|
|
|
|
|
+ self.rtdetr_weight = self.config.get('rtdetr_weight', 0.4)
|
|
|
|
|
+ self.iou_merge_threshold = self.config.get('iou_merge_threshold', 0.7)
|
|
|
|
|
+ self.iou_nms_threshold = self.config.get('iou_nms_threshold', 0.5)
|
|
|
|
|
+ self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
|
|
|
|
|
+ self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
|
|
|
|
|
+ self.skip_rtdetr_for_txt_pdf = self.config.get('skip_rtdetr_for_txt_pdf', True)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"🔧 CellFusionEngine initialized: "
|
|
|
|
|
+ f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
|
|
|
|
|
+ f"iou_merge={self.iou_merge_threshold}, skip_txt_pdf={self.skip_rtdetr_for_txt_pdf}")
|
|
|
|
|
+
|
|
|
|
|
+ def should_use_rtdetr(
|
|
|
|
|
+ self,
|
|
|
|
|
+ pdf_type: str,
|
|
|
|
|
+ unet_cell_count: int,
|
|
|
|
|
+ table_size: Tuple[int, int]
|
|
|
|
|
+ ) -> bool:
|
|
|
|
|
+ """
|
|
|
|
|
+ 判断是否需要使用 RT-DETR 检测(自适应策略)
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ pdf_type: PDF类型 ('txt' 或 'ocr')
|
|
|
|
|
+ unet_cell_count: UNet检测到的单元格数量
|
|
|
|
|
+ table_size: 表格尺寸 (width, height)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 是否使用 RT-DETR
|
|
|
|
|
+ """
|
|
|
|
|
+ # 策略1: 文字PDF跳过RT-DETR(无噪声,UNet结果已足够准确)
|
|
|
|
|
+ if pdf_type == 'txt' and self.skip_rtdetr_for_txt_pdf:
|
|
|
|
|
+ logger.debug(f"📄 Text PDF detected, skip RT-DETR (UNet cells: {unet_cell_count})")
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ # 策略2: 如果 RT-DETR 检测器未初始化,跳过
|
|
|
|
|
+ if self.rtdetr_detector is None:
|
|
|
|
|
+ logger.debug("⚠️ RT-DETR detector not initialized, skip fusion")
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ # 策略3: UNet检测结果为空,必须使用RT-DETR补救
|
|
|
|
|
+ if unet_cell_count == 0:
|
|
|
|
|
+ logger.info("🚨 UNet detected 0 cells, force enable RT-DETR")
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ # 策略4: 扫描PDF,启用融合模式
|
|
|
|
|
+ logger.debug(f"🔍 Scan PDF detected, enable RT-DETR fusion (UNet cells: {unet_cell_count})")
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ def fuse(
|
|
|
|
|
+ self,
|
|
|
|
|
+ table_image: np.ndarray,
|
|
|
|
|
+ unet_cells: List[List[float]],
|
|
|
|
|
+ ocr_boxes: List[Dict[str, Any]],
|
|
|
|
|
+ pdf_type: str = 'ocr',
|
|
|
|
|
+ upscale: float = 1.0,
|
|
|
|
|
+ debug_dir: Optional[str] = None,
|
|
|
|
|
+ debug_prefix: str = "fusion"
|
|
|
|
|
+ ) -> Tuple[List[List[float]], Dict[str, Any]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 融合多源单元格检测结果
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ table_image: 表格图像(原图坐标系)
|
|
|
|
|
+ unet_cells: UNet检测的单元格列表 [[x1,y1,x2,y2], ...](原图坐标系)
|
|
|
|
|
+ ocr_boxes: OCR结果列表
|
|
|
|
|
+ pdf_type: PDF类型 ('txt' 或 'ocr')
|
|
|
|
|
+ upscale: UNet的上采样比例
|
|
|
|
|
+ debug_dir: 调试输出目录(可选)
|
|
|
|
|
+ debug_prefix: 调试文件前缀
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ (fused_cells, fusion_stats)
|
|
|
|
|
+ - fused_cells: 融合后的单元格列表 [[x1,y1,x2,y2], ...]
|
|
|
|
|
+ - fusion_stats: 融合统计信息
|
|
|
|
|
+ """
|
|
|
|
|
+ h, w = table_image.shape[:2]
|
|
|
|
|
+
|
|
|
|
|
+ # 决策:是否使用 RT-DETR
|
|
|
|
|
+ use_rtdetr = self.should_use_rtdetr(pdf_type, len(unet_cells), (w, h))
|
|
|
|
|
+
|
|
|
|
|
+ fusion_stats = {
|
|
|
|
|
+ 'use_rtdetr': use_rtdetr,
|
|
|
|
|
+ 'unet_count': len(unet_cells),
|
|
|
|
|
+ 'rtdetr_count': 0,
|
|
|
|
|
+ 'fused_count': 0,
|
|
|
|
|
+ 'merged_count': 0,
|
|
|
|
|
+ 'added_count': 0,
|
|
|
|
|
+ 'ocr_compensated_count': 0
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 如果不使用RT-DETR,直接返回UNet结果
|
|
|
|
|
+ if not use_rtdetr:
|
|
|
|
|
+ fused_cells = unet_cells.copy()
|
|
|
|
|
+ fusion_stats['fused_count'] = len(fused_cells)
|
|
|
|
|
+
|
|
|
|
|
+ # 可选:OCR补偿
|
|
|
|
|
+ if self.enable_ocr_compensation and ocr_boxes:
|
|
|
|
|
+ fused_cells, ocr_comp_count = self._compensate_with_ocr(
|
|
|
|
|
+ fused_cells, ocr_boxes, (w, h)
|
|
|
|
|
+ )
|
|
|
|
|
+ fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
|
|
|
+ fusion_stats['fused_count'] = len(fused_cells)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"📊 Fusion (UNet-only): {len(unet_cells)} → {len(fused_cells)} cells")
|
|
|
|
|
+ return fused_cells, fusion_stats
|
|
|
|
|
+
|
|
|
|
|
+ # Phase 1: RT-DETR 检测
|
|
|
|
|
+ try:
|
|
|
|
|
+ rtdetr_results = self.rtdetr_detector.detect(
|
|
|
|
|
+ table_image,
|
|
|
|
|
+ conf_threshold=self.rtdetr_conf_threshold
|
|
|
|
|
+ )
|
|
|
|
|
+ rtdetr_cells = [res['bbox'] for res in rtdetr_results]
|
|
|
|
|
+ rtdetr_scores = [res['score'] for res in rtdetr_results]
|
|
|
|
|
+ fusion_stats['rtdetr_count'] = len(rtdetr_cells)
|
|
|
|
|
+
|
|
|
|
|
+ logger.debug(f"RT-DETR detected {len(rtdetr_cells)} cells")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"⚠️ RT-DETR detection failed: {e}, fallback to UNet-only")
|
|
|
|
|
+ fused_cells = unet_cells.copy()
|
|
|
|
|
+ fusion_stats['fused_count'] = len(fused_cells)
|
|
|
|
|
+ return fused_cells, fusion_stats
|
|
|
|
|
+
|
|
|
|
|
+ # Phase 2: 智能融合
|
|
|
|
|
+ fused_cells, merge_stats = self._fuse_cells(
|
|
|
|
|
+ unet_cells, rtdetr_cells, rtdetr_scores
|
|
|
|
|
+ )
|
|
|
|
|
+ fusion_stats['merged_count'] = merge_stats['merged']
|
|
|
|
|
+ fusion_stats['added_count'] = merge_stats['added']
|
|
|
|
|
+
|
|
|
|
|
+ # Phase 3: NMS 去重
|
|
|
|
|
+ fused_cells = self._nms_filter(fused_cells, self.iou_nms_threshold)
|
|
|
|
|
+
|
|
|
|
|
+ # Phase 4: OCR 补偿(可选)
|
|
|
|
|
+ if self.enable_ocr_compensation and ocr_boxes:
|
|
|
|
|
+ fused_cells, ocr_comp_count = self._compensate_with_ocr(
|
|
|
|
|
+ fused_cells, ocr_boxes, (w, h)
|
|
|
|
|
+ )
|
|
|
|
|
+ fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
|
|
|
+
|
|
|
|
|
+ fusion_stats['fused_count'] = len(fused_cells)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
|
|
|
|
|
+ f"Merged={merge_stats['merged']}, Added={merge_stats['added']}, "
|
|
|
|
|
+ f"Final={len(fused_cells)}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 可视化(调试)
|
|
|
|
|
+ if debug_dir:
|
|
|
|
|
+ self._visualize_fusion(
|
|
|
|
|
+ table_image, unet_cells, rtdetr_cells, fused_cells,
|
|
|
|
|
+ debug_dir, debug_prefix
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return fused_cells, fusion_stats
|
|
|
|
|
+
|
|
|
|
|
+ def _fuse_cells(
|
|
|
|
|
+ self,
|
|
|
|
|
+ unet_cells: List[List[float]],
|
|
|
|
|
+ rtdetr_cells: List[List[float]],
|
|
|
|
|
+ rtdetr_scores: List[float]
|
|
|
|
|
+ ) -> Tuple[List[List[float]], Dict[str, int]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 融合 UNet 和 RT-DETR 检测结果
|
|
|
|
|
+
|
|
|
|
|
+ 融合规则:
|
|
|
|
|
+ 1. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并
|
|
|
|
|
+ 2. RT-DETR 独有 + 高置信度 (>0.7) → 补充
|
|
|
|
|
+ 3. UNet 独有 → 保留
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ unet_cells: UNet单元格列表
|
|
|
|
|
+ rtdetr_cells: RT-DETR单元格列表
|
|
|
|
|
+ rtdetr_scores: RT-DETR置信度列表
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ (fused_cells, stats)
|
|
|
|
|
+ - fused_cells: 融合后的单元格
|
|
|
|
|
+ - stats: {'merged': int, 'added': int}
|
|
|
|
|
+ """
|
|
|
|
|
+ fused_cells = []
|
|
|
|
|
+ rtdetr_matched = [False] * len(rtdetr_cells)
|
|
|
|
|
+ stats = {'merged': 0, 'added': 0}
|
|
|
|
|
+
|
|
|
|
|
+ # Step 1: 遍历 UNet 单元格,尝试与 RT-DETR 匹配
|
|
|
|
|
+ for unet_cell in unet_cells:
|
|
|
|
|
+ best_match_idx = -1
|
|
|
|
|
+ best_iou = 0.0
|
|
|
|
|
+
|
|
|
|
|
+ # 查找最佳匹配的 RT-DETR 单元格
|
|
|
|
|
+ for idx, rtdetr_cell in enumerate(rtdetr_cells):
|
|
|
|
|
+ if rtdetr_matched[idx]:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ iou = CoordinateUtils.calculate_iou(unet_cell, rtdetr_cell)
|
|
|
|
|
+ if iou > best_iou:
|
|
|
|
|
+ best_iou = iou
|
|
|
|
|
+ best_match_idx = idx
|
|
|
|
|
+
|
|
|
|
|
+ # 判断是否合并
|
|
|
|
|
+ if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
|
|
|
|
|
+ # 高IoU:加权平均合并
|
|
|
|
|
+ merged_cell = self._weighted_merge_bbox(
|
|
|
|
|
+ unet_cell,
|
|
|
|
|
+ rtdetr_cells[best_match_idx],
|
|
|
|
|
+ self.unet_weight,
|
|
|
|
|
+ self.rtdetr_weight
|
|
|
|
|
+ )
|
|
|
|
|
+ fused_cells.append(merged_cell)
|
|
|
|
|
+ rtdetr_matched[best_match_idx] = True
|
|
|
|
|
+ stats['merged'] += 1
|
|
|
|
|
+ else:
|
|
|
|
|
+ # UNet 独有:保留
|
|
|
|
|
+ fused_cells.append(unet_cell)
|
|
|
|
|
+
|
|
|
|
|
+ # Step 2: 补充 RT-DETR 独有的高置信度单元格
|
|
|
|
|
+ for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
|
|
|
|
|
+ if not rtdetr_matched[idx] and score > 0.7:
|
|
|
|
|
+ fused_cells.append(rtdetr_cell)
|
|
|
|
|
+ stats['added'] += 1
|
|
|
|
|
+
|
|
|
|
|
+ return fused_cells, stats
|
|
|
|
|
+
|
|
|
|
|
+ def _weighted_merge_bbox(
|
|
|
|
|
+ self,
|
|
|
|
|
+ bbox1: List[float],
|
|
|
|
|
+ bbox2: List[float],
|
|
|
|
|
+ weight1: float,
|
|
|
|
|
+ weight2: float
|
|
|
|
|
+ ) -> List[float]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 加权平均合并两个 bbox
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ bbox1: [x1, y1, x2, y2]
|
|
|
|
|
+ bbox2: [x1, y1, x2, y2]
|
|
|
|
|
+ weight1: bbox1 的权重
|
|
|
|
|
+ weight2: bbox2 的权重
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ merged_bbox: [x1, y1, x2, y2]
|
|
|
|
|
+ """
|
|
|
|
|
+ return [
|
|
|
|
|
+ weight1 * bbox1[0] + weight2 * bbox2[0],
|
|
|
|
|
+ weight1 * bbox1[1] + weight2 * bbox2[1],
|
|
|
|
|
+ weight1 * bbox1[2] + weight2 * bbox2[2],
|
|
|
|
|
+ weight1 * bbox1[3] + weight2 * bbox2[3]
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ def _nms_filter(
|
|
|
|
|
+ self,
|
|
|
|
|
+ cells: List[List[float]],
|
|
|
|
|
+ iou_threshold: float
|
|
|
|
|
+ ) -> List[List[float]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 简单 NMS 过滤(去除高度重叠的冗余框)
|
|
|
|
|
+
|
|
|
|
|
+ 策略:按面积排序,保留大框,移除与大框高IoU的小框
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ cells: 单元格列表 [[x1,y1,x2,y2], ...]
|
|
|
|
|
+ iou_threshold: IoU阈值
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 过滤后的单元格列表
|
|
|
|
|
+ """
|
|
|
|
|
+ if len(cells) == 0:
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+ # 计算面积并排序(大框优先)
|
|
|
|
|
+ areas = [(x2 - x1) * (y2 - y1) for x1, y1, x2, y2 in cells]
|
|
|
|
|
+ sorted_indices = sorted(range(len(cells)), key=lambda i: areas[i], reverse=True)
|
|
|
|
|
+
|
|
|
|
|
+ keep = []
|
|
|
|
|
+ suppressed = [False] * len(cells)
|
|
|
|
|
+
|
|
|
|
|
+ for idx in sorted_indices:
|
|
|
|
|
+ if suppressed[idx]:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ keep.append(cells[idx])
|
|
|
|
|
+
|
|
|
|
|
+ # 抑制与当前框高IoU的其他框
|
|
|
|
|
+ for other_idx in sorted_indices:
|
|
|
|
|
+ if other_idx == idx or suppressed[other_idx]:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ iou = CoordinateUtils.calculate_iou(cells[idx], cells[other_idx])
|
|
|
|
|
+ if iou > iou_threshold:
|
|
|
|
|
+ suppressed[other_idx] = True
|
|
|
|
|
+
|
|
|
|
|
+ logger.debug(f"NMS: {len(cells)} → {len(keep)} cells (threshold={iou_threshold})")
|
|
|
|
|
+ return keep
|
|
|
|
|
+
|
|
|
|
|
+ def _compensate_with_ocr(
|
|
|
|
|
+ self,
|
|
|
|
|
+ cells: List[List[float]],
|
|
|
|
|
+ ocr_boxes: List[Dict[str, Any]],
|
|
|
|
|
+ table_size: Tuple[int, int]
|
|
|
|
|
+ ) -> Tuple[List[List[float]], int]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 使用 OCR 补偿遗漏的单元格
|
|
|
|
|
+
|
|
|
|
|
+ 策略:如果 OCR 文本没有匹配到任何单元格,创建新单元格
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ cells: 现有单元格列表
|
|
|
|
|
+ ocr_boxes: OCR结果列表
|
|
|
|
|
+ table_size: 表格尺寸 (width, height)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ (compensated_cells, compensation_count)
|
|
|
|
|
+ """
|
|
|
|
|
+ compensated = cells.copy()
|
|
|
|
|
+ compensation_count = 0
|
|
|
|
|
+ w, h = table_size
|
|
|
|
|
+
|
|
|
|
|
+ for ocr in ocr_boxes:
|
|
|
|
|
+ ocr_bbox = ocr.get('bbox', [])
|
|
|
|
|
+ if not ocr_bbox or len(ocr_bbox) < 4:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 计算 OCR 中心点
|
|
|
|
|
+ if len(ocr_bbox) == 8: # poly format
|
|
|
|
|
+ ocr_cx = (ocr_bbox[0] + ocr_bbox[2] + ocr_bbox[4] + ocr_bbox[6]) / 4
|
|
|
|
|
+ ocr_cy = (ocr_bbox[1] + ocr_bbox[3] + ocr_bbox[5] + ocr_bbox[7]) / 4
|
|
|
|
|
+ else: # bbox format
|
|
|
|
|
+ ocr_cx = (ocr_bbox[0] + ocr_bbox[2]) / 2
|
|
|
|
|
+ ocr_cy = (ocr_bbox[1] + ocr_bbox[3]) / 2
|
|
|
|
|
+
|
|
|
|
|
+ # 检查是否在任何单元格内
|
|
|
|
|
+ is_covered = False
|
|
|
|
|
+ for cell in compensated:
|
|
|
|
|
+ x1, y1, x2, y2 = cell
|
|
|
|
|
+ if x1 <= ocr_cx <= x2 and y1 <= ocr_cy <= y2:
|
|
|
|
|
+ is_covered = True
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ # 如果孤立,创建新单元格
|
|
|
|
|
+ if not is_covered:
|
|
|
|
|
+ # 扩展 OCR bbox 作为新单元格
|
|
|
|
|
+ if len(ocr_bbox) == 8:
|
|
|
|
|
+ new_cell = [
|
|
|
|
|
+ max(0, min(ocr_bbox[0], ocr_bbox[6]) - 5),
|
|
|
|
|
+ max(0, min(ocr_bbox[1], ocr_bbox[3]) - 5),
|
|
|
|
|
+ min(w, max(ocr_bbox[2], ocr_bbox[4]) + 5),
|
|
|
|
|
+ min(h, max(ocr_bbox[5], ocr_bbox[7]) + 5)
|
|
|
|
|
+ ]
|
|
|
|
|
+ else:
|
|
|
|
|
+ new_cell = [
|
|
|
|
|
+ max(0, ocr_bbox[0] - 5),
|
|
|
|
|
+ max(0, ocr_bbox[1] - 5),
|
|
|
|
|
+ min(w, ocr_bbox[2] + 5),
|
|
|
|
|
+ min(h, ocr_bbox[3] + 5)
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ compensated.append(new_cell)
|
|
|
|
|
+ compensation_count += 1
|
|
|
|
|
+
|
|
|
|
|
+ if compensation_count > 0:
|
|
|
|
|
+ logger.debug(f"OCR compensation: added {compensation_count} cells")
|
|
|
|
|
+
|
|
|
|
|
+ return compensated, compensation_count
|
|
|
|
|
+
|
|
|
|
|
+ def _visualize_fusion(
|
|
|
|
|
+ self,
|
|
|
|
|
+ table_image: np.ndarray,
|
|
|
|
|
+ unet_cells: List[List[float]],
|
|
|
|
|
+ rtdetr_cells: List[List[float]],
|
|
|
|
|
+ fused_cells: List[List[float]],
|
|
|
|
|
+ debug_dir: str,
|
|
|
|
|
+ debug_prefix: str
|
|
|
|
|
+ ):
|
|
|
|
|
+ """可视化融合结果(调试用)"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ import cv2
|
|
|
|
|
+ from pathlib import Path
|
|
|
|
|
+
|
|
|
|
|
+ output_dir = Path(debug_dir)
|
|
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+
|
|
|
|
|
+ # 创建三栏对比图
|
|
|
|
|
+ h, w = table_image.shape[:2]
|
|
|
|
|
+ vis_canvas = np.zeros((h, w * 3, 3), dtype=np.uint8)
|
|
|
|
|
+
|
|
|
|
|
+ # 左栏:UNet
|
|
|
|
|
+ img1 = table_image.copy()
|
|
|
|
|
+ for cell in unet_cells:
|
|
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
|
|
+ cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
|
|
+ cv2.putText(img1, f"UNet ({len(unet_cells)})", (10, 30),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
|
|
|
+
|
|
|
|
|
+ # 中栏:RT-DETR
|
|
|
|
|
+ img2 = table_image.copy()
|
|
|
|
|
+ for cell in rtdetr_cells:
|
|
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
|
|
+ cv2.rectangle(img2, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
|
|
|
|
+ cv2.putText(img2, f"RT-DETR ({len(rtdetr_cells)})", (10, 30),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
|
|
|
|
|
+
|
|
|
|
|
+ # 右栏:融合结果
|
|
|
|
|
+ img3 = table_image.copy()
|
|
|
|
|
+ for cell in fused_cells:
|
|
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 255), 2)
|
|
|
|
|
+ cv2.putText(img3, f"Fused ({len(fused_cells)})", (10, 30),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
|
|
|
|
+
|
|
|
|
|
+ # 拼接
|
|
|
|
|
+ vis_canvas[:, :w] = img1
|
|
|
|
|
+ vis_canvas[:, w:2*w] = img2
|
|
|
|
|
+ vis_canvas[:, 2*w:] = img3
|
|
|
|
|
+
|
|
|
|
|
+ # 保存
|
|
|
|
|
+ output_path = output_dir / f"{debug_prefix}_fusion_comparison.png"
|
|
|
|
|
+ cv2.imwrite(str(output_path), vis_canvas)
|
|
|
|
|
+ logger.debug(f"💾 Fusion visualization saved: {output_path}")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"Failed to visualize fusion: {e}")
|