test_layout_detector.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. """
  2. 布局检测器测试脚本
  3. """
  4. import sys
  5. from pathlib import Path
  6. import cv2
  7. import random
  8. # 添加项目根目录到路径
  9. project_root = Path(__file__).parents[1]
  10. sys.path.insert(0, str(project_root))
  11. from models.adapters import PaddleLayoutDetector
  12. def test_layout_detector():
  13. """测试 PaddleX 布局检测器"""
  14. # 测试配置
  15. config = {
  16. 'model_dir': '/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/RT-DETR-H_layout_17cls.onnx',
  17. 'device': 'cpu',
  18. 'conf': 0.25
  19. }
  20. # 初始化检测器
  21. print("🔧 Initializing detector...")
  22. detector = PaddleLayoutDetector(config)
  23. detector.initialize()
  24. # 读取测试图像
  25. img_path = "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_001.png"
  26. print(f"\n📖 Loading image: {img_path}")
  27. img = cv2.imread(img_path)
  28. if img is None:
  29. print(f"❌ Failed to load image: {img_path}")
  30. return
  31. print(f" Image shape: {img.shape}")
  32. # 执行检测
  33. print("\n🔍 Detecting layout...")
  34. results = detector.detect(img)
  35. print(f"\n✅ 检测到 {len(results)} 个区域:")
  36. for i, res in enumerate(results, 1):
  37. print(f" [{i}] {res['category']}: "
  38. f"score={res['confidence']:.3f}, "
  39. f"bbox={res['bbox']}, "
  40. f"size={res['raw']['width']}x{res['raw']['height']}, "
  41. f"original={res['raw']['original_category_name']}")
  42. # 统计各类别
  43. category_counts = {}
  44. for res in results:
  45. cat = res['category']
  46. category_counts[cat] = category_counts.get(cat, 0) + 1
  47. print(f"\n📊 类别统计 (MinerU格式):")
  48. for cat, count in sorted(category_counts.items()):
  49. print(f" - {cat}: {count}")
  50. # 可视化结果
  51. if len(results) > 0:
  52. print("\n🎨 Generating visualization...")
  53. # 为每个类别分配颜色
  54. category_colors = {}
  55. for res in results:
  56. cat = res['category']
  57. if cat not in category_colors:
  58. category_colors[cat] = (
  59. random.randint(50, 255),
  60. random.randint(50, 255),
  61. random.randint(50, 255)
  62. )
  63. # 绘制检测框
  64. vis_img = img.copy()
  65. for res in results:
  66. bbox = res['bbox']
  67. x1, y1, x2, y2 = bbox
  68. cat = res['category']
  69. color = category_colors[cat]
  70. # 绘制矩形
  71. cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
  72. # 绘制标签
  73. label = f"{cat} {res['confidence']:.2f}"
  74. label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
  75. # 标签背景
  76. cv2.rectangle(
  77. vis_img,
  78. (x1, y1 - label_size[1] - 4),
  79. (x1 + label_size[0], y1),
  80. color,
  81. -1
  82. )
  83. # 标签文字
  84. cv2.putText(
  85. vis_img,
  86. label,
  87. (x1, y1 - 2),
  88. cv2.FONT_HERSHEY_SIMPLEX,
  89. 0.5,
  90. (255, 255, 255),
  91. 1
  92. )
  93. # 保存可视化结果
  94. output_dir = Path(__file__).parent / "output"
  95. output_dir.mkdir(exist_ok=True)
  96. output_path = output_dir / f"{Path(img_path).stem}_layout_vis.jpg"
  97. cv2.imwrite(str(output_path), vis_img)
  98. print(f"💾 Visualization saved to: {output_path}")
  99. # 清理
  100. detector.cleanup()
  101. print("\n✅ 测试完成!")
  102. if __name__ == "__main__":
  103. test_layout_detector()