|
@@ -0,0 +1,128 @@
|
|
|
|
|
+"""
|
|
|
|
|
+布局检测器测试脚本
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import sys
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+import cv2
|
|
|
|
|
+import random
|
|
|
|
|
+
|
|
|
|
|
+# 添加项目根目录到路径
|
|
|
|
|
+project_root = Path(__file__).parents[1]
|
|
|
|
|
+sys.path.insert(0, str(project_root))
|
|
|
|
|
+
|
|
|
|
|
+from models.adapters import PaddleLayoutDetector
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def test_layout_detector():
|
|
|
|
|
+ """测试 PaddleX 布局检测器"""
|
|
|
|
|
+
|
|
|
|
|
+ # 测试配置
|
|
|
|
|
+ config = {
|
|
|
|
|
+ 'model_dir': '/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/RT-DETR-H_layout_17cls.onnx',
|
|
|
|
|
+ 'device': 'cpu',
|
|
|
|
|
+ 'conf': 0.25
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 初始化检测器
|
|
|
|
|
+ print("🔧 Initializing detector...")
|
|
|
|
|
+ detector = PaddleLayoutDetector(config)
|
|
|
|
|
+ detector.initialize()
|
|
|
|
|
+
|
|
|
|
|
+ # 读取测试图像
|
|
|
|
|
+ img_path = "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_001.png"
|
|
|
|
|
+ print(f"\n📖 Loading image: {img_path}")
|
|
|
|
|
+ img = cv2.imread(img_path)
|
|
|
|
|
+
|
|
|
|
|
+ if img is None:
|
|
|
|
|
+ print(f"❌ Failed to load image: {img_path}")
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ print(f" Image shape: {img.shape}")
|
|
|
|
|
+
|
|
|
|
|
+ # 执行检测
|
|
|
|
|
+ print("\n🔍 Detecting layout...")
|
|
|
|
|
+ results = detector.detect(img)
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n✅ 检测到 {len(results)} 个区域:")
|
|
|
|
|
+ for i, res in enumerate(results, 1):
|
|
|
|
|
+ print(f" [{i}] {res['category']}: "
|
|
|
|
|
+ f"score={res['confidence']:.3f}, "
|
|
|
|
|
+ f"bbox={res['bbox']}, "
|
|
|
|
|
+ f"size={res['raw']['width']}x{res['raw']['height']}, "
|
|
|
|
|
+ f"original={res['raw']['original_category_name']}")
|
|
|
|
|
+
|
|
|
|
|
+ # 统计各类别
|
|
|
|
|
+ category_counts = {}
|
|
|
|
|
+ for res in results:
|
|
|
|
|
+ cat = res['category']
|
|
|
|
|
+ category_counts[cat] = category_counts.get(cat, 0) + 1
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n📊 类别统计 (MinerU格式):")
|
|
|
|
|
+ for cat, count in sorted(category_counts.items()):
|
|
|
|
|
+ print(f" - {cat}: {count}")
|
|
|
|
|
+
|
|
|
|
|
+ # 可视化结果
|
|
|
|
|
+ if len(results) > 0:
|
|
|
|
|
+ print("\n🎨 Generating visualization...")
|
|
|
|
|
+
|
|
|
|
|
+ # 为每个类别分配颜色
|
|
|
|
|
+ category_colors = {}
|
|
|
|
|
+ for res in results:
|
|
|
|
|
+ cat = res['category']
|
|
|
|
|
+ if cat not in category_colors:
|
|
|
|
|
+ category_colors[cat] = (
|
|
|
|
|
+ random.randint(50, 255),
|
|
|
|
|
+ random.randint(50, 255),
|
|
|
|
|
+ random.randint(50, 255)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制检测框
|
|
|
|
|
+ vis_img = img.copy()
|
|
|
|
|
+ for res in results:
|
|
|
|
|
+ bbox = res['bbox']
|
|
|
|
|
+ x1, y1, x2, y2 = bbox
|
|
|
|
|
+ cat = res['category']
|
|
|
|
|
+ color = category_colors[cat]
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制矩形
|
|
|
|
|
+ cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制标签
|
|
|
|
|
+ label = f"{cat} {res['confidence']:.2f}"
|
|
|
|
|
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
|
|
|
|
+
|
|
|
|
|
+ # 标签背景
|
|
|
|
|
+ cv2.rectangle(
|
|
|
|
|
+ vis_img,
|
|
|
|
|
+ (x1, y1 - label_size[1] - 4),
|
|
|
|
|
+ (x1 + label_size[0], y1),
|
|
|
|
|
+ color,
|
|
|
|
|
+ -1
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 标签文字
|
|
|
|
|
+ cv2.putText(
|
|
|
|
|
+ vis_img,
|
|
|
|
|
+ label,
|
|
|
|
|
+ (x1, y1 - 2),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
|
|
+ 0.5,
|
|
|
|
|
+ (255, 255, 255),
|
|
|
|
|
+ 1
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 保存可视化结果
|
|
|
|
|
+ output_dir = Path(__file__).parent / "output"
|
|
|
|
|
+ output_dir.mkdir(exist_ok=True)
|
|
|
|
|
+ output_path = output_dir / f"{Path(img_path).stem}_layout_vis.jpg"
|
|
|
|
|
+ cv2.imwrite(str(output_path), vis_img)
|
|
|
|
|
+ print(f"💾 Visualization saved to: {output_path}")
|
|
|
|
|
+
|
|
|
|
|
+ # 清理
|
|
|
|
|
+ detector.cleanup()
|
|
|
|
|
+ print("\n✅ 测试完成!")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ test_layout_detector()
|