Explorar o código

feat: 新增布局检测器测试脚本,支持图像加载、检测及可视化结果

zhch158_admin hai 2 semanas
pai
achega
4c56f9c64a
Modificáronse 1 ficheiros con 128 adicións e 0 borrados
  1. 128 0
      zhch/universal_doc_parser/tests/test_layout_detector.py

+ 128 - 0
zhch/universal_doc_parser/tests/test_layout_detector.py

@@ -0,0 +1,128 @@
+"""
+布局检测器测试脚本
+"""
+
+import sys
+from pathlib import Path
+import cv2
+import random
+
+# 添加项目根目录到路径
+project_root = Path(__file__).parents[1]
+sys.path.insert(0, str(project_root))
+
+from models.adapters import PaddleLayoutDetector
+
+
+def test_layout_detector():
+    """测试 PaddleX 布局检测器"""
+    
+    # 测试配置
+    config = {
+        'model_dir': '/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/RT-DETR-H_layout_17cls.onnx',
+        'device': 'cpu',
+        'conf': 0.25
+    }
+    
+    # 初始化检测器
+    print("🔧 Initializing detector...")
+    detector = PaddleLayoutDetector(config)
+    detector.initialize()
+    
+    # 读取测试图像
+    img_path = "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_001.png"
+    print(f"\n📖 Loading image: {img_path}")
+    img = cv2.imread(img_path)
+    
+    if img is None:
+        print(f"❌ Failed to load image: {img_path}")
+        return
+    
+    print(f"   Image shape: {img.shape}")
+    
+    # 执行检测
+    print("\n🔍 Detecting layout...")
+    results = detector.detect(img)
+    
+    print(f"\n✅ 检测到 {len(results)} 个区域:")
+    for i, res in enumerate(results, 1):
+        print(f"  [{i}] {res['category']}: "
+              f"score={res['confidence']:.3f}, "
+              f"bbox={res['bbox']}, "
+              f"size={res['raw']['width']}x{res['raw']['height']}, "
+              f"original={res['raw']['original_category_name']}")
+    
+    # 统计各类别
+    category_counts = {}
+    for res in results:
+        cat = res['category']
+        category_counts[cat] = category_counts.get(cat, 0) + 1
+    
+    print(f"\n📊 类别统计 (MinerU格式):")
+    for cat, count in sorted(category_counts.items()):
+        print(f"  - {cat}: {count}")
+    
+    # 可视化结果
+    if len(results) > 0:
+        print("\n🎨 Generating visualization...")
+        
+        # 为每个类别分配颜色
+        category_colors = {}
+        for res in results:
+            cat = res['category']
+            if cat not in category_colors:
+                category_colors[cat] = (
+                    random.randint(50, 255),
+                    random.randint(50, 255),
+                    random.randint(50, 255)
+                )
+        
+        # 绘制检测框
+        vis_img = img.copy()
+        for res in results:
+            bbox = res['bbox']
+            x1, y1, x2, y2 = bbox
+            cat = res['category']
+            color = category_colors[cat]
+            
+            # 绘制矩形
+            cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
+            
+            # 绘制标签
+            label = f"{cat} {res['confidence']:.2f}"
+            label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
+            
+            # 标签背景
+            cv2.rectangle(
+                vis_img,
+                (x1, y1 - label_size[1] - 4),
+                (x1 + label_size[0], y1),
+                color,
+                -1
+            )
+            
+            # 标签文字
+            cv2.putText(
+                vis_img,
+                label,
+                (x1, y1 - 2),
+                cv2.FONT_HERSHEY_SIMPLEX,
+                0.5,
+                (255, 255, 255),
+                1
+            )
+        
+        # 保存可视化结果
+        output_dir = Path(__file__).parent / "output"
+        output_dir.mkdir(exist_ok=True)
+        output_path = output_dir / f"{Path(img_path).stem}_layout_vis.jpg"
+        cv2.imwrite(str(output_path), vis_img)
+        print(f"💾 Visualization saved to: {output_path}")
+    
+    # 清理
+    detector.cleanup()
+    print("\n✅ 测试完成!")
+
+
+if __name__ == "__main__":
+    test_layout_detector()