|
|
@@ -0,0 +1,174 @@
|
|
|
+"""
|
|
|
+测试多源单元格融合系统
|
|
|
+
|
|
|
+验证:
|
|
|
+1. RT-DETR检测器初始化
|
|
|
+2. 融合引擎基本功能
|
|
|
+3. 自适应策略(文字PDF跳过RT-DETR)
|
|
|
+"""
|
|
|
+
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+# 添加路径
|
|
|
+sys.path.insert(0, str(Path(__file__).parents[5]))
|
|
|
+
|
|
|
+from ocr_tools.universal_doc_parser.models.adapters.paddle_wired_table_cells_detector import PaddleWiredTableCellsDetector
|
|
|
+from ocr_tools.universal_doc_parser.models.adapters.wired_table.cell_fusion import CellFusionEngine
|
|
|
+
|
|
|
+
|
|
|
+def test_rtdetr_detector():
|
|
|
+ """测试 RT-DETR 检测器"""
|
|
|
+ print("=" * 60)
|
|
|
+ print("Test 1: RT-DETR 单元格检测器")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ # 配置
|
|
|
+ config = {
|
|
|
+ 'model_dir': '/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/pytorch_models/Table/RT-DETR-L_wired_table_cell_det.onnx',
|
|
|
+ 'device': 'cpu',
|
|
|
+ 'conf': 0.5
|
|
|
+ }
|
|
|
+
|
|
|
+ # 初始化
|
|
|
+ try:
|
|
|
+ detector = PaddleWiredTableCellsDetector(config)
|
|
|
+ detector.initialize()
|
|
|
+ print("✅ RT-DETR detector initialized successfully")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ Failed to initialize: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 测试检测
|
|
|
+ test_image_path = "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/tests/2023年度报告母公司_page_005_270_table.png"
|
|
|
+ if Path(test_image_path).exists():
|
|
|
+ img = cv2.imread(test_image_path)
|
|
|
+ print(f"\n📖 Test image: {img.shape}")
|
|
|
+
|
|
|
+ results = detector.detect(img, conf_threshold=0.5)
|
|
|
+ print(f"✅ Detected {len(results)} cells")
|
|
|
+
|
|
|
+ if len(results) > 0:
|
|
|
+ print(f" Sample cell: {results[0]}")
|
|
|
+ else:
|
|
|
+ print(f"⚠️ Test image not found: {test_image_path}")
|
|
|
+
|
|
|
+ return detector
|
|
|
+
|
|
|
+
|
|
|
+def test_fusion_engine(detector):
|
|
|
+ """测试融合引擎"""
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("Test 2: 融合引擎基本功能")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ # 配置
|
|
|
+ fusion_config = {
|
|
|
+ 'unet_weight': 0.6,
|
|
|
+ 'rtdetr_weight': 0.4,
|
|
|
+ 'iou_merge_threshold': 0.7,
|
|
|
+ 'iou_nms_threshold': 0.5,
|
|
|
+ 'rtdetr_conf_threshold': 0.5,
|
|
|
+ 'enable_ocr_compensation': True,
|
|
|
+ 'skip_rtdetr_for_txt_pdf': True
|
|
|
+ }
|
|
|
+
|
|
|
+ # 初始化
|
|
|
+ engine = CellFusionEngine(rtdetr_detector=detector, config=fusion_config)
|
|
|
+ print("✅ Fusion engine initialized")
|
|
|
+
|
|
|
+ # 模拟数据
|
|
|
+ table_image = np.ones((500, 500, 3), dtype=np.uint8) * 255
|
|
|
+
|
|
|
+ unet_cells = [
|
|
|
+ [10, 10, 100, 50],
|
|
|
+ [10, 60, 100, 100],
|
|
|
+ [110, 10, 200, 50]
|
|
|
+ ]
|
|
|
+
|
|
|
+ ocr_boxes = [
|
|
|
+ {'bbox': [20, 20, 80, 40], 'text': 'Cell 1'},
|
|
|
+ {'bbox': [20, 70, 80, 90], 'text': 'Cell 2'}
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Test 2.1: 文字PDF模式(应跳过RT-DETR)
|
|
|
+ print("\n📄 Test 2.1: Text PDF mode (should skip RT-DETR)")
|
|
|
+ fused_cells, stats = engine.fuse(
|
|
|
+ table_image=table_image,
|
|
|
+ unet_cells=unet_cells,
|
|
|
+ ocr_boxes=ocr_boxes,
|
|
|
+ pdf_type='txt',
|
|
|
+ upscale=1.0
|
|
|
+ )
|
|
|
+ print(f" Use RT-DETR: {stats['use_rtdetr']}")
|
|
|
+ print(f" Fused cells: {len(fused_cells)}")
|
|
|
+ assert not stats['use_rtdetr'], "❌ Should skip RT-DETR for text PDF"
|
|
|
+ assert len(fused_cells) == len(unet_cells), "❌ Should keep UNet cells only"
|
|
|
+ print(" ✅ Correctly skipped RT-DETR for text PDF")
|
|
|
+
|
|
|
+ # Test 2.2: 扫描PDF模式(应启用RT-DETR,但因为是假图片可能失败)
|
|
|
+ print("\n🔍 Test 2.2: Scan PDF mode (should enable RT-DETR)")
|
|
|
+ fused_cells, stats = engine.fuse(
|
|
|
+ table_image=table_image,
|
|
|
+ unet_cells=unet_cells,
|
|
|
+ ocr_boxes=ocr_boxes,
|
|
|
+ pdf_type='ocr',
|
|
|
+ upscale=1.0
|
|
|
+ )
|
|
|
+ print(f" Use RT-DETR: {stats['use_rtdetr']}")
|
|
|
+ print(f" Stats: {stats}")
|
|
|
+ print(" ✅ Fusion completed (RT-DETR may return 0 cells on blank image)")
|
|
|
+
|
|
|
+ return engine
|
|
|
+
|
|
|
+
|
|
|
+def test_adaptive_strategy():
|
|
|
+ """测试自适应策略"""
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("Test 3: 自适应策略测试")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ engine = CellFusionEngine(rtdetr_detector=None, config={'skip_rtdetr_for_txt_pdf': True})
|
|
|
+
|
|
|
+ # Test 3.1: 文字PDF + 正常单元格数 → 跳过
|
|
|
+ should_use = engine.should_use_rtdetr('txt', unet_cell_count=10, table_size=(500, 500))
|
|
|
+ print(f"📄 Text PDF, 10 cells: use_rtdetr={should_use}")
|
|
|
+ assert not should_use, "❌ Should skip RT-DETR"
|
|
|
+ print(" ✅ Correct")
|
|
|
+
|
|
|
+ # Test 3.2: 扫描PDF + 正常单元格数 → 跳过(因为检测器未初始化)
|
|
|
+ should_use = engine.should_use_rtdetr('ocr', unet_cell_count=10, table_size=(500, 500))
|
|
|
+ print(f"🔍 Scan PDF, 10 cells, no detector: use_rtdetr={should_use}")
|
|
|
+ assert not should_use, "❌ Should skip (detector not available)"
|
|
|
+ print(" ✅ Correct")
|
|
|
+
|
|
|
+ # Test 3.3: UNet为空 → 强制启用(但检测器未初始化,仍跳过)
|
|
|
+ should_use = engine.should_use_rtdetr('ocr', unet_cell_count=0, table_size=(500, 500))
|
|
|
+ print(f"🚨 Scan PDF, 0 cells, no detector: use_rtdetr={should_use}")
|
|
|
+ print(" ⚠️ Would force enable if detector available")
|
|
|
+
|
|
|
+ print("\n✅ All adaptive strategy tests passed")
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ print("🚀 多源单元格融合系统测试\n")
|
|
|
+
|
|
|
+ # Test 1: RT-DETR 检测器
|
|
|
+ detector = test_rtdetr_detector()
|
|
|
+
|
|
|
+ # Test 2: 融合引擎
|
|
|
+ if detector:
|
|
|
+ test_fusion_engine(detector)
|
|
|
+
|
|
|
+ # Test 3: 自适应策略
|
|
|
+ test_adaptive_strategy()
|
|
|
+
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("✅ 所有测试完成!")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|