|
|
@@ -17,6 +17,16 @@ from .wired_table.grid_recovery import GridRecovery
|
|
|
from .wired_table.text_filling import TextFiller
|
|
|
from .wired_table.html_generator import WiredTableHTMLGenerator
|
|
|
from .wired_table.visualization import WiredTableVisualizer
|
|
|
+from .wired_table.cell_fusion import CellFusionEngine
|
|
|
+
|
|
|
+# 导入 RT-DETR 单元格检测器
|
|
|
+try:
|
|
|
+ from .paddle_wired_table_cells_detector import PaddleWiredTableCellsDetector
|
|
|
+ RTDETR_AVAILABLE = True
|
|
|
+except ImportError:
|
|
|
+ RTDETR_AVAILABLE = False
|
|
|
+ PaddleWiredTableCellsDetector = None
|
|
|
+ logger.warning("RT-DETR cell detector not available, fusion mode disabled")
|
|
|
|
|
|
# 确保 mineru 库可导入
|
|
|
from pathlib import Path
|
|
|
@@ -57,6 +67,40 @@ class MinerUWiredTableRecognizer:
|
|
|
self.text_filler = TextFiller(ocr_engine, self.config)
|
|
|
self.html_generator = WiredTableHTMLGenerator()
|
|
|
self.visualizer = WiredTableVisualizer()
|
|
|
+
|
|
|
+ # 初始化单元格融合引擎(可选)
|
|
|
+ self.cell_fusion_engine = None
|
|
|
+ self.use_cell_fusion = self.config.get("use_cell_fusion", False)
|
|
|
+
|
|
|
+ if self.use_cell_fusion and RTDETR_AVAILABLE:
|
|
|
+ try:
|
|
|
+ # 获取融合配置
|
|
|
+ fusion_config = self.config.get("cell_fusion", {})
|
|
|
+ rtdetr_model_path = fusion_config.get("rtdetr_model_path")
|
|
|
+
|
|
|
+ if rtdetr_model_path:
|
|
|
+ # 初始化 RT-DETR 检测器
|
|
|
+ rtdetr_config = {
|
|
|
+ 'model_dir': rtdetr_model_path,
|
|
|
+ 'device': self.config.get('device', 'cpu'),
|
|
|
+ 'conf': fusion_config.get('rtdetr_conf_threshold', 0.5)
|
|
|
+ }
|
|
|
+ rtdetr_detector = PaddleWiredTableCellsDetector(rtdetr_config)
|
|
|
+ rtdetr_detector.initialize()
|
|
|
+
|
|
|
+ # 初始化融合引擎
|
|
|
+ self.cell_fusion_engine = CellFusionEngine(
|
|
|
+ rtdetr_detector=rtdetr_detector,
|
|
|
+ config=fusion_config
|
|
|
+ )
|
|
|
+ logger.info("🔧 Cell fusion engine enabled")
|
|
|
+ else:
|
|
|
+ logger.warning("⚠️ Cell fusion enabled but rtdetr_model_path not configured")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ Failed to initialize cell fusion engine: {e}")
|
|
|
+ self.cell_fusion_engine = None
|
|
|
+ elif self.use_cell_fusion and not RTDETR_AVAILABLE:
|
|
|
+ logger.warning("⚠️ Cell fusion enabled but RT-DETR detector not available")
|
|
|
|
|
|
# ========== 倾斜检测与矫正 ==========
|
|
|
|
|
|
@@ -352,6 +396,25 @@ class MinerUWiredTableRecognizer:
|
|
|
if not bboxes:
|
|
|
raise RuntimeError("未能提取出单元格")
|
|
|
|
|
|
+ # Step 2.3: 🆕 多源单元格融合(UNet + RT-DETR + OCR)
|
|
|
+ fusion_stats = {}
|
|
|
+ if self.cell_fusion_engine:
|
|
|
+ try:
|
|
|
+ logger.debug(f"🔀 Starting multi-source cell fusion (pdf_type={pdf_type})")
|
|
|
+ bboxes, fusion_stats = self.cell_fusion_engine.fuse(
|
|
|
+ table_image=table_image,
|
|
|
+ unet_cells=bboxes,
|
|
|
+ ocr_boxes=ocr_boxes or [],
|
|
|
+ pdf_type=pdf_type,
|
|
|
+ upscale=upscale,
|
|
|
+ debug_dir=debug_dir,
|
|
|
+ debug_prefix=debug_prefix
|
|
|
+ )
|
|
|
+ logger.info(f"✅ Cell fusion completed: {fusion_stats}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ Cell fusion failed: {e}, using UNet-only results")
|
|
|
+ # 融合失败,继续使用 UNet 结果
|
|
|
+
|
|
|
# Step 2.5: 可视化连通域(线条+框,直观版)
|
|
|
if self.debug_utils.debug_is_on("save_connected_components", dbg):
|
|
|
out_path = self.debug_utils.debug_path("connected_components", dbg)
|