Sfoglia il codice sorgente

feat(cell_fusion_tests): 添加多源单元格融合系统测试,包括RT-DETR检测器和融合引擎功能验证

zhch158_admin 2 settimane fa
parent
commit
01d8ee5005

+ 46 - 0
ocr_tools/universal_doc_parser/tests/cell_fusion_config_example.yaml

@@ -0,0 +1,46 @@
+# 多源单元格融合配置示例
+# 用于 MinerUWiredTableRecognizer
+
+wired_table_recognizer:
+  # 基础配置
+  upscale_ratio: 3.333  # 10/3
+  use_custom_postprocess: true  # 启用 v4 流程
+  
+  # 🆕 启用多源单元格融合
+  use_cell_fusion: true
+  
+  # 融合引擎配置
+  cell_fusion:
+    # RT-DETR 模型路径(必需)
+    rtdetr_model_path: "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/pytorch_models/Table/RT-DETR-L_wired_table_cell_det.onnx"
+    
+    # 融合权重
+    unet_weight: 0.6        # UNet 权重(结构性强)
+    rtdetr_weight: 0.4      # RT-DETR 权重(鲁棒性强)
+    
+    # 阈值配置
+    iou_merge_threshold: 0.7    # 高IoU合并阈值(>0.7则加权平均)
+    iou_nms_threshold: 0.5      # NMS去重阈值
+    rtdetr_conf_threshold: 0.5  # RT-DETR置信度阈值
+    
+    # 功能开关
+    enable_ocr_compensation: true      # 启用OCR孤立文本补偿
+    skip_rtdetr_for_txt_pdf: true      # 🎯 文字PDF跳过RT-DETR(自适应策略)
+  
+  # 调试选项
+  debug_options:
+    enabled: true
+    output_dir: "debug_output/table_fusion"
+    prefix: "table"
+    save_fusion_comparison: true  # 保存融合对比图
+
+# 使用说明:
+# 1. 文字PDF (pdf_type='txt'): 自动跳过RT-DETR,使用纯UNet模式(无噪声干扰)
+# 2. 扫描PDF (pdf_type='ocr'): 启用融合模式,结合UNet、RT-DETR和OCR三路结果
+# 3. UNet结果为空: 强制启用RT-DETR补救
+# 4. 融合失败: 自动降级到UNet-only模式
+
+# 性能优化建议:
+# - 小表格(<50单元格): 考虑禁用融合(use_cell_fusion: false)
+# - 高质量图像: 提高 unet_weight 到 0.7-0.8
+# - 模糊图像: 降低 unet_weight 到 0.4-0.5

+ 174 - 0
ocr_tools/universal_doc_parser/tests/test_cell_fusion.py

@@ -0,0 +1,174 @@
+"""
+测试多源单元格融合系统
+
+验证:
+1. RT-DETR检测器初始化
+2. 融合引擎基本功能
+3. 自适应策略(文字PDF跳过RT-DETR)
+"""
+
+import cv2
+import numpy as np
+import sys
+from pathlib import Path
+
+# 添加路径
+sys.path.insert(0, str(Path(__file__).parents[5]))
+
+from ocr_tools.universal_doc_parser.models.adapters.paddle_wired_table_cells_detector import PaddleWiredTableCellsDetector
+from ocr_tools.universal_doc_parser.models.adapters.wired_table.cell_fusion import CellFusionEngine
+
+
+def test_rtdetr_detector():
+    """测试 RT-DETR 检测器"""
+    print("=" * 60)
+    print("Test 1: RT-DETR 单元格检测器")
+    print("=" * 60)
+    
+    # 配置
+    config = {
+        'model_dir': '/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/pytorch_models/Table/RT-DETR-L_wired_table_cell_det.onnx',
+        'device': 'cpu',
+        'conf': 0.5
+    }
+    
+    # 初始化
+    try:
+        detector = PaddleWiredTableCellsDetector(config)
+        detector.initialize()
+        print("✅ RT-DETR detector initialized successfully")
+    except Exception as e:
+        print(f"❌ Failed to initialize: {e}")
+        return None
+    
+    # 测试检测
+    test_image_path = "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/tests/2023年度报告母公司_page_005_270_table.png"
+    if Path(test_image_path).exists():
+        img = cv2.imread(test_image_path)
+        print(f"\n📖 Test image: {img.shape}")
+        
+        results = detector.detect(img, conf_threshold=0.5)
+        print(f"✅ Detected {len(results)} cells")
+        
+        if len(results) > 0:
+            print(f"   Sample cell: {results[0]}")
+    else:
+        print(f"⚠️ Test image not found: {test_image_path}")
+    
+    return detector
+
+
+def test_fusion_engine(detector):
+    """测试融合引擎"""
+    print("\n" + "=" * 60)
+    print("Test 2: 融合引擎基本功能")
+    print("=" * 60)
+    
+    # 配置
+    fusion_config = {
+        'unet_weight': 0.6,
+        'rtdetr_weight': 0.4,
+        'iou_merge_threshold': 0.7,
+        'iou_nms_threshold': 0.5,
+        'rtdetr_conf_threshold': 0.5,
+        'enable_ocr_compensation': True,
+        'skip_rtdetr_for_txt_pdf': True
+    }
+    
+    # 初始化
+    engine = CellFusionEngine(rtdetr_detector=detector, config=fusion_config)
+    print("✅ Fusion engine initialized")
+    
+    # 模拟数据
+    table_image = np.ones((500, 500, 3), dtype=np.uint8) * 255
+    
+    unet_cells = [
+        [10, 10, 100, 50],
+        [10, 60, 100, 100],
+        [110, 10, 200, 50]
+    ]
+    
+    ocr_boxes = [
+        {'bbox': [20, 20, 80, 40], 'text': 'Cell 1'},
+        {'bbox': [20, 70, 80, 90], 'text': 'Cell 2'}
+    ]
+    
+    # Test 2.1: 文字PDF模式(应跳过RT-DETR)
+    print("\n📄 Test 2.1: Text PDF mode (should skip RT-DETR)")
+    fused_cells, stats = engine.fuse(
+        table_image=table_image,
+        unet_cells=unet_cells,
+        ocr_boxes=ocr_boxes,
+        pdf_type='txt',
+        upscale=1.0
+    )
+    print(f"   Use RT-DETR: {stats['use_rtdetr']}")
+    print(f"   Fused cells: {len(fused_cells)}")
+    assert not stats['use_rtdetr'], "❌ Should skip RT-DETR for text PDF"
+    assert len(fused_cells) == len(unet_cells), "❌ Should keep UNet cells only"
+    print("   ✅ Correctly skipped RT-DETR for text PDF")
+    
+    # Test 2.2: 扫描PDF模式(应启用RT-DETR,但因为是假图片可能失败)
+    print("\n🔍 Test 2.2: Scan PDF mode (should enable RT-DETR)")
+    fused_cells, stats = engine.fuse(
+        table_image=table_image,
+        unet_cells=unet_cells,
+        ocr_boxes=ocr_boxes,
+        pdf_type='ocr',
+        upscale=1.0
+    )
+    print(f"   Use RT-DETR: {stats['use_rtdetr']}")
+    print(f"   Stats: {stats}")
+    print("   ✅ Fusion completed (RT-DETR may return 0 cells on blank image)")
+    
+    return engine
+
+
+def test_adaptive_strategy():
+    """测试自适应策略"""
+    print("\n" + "=" * 60)
+    print("Test 3: 自适应策略测试")
+    print("=" * 60)
+    
+    engine = CellFusionEngine(rtdetr_detector=None, config={'skip_rtdetr_for_txt_pdf': True})
+    
+    # Test 3.1: 文字PDF + 正常单元格数 → 跳过
+    should_use = engine.should_use_rtdetr('txt', unet_cell_count=10, table_size=(500, 500))
+    print(f"📄 Text PDF, 10 cells: use_rtdetr={should_use}")
+    assert not should_use, "❌ Should skip RT-DETR"
+    print("   ✅ Correct")
+    
+    # Test 3.2: 扫描PDF + 正常单元格数 → 跳过(因为检测器未初始化)
+    should_use = engine.should_use_rtdetr('ocr', unet_cell_count=10, table_size=(500, 500))
+    print(f"🔍 Scan PDF, 10 cells, no detector: use_rtdetr={should_use}")
+    assert not should_use, "❌ Should skip (detector not available)"
+    print("   ✅ Correct")
+    
+    # Test 3.3: UNet为空 → 强制启用(但检测器未初始化,仍跳过)
+    should_use = engine.should_use_rtdetr('ocr', unet_cell_count=0, table_size=(500, 500))
+    print(f"🚨 Scan PDF, 0 cells, no detector: use_rtdetr={should_use}")
+    print("   ⚠️ Would force enable if detector available")
+    
+    print("\n✅ All adaptive strategy tests passed")
+
+
+def main():
+    print("🚀 多源单元格融合系统测试\n")
+    
+    # Test 1: RT-DETR 检测器
+    detector = test_rtdetr_detector()
+    
+    # Test 2: 融合引擎
+    if detector:
+        test_fusion_engine(detector)
+    
+    # Test 3: 自适应策略
+    test_adaptive_strategy()
+    
+    print("\n" + "=" * 60)
+    print("✅ 所有测试完成!")
+    print("=" * 60)
+
+
+if __name__ == "__main__":
+    main()