Quellcode durchsuchen

feat(cell_fusion): 添加多源单元格融合支持,集成 UNet 和 RT-DETR

zhch158_admin vor 3 Wochen
Ursprung
Commit
637bcf9318

+ 63 - 0
ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

@@ -17,6 +17,16 @@ from .wired_table.grid_recovery import GridRecovery
 from .wired_table.text_filling import TextFiller
 from .wired_table.html_generator import WiredTableHTMLGenerator
 from .wired_table.visualization import WiredTableVisualizer
+from .wired_table.cell_fusion import CellFusionEngine
+
+# 导入 RT-DETR 单元格检测器
+try:
+    from .paddle_wired_table_cells_detector import PaddleWiredTableCellsDetector
+    RTDETR_AVAILABLE = True
+except ImportError:
+    RTDETR_AVAILABLE = False
+    PaddleWiredTableCellsDetector = None
+    logger.warning("RT-DETR cell detector not available, fusion mode disabled")
 
 # 确保 mineru 库可导入
 from pathlib import Path
@@ -57,6 +67,40 @@ class MinerUWiredTableRecognizer:
         self.text_filler = TextFiller(ocr_engine, self.config)
         self.html_generator = WiredTableHTMLGenerator()
         self.visualizer = WiredTableVisualizer()
+        
+        # 初始化单元格融合引擎(可选)
+        self.cell_fusion_engine = None
+        self.use_cell_fusion = self.config.get("use_cell_fusion", False)
+        
+        if self.use_cell_fusion and RTDETR_AVAILABLE:
+            try:
+                # 获取融合配置
+                fusion_config = self.config.get("cell_fusion", {})
+                rtdetr_model_path = fusion_config.get("rtdetr_model_path")
+                
+                if rtdetr_model_path:
+                    # 初始化 RT-DETR 检测器
+                    rtdetr_config = {
+                        'model_dir': rtdetr_model_path,
+                        'device': self.config.get('device', 'cpu'),
+                        'conf': fusion_config.get('rtdetr_conf_threshold', 0.5)
+                    }
+                    rtdetr_detector = PaddleWiredTableCellsDetector(rtdetr_config)
+                    rtdetr_detector.initialize()
+                    
+                    # 初始化融合引擎
+                    self.cell_fusion_engine = CellFusionEngine(
+                        rtdetr_detector=rtdetr_detector,
+                        config=fusion_config
+                    )
+                    logger.info("🔧 Cell fusion engine enabled")
+                else:
+                    logger.warning("⚠️ Cell fusion enabled but rtdetr_model_path not configured")
+            except Exception as e:
+                logger.error(f"❌ Failed to initialize cell fusion engine: {e}")
+                self.cell_fusion_engine = None
+        elif self.use_cell_fusion and not RTDETR_AVAILABLE:
+            logger.warning("⚠️ Cell fusion enabled but RT-DETR detector not available")
 
     # ========== 倾斜检测与矫正 ==========
     
@@ -352,6 +396,25 @@ class MinerUWiredTableRecognizer:
             if not bboxes:
                 raise RuntimeError("未能提取出单元格")
 
+            # Step 2.3: 🆕 多源单元格融合(UNet + RT-DETR + OCR)
+            fusion_stats = {}
+            if self.cell_fusion_engine:
+                try:
+                    logger.debug(f"🔀 Starting multi-source cell fusion (pdf_type={pdf_type})")
+                    bboxes, fusion_stats = self.cell_fusion_engine.fuse(
+                        table_image=table_image,
+                        unet_cells=bboxes,
+                        ocr_boxes=ocr_boxes or [],
+                        pdf_type=pdf_type,
+                        upscale=upscale,
+                        debug_dir=debug_dir,
+                        debug_prefix=debug_prefix
+                    )
+                    logger.info(f"✅ Cell fusion completed: {fusion_stats}")
+                except Exception as e:
+                    logger.error(f"❌ Cell fusion failed: {e}, using UNet-only results")
+                    # 融合失败,继续使用 UNet 结果
+            
             # Step 2.5: 可视化连通域(线条+框,直观版)
             if self.debug_utils.debug_is_on("save_connected_components", dbg):
                 out_path = self.debug_utils.debug_path("connected_components", dbg)