Pārlūkot izejas kodu

feat(cell_fusion): 添加多源单元格融合引擎,支持 UNet、RT-DETR 和 OCR 结果融合

zhch158_admin 3 nedēļas atpakaļ
vecāks
revīzija
d62cc9a6ee

+ 471 - 0
ocr_tools/universal_doc_parser/models/adapters/wired_table/cell_fusion.py

@@ -0,0 +1,471 @@
+"""多源单元格融合引擎:融合 UNet、RT-DETR 和 OCR 结果"""
+
+import numpy as np
+from typing import Dict, List, Tuple, Optional, Any
+from loguru import logger
+
+try:
+    from ocr_utils.coordinate_utils import CoordinateUtils
+except ImportError:
+    from ...core.coordinate_utils import CoordinateUtils
+
+
+class CellFusionEngine:
+    """
+    多源单元格融合引擎
+    
+    融合策略:
+    1. UNet 连通域检测(结构性强,适合清晰表格)
+    2. RT-DETR 端到端检测(鲁棒性强,适合噪声表格)
+    3. OCR 文本位置(验证单元格存在性)
+    
+    自适应策略:
+    - 文字PDF (pdf_type='txt'): 跳过 RT-DETR,纯 UNet 模式(无噪声)
+    - 扫描PDF (pdf_type='ocr'): 启用融合模式(有噪声)
+    """
+    
+    def __init__(
+        self,
+        rtdetr_detector: Optional[Any] = None,
+        config: Optional[Dict[str, Any]] = None
+    ):
+        """
+        初始化融合引擎
+        
+        Args:
+            rtdetr_detector: RT-DETR 检测器实例(可选)
+            config: 融合配置
+                - unet_weight: 0.6 (UNet 权重)
+                - rtdetr_weight: 0.4 (RT-DETR 权重)
+                - iou_merge_threshold: 0.7 (高IoU合并阈值)
+                - iou_nms_threshold: 0.5 (NMS去重阈值)
+                - rtdetr_conf_threshold: 0.5 (RT-DETR置信度阈值)
+                - enable_ocr_compensation: True (启用OCR补偿)
+                - skip_rtdetr_for_txt_pdf: True (文字PDF跳过RT-DETR)
+        """
+        self.rtdetr_detector = rtdetr_detector
+        self.config = config or {}
+        
+        # 融合参数
+        self.unet_weight = self.config.get('unet_weight', 0.6)
+        self.rtdetr_weight = self.config.get('rtdetr_weight', 0.4)
+        self.iou_merge_threshold = self.config.get('iou_merge_threshold', 0.7)
+        self.iou_nms_threshold = self.config.get('iou_nms_threshold', 0.5)
+        self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
+        self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
+        self.skip_rtdetr_for_txt_pdf = self.config.get('skip_rtdetr_for_txt_pdf', True)
+        
+        logger.info(f"🔧 CellFusionEngine initialized: "
+                   f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
+                   f"iou_merge={self.iou_merge_threshold}, skip_txt_pdf={self.skip_rtdetr_for_txt_pdf}")
+    
+    def should_use_rtdetr(
+        self,
+        pdf_type: str,
+        unet_cell_count: int,
+        table_size: Tuple[int, int]
+    ) -> bool:
+        """
+        判断是否需要使用 RT-DETR 检测(自适应策略)
+        
+        Args:
+            pdf_type: PDF类型 ('txt' 或 'ocr')
+            unet_cell_count: UNet检测到的单元格数量
+            table_size: 表格尺寸 (width, height)
+            
+        Returns:
+            是否使用 RT-DETR
+        """
+        # 策略1: 文字PDF跳过RT-DETR(无噪声,UNet结果已足够准确)
+        if pdf_type == 'txt' and self.skip_rtdetr_for_txt_pdf:
+            logger.debug(f"📄 Text PDF detected, skip RT-DETR (UNet cells: {unet_cell_count})")
+            return False
+        
+        # 策略2: 如果 RT-DETR 检测器未初始化,跳过
+        if self.rtdetr_detector is None:
+            logger.debug("⚠️ RT-DETR detector not initialized, skip fusion")
+            return False
+        
+        # 策略3: UNet检测结果为空,必须使用RT-DETR补救
+        if unet_cell_count == 0:
+            logger.info("🚨 UNet detected 0 cells, force enable RT-DETR")
+            return True
+        
+        # 策略4: 扫描PDF,启用融合模式
+        logger.debug(f"🔍 Scan PDF detected, enable RT-DETR fusion (UNet cells: {unet_cell_count})")
+        return True
+    
+    def fuse(
+        self,
+        table_image: np.ndarray,
+        unet_cells: List[List[float]],
+        ocr_boxes: List[Dict[str, Any]],
+        pdf_type: str = 'ocr',
+        upscale: float = 1.0,
+        debug_dir: Optional[str] = None,
+        debug_prefix: str = "fusion"
+    ) -> Tuple[List[List[float]], Dict[str, Any]]:
+        """
+        融合多源单元格检测结果
+        
+        Args:
+            table_image: 表格图像(原图坐标系)
+            unet_cells: UNet检测的单元格列表 [[x1,y1,x2,y2], ...](原图坐标系)
+            ocr_boxes: OCR结果列表
+            pdf_type: PDF类型 ('txt' 或 'ocr')
+            upscale: UNet的上采样比例
+            debug_dir: 调试输出目录(可选)
+            debug_prefix: 调试文件前缀
+            
+        Returns:
+            (fused_cells, fusion_stats)
+            - fused_cells: 融合后的单元格列表 [[x1,y1,x2,y2], ...]
+            - fusion_stats: 融合统计信息
+        """
+        h, w = table_image.shape[:2]
+        
+        # 决策:是否使用 RT-DETR
+        use_rtdetr = self.should_use_rtdetr(pdf_type, len(unet_cells), (w, h))
+        
+        fusion_stats = {
+            'use_rtdetr': use_rtdetr,
+            'unet_count': len(unet_cells),
+            'rtdetr_count': 0,
+            'fused_count': 0,
+            'merged_count': 0,
+            'added_count': 0,
+            'ocr_compensated_count': 0
+        }
+        
+        # 如果不使用RT-DETR,直接返回UNet结果
+        if not use_rtdetr:
+            fused_cells = unet_cells.copy()
+            fusion_stats['fused_count'] = len(fused_cells)
+            
+            # 可选:OCR补偿
+            if self.enable_ocr_compensation and ocr_boxes:
+                fused_cells, ocr_comp_count = self._compensate_with_ocr(
+                    fused_cells, ocr_boxes, (w, h)
+                )
+                fusion_stats['ocr_compensated_count'] = ocr_comp_count
+                fusion_stats['fused_count'] = len(fused_cells)
+            
+            logger.info(f"📊 Fusion (UNet-only): {len(unet_cells)} → {len(fused_cells)} cells")
+            return fused_cells, fusion_stats
+        
+        # Phase 1: RT-DETR 检测
+        try:
+            rtdetr_results = self.rtdetr_detector.detect(
+                table_image,
+                conf_threshold=self.rtdetr_conf_threshold
+            )
+            rtdetr_cells = [res['bbox'] for res in rtdetr_results]
+            rtdetr_scores = [res['score'] for res in rtdetr_results]
+            fusion_stats['rtdetr_count'] = len(rtdetr_cells)
+            
+            logger.debug(f"RT-DETR detected {len(rtdetr_cells)} cells")
+        except Exception as e:
+            logger.warning(f"⚠️ RT-DETR detection failed: {e}, fallback to UNet-only")
+            fused_cells = unet_cells.copy()
+            fusion_stats['fused_count'] = len(fused_cells)
+            return fused_cells, fusion_stats
+        
+        # Phase 2: 智能融合
+        fused_cells, merge_stats = self._fuse_cells(
+            unet_cells, rtdetr_cells, rtdetr_scores
+        )
+        fusion_stats['merged_count'] = merge_stats['merged']
+        fusion_stats['added_count'] = merge_stats['added']
+        
+        # Phase 3: NMS 去重
+        fused_cells = self._nms_filter(fused_cells, self.iou_nms_threshold)
+        
+        # Phase 4: OCR 补偿(可选)
+        if self.enable_ocr_compensation and ocr_boxes:
+            fused_cells, ocr_comp_count = self._compensate_with_ocr(
+                fused_cells, ocr_boxes, (w, h)
+            )
+            fusion_stats['ocr_compensated_count'] = ocr_comp_count
+        
+        fusion_stats['fused_count'] = len(fused_cells)
+        
+        logger.info(
+            f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
+            f"Merged={merge_stats['merged']}, Added={merge_stats['added']}, "
+            f"Final={len(fused_cells)}"
+        )
+        
+        # 可视化(调试)
+        if debug_dir:
+            self._visualize_fusion(
+                table_image, unet_cells, rtdetr_cells, fused_cells,
+                debug_dir, debug_prefix
+            )
+        
+        return fused_cells, fusion_stats
+    
+    def _fuse_cells(
+        self,
+        unet_cells: List[List[float]],
+        rtdetr_cells: List[List[float]],
+        rtdetr_scores: List[float]
+    ) -> Tuple[List[List[float]], Dict[str, int]]:
+        """
+        融合 UNet 和 RT-DETR 检测结果
+        
+        融合规则:
+        1. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并
+        2. RT-DETR 独有 + 高置信度 (>0.7) → 补充
+        3. UNet 独有 → 保留
+        
+        Args:
+            unet_cells: UNet单元格列表
+            rtdetr_cells: RT-DETR单元格列表
+            rtdetr_scores: RT-DETR置信度列表
+            
+        Returns:
+            (fused_cells, stats)
+            - fused_cells: 融合后的单元格
+            - stats: {'merged': int, 'added': int}
+        """
+        fused_cells = []
+        rtdetr_matched = [False] * len(rtdetr_cells)
+        stats = {'merged': 0, 'added': 0}
+        
+        # Step 1: 遍历 UNet 单元格,尝试与 RT-DETR 匹配
+        for unet_cell in unet_cells:
+            best_match_idx = -1
+            best_iou = 0.0
+            
+            # 查找最佳匹配的 RT-DETR 单元格
+            for idx, rtdetr_cell in enumerate(rtdetr_cells):
+                if rtdetr_matched[idx]:
+                    continue
+                
+                iou = CoordinateUtils.calculate_iou(unet_cell, rtdetr_cell)
+                if iou > best_iou:
+                    best_iou = iou
+                    best_match_idx = idx
+            
+            # 判断是否合并
+            if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
+                # 高IoU:加权平均合并
+                merged_cell = self._weighted_merge_bbox(
+                    unet_cell,
+                    rtdetr_cells[best_match_idx],
+                    self.unet_weight,
+                    self.rtdetr_weight
+                )
+                fused_cells.append(merged_cell)
+                rtdetr_matched[best_match_idx] = True
+                stats['merged'] += 1
+            else:
+                # UNet 独有:保留
+                fused_cells.append(unet_cell)
+        
+        # Step 2: 补充 RT-DETR 独有的高置信度单元格
+        for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
+            if not rtdetr_matched[idx] and score > 0.7:
+                fused_cells.append(rtdetr_cell)
+                stats['added'] += 1
+        
+        return fused_cells, stats
+    
+    def _weighted_merge_bbox(
+        self,
+        bbox1: List[float],
+        bbox2: List[float],
+        weight1: float,
+        weight2: float
+    ) -> List[float]:
+        """
+        加权平均合并两个 bbox
+        
+        Args:
+            bbox1: [x1, y1, x2, y2]
+            bbox2: [x1, y1, x2, y2]
+            weight1: bbox1 的权重
+            weight2: bbox2 的权重
+            
+        Returns:
+            merged_bbox: [x1, y1, x2, y2]
+        """
+        return [
+            weight1 * bbox1[0] + weight2 * bbox2[0],
+            weight1 * bbox1[1] + weight2 * bbox2[1],
+            weight1 * bbox1[2] + weight2 * bbox2[2],
+            weight1 * bbox1[3] + weight2 * bbox2[3]
+        ]
+    
+    def _nms_filter(
+        self,
+        cells: List[List[float]],
+        iou_threshold: float
+    ) -> List[List[float]]:
+        """
+        简单 NMS 过滤(去除高度重叠的冗余框)
+        
+        策略:按面积排序,保留大框,移除与大框高IoU的小框
+        
+        Args:
+            cells: 单元格列表 [[x1,y1,x2,y2], ...]
+            iou_threshold: IoU阈值
+            
+        Returns:
+            过滤后的单元格列表
+        """
+        if len(cells) == 0:
+            return []
+        
+        # 计算面积并排序(大框优先)
+        areas = [(x2 - x1) * (y2 - y1) for x1, y1, x2, y2 in cells]
+        sorted_indices = sorted(range(len(cells)), key=lambda i: areas[i], reverse=True)
+        
+        keep = []
+        suppressed = [False] * len(cells)
+        
+        for idx in sorted_indices:
+            if suppressed[idx]:
+                continue
+            
+            keep.append(cells[idx])
+            
+            # 抑制与当前框高IoU的其他框
+            for other_idx in sorted_indices:
+                if other_idx == idx or suppressed[other_idx]:
+                    continue
+                
+                iou = CoordinateUtils.calculate_iou(cells[idx], cells[other_idx])
+                if iou > iou_threshold:
+                    suppressed[other_idx] = True
+        
+        logger.debug(f"NMS: {len(cells)} → {len(keep)} cells (threshold={iou_threshold})")
+        return keep
+    
+    def _compensate_with_ocr(
+        self,
+        cells: List[List[float]],
+        ocr_boxes: List[Dict[str, Any]],
+        table_size: Tuple[int, int]
+    ) -> Tuple[List[List[float]], int]:
+        """
+        使用 OCR 补偿遗漏的单元格
+        
+        策略:如果 OCR 文本没有匹配到任何单元格,创建新单元格
+        
+        Args:
+            cells: 现有单元格列表
+            ocr_boxes: OCR结果列表
+            table_size: 表格尺寸 (width, height)
+            
+        Returns:
+            (compensated_cells, compensation_count)
+        """
+        compensated = cells.copy()
+        compensation_count = 0
+        w, h = table_size
+        
+        for ocr in ocr_boxes:
+            ocr_bbox = ocr.get('bbox', [])
+            if not ocr_bbox or len(ocr_bbox) < 4:
+                continue
+            
+            # 计算 OCR 中心点
+            if len(ocr_bbox) == 8:  # poly format
+                ocr_cx = (ocr_bbox[0] + ocr_bbox[2] + ocr_bbox[4] + ocr_bbox[6]) / 4
+                ocr_cy = (ocr_bbox[1] + ocr_bbox[3] + ocr_bbox[5] + ocr_bbox[7]) / 4
+            else:  # bbox format
+                ocr_cx = (ocr_bbox[0] + ocr_bbox[2]) / 2
+                ocr_cy = (ocr_bbox[1] + ocr_bbox[3]) / 2
+            
+            # 检查是否在任何单元格内
+            is_covered = False
+            for cell in compensated:
+                x1, y1, x2, y2 = cell
+                if x1 <= ocr_cx <= x2 and y1 <= ocr_cy <= y2:
+                    is_covered = True
+                    break
+            
+            # 如果孤立,创建新单元格
+            if not is_covered:
+                # 扩展 OCR bbox 作为新单元格
+                if len(ocr_bbox) == 8:
+                    new_cell = [
+                        max(0, min(ocr_bbox[0], ocr_bbox[6]) - 5),
+                        max(0, min(ocr_bbox[1], ocr_bbox[3]) - 5),
+                        min(w, max(ocr_bbox[2], ocr_bbox[4]) + 5),
+                        min(h, max(ocr_bbox[5], ocr_bbox[7]) + 5)
+                    ]
+                else:
+                    new_cell = [
+                        max(0, ocr_bbox[0] - 5),
+                        max(0, ocr_bbox[1] - 5),
+                        min(w, ocr_bbox[2] + 5),
+                        min(h, ocr_bbox[3] + 5)
+                    ]
+                
+                compensated.append(new_cell)
+                compensation_count += 1
+        
+        if compensation_count > 0:
+            logger.debug(f"OCR compensation: added {compensation_count} cells")
+        
+        return compensated, compensation_count
+    
+    def _visualize_fusion(
+        self,
+        table_image: np.ndarray,
+        unet_cells: List[List[float]],
+        rtdetr_cells: List[List[float]],
+        fused_cells: List[List[float]],
+        debug_dir: str,
+        debug_prefix: str
+    ):
+        """可视化融合结果(调试用)"""
+        try:
+            import cv2
+            from pathlib import Path
+            
+            output_dir = Path(debug_dir)
+            output_dir.mkdir(parents=True, exist_ok=True)
+            
+            # 创建三栏对比图
+            h, w = table_image.shape[:2]
+            vis_canvas = np.zeros((h, w * 3, 3), dtype=np.uint8)
+            
+            # 左栏:UNet
+            img1 = table_image.copy()
+            for cell in unet_cells:
+                x1, y1, x2, y2 = [int(v) for v in cell]
+                cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 255, 0), 2)
+            cv2.putText(img1, f"UNet ({len(unet_cells)})", (10, 30),
+                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
+            
+            # 中栏:RT-DETR
+            img2 = table_image.copy()
+            for cell in rtdetr_cells:
+                x1, y1, x2, y2 = [int(v) for v in cell]
+                cv2.rectangle(img2, (x1, y1), (x2, y2), (255, 0, 0), 2)
+            cv2.putText(img2, f"RT-DETR ({len(rtdetr_cells)})", (10, 30),
+                       cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
+            
+            # 右栏:融合结果
+            img3 = table_image.copy()
+            for cell in fused_cells:
+                x1, y1, x2, y2 = [int(v) for v in cell]
+                cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 255), 2)
+            cv2.putText(img3, f"Fused ({len(fused_cells)})", (10, 30),
+                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
+            
+            # 拼接
+            vis_canvas[:, :w] = img1
+            vis_canvas[:, w:2*w] = img2
+            vis_canvas[:, 2*w:] = img3
+            
+            # 保存
+            output_path = output_dir / f"{debug_prefix}_fusion_comparison.png"
+            cv2.imwrite(str(output_path), vis_canvas)
+            logger.debug(f"💾 Fusion visualization saved: {output_path}")
+            
+        except Exception as e:
+            logger.warning(f"Failed to visualize fusion: {e}")