Browse Source

feat: 新增 DiT Layout Detector 测试脚本

- 创建测试脚本以验证 DiT Layout Detector 适配器的功能,支持 PDF、图像和目录输入。
- 实现页面范围过滤、布局检测和结果统计功能,提供可视化结果保存选项。
- 添加命令行参数解析,支持自定义配置和模型权重设置,提升测试灵活性和可用性。
zhch158_admin 1 tuần trước cách đây
mục cha
commit
bdc29cb5a4
1 tập tin đã thay đổi với 529 bổ sung0 xóa
  1. 529 0
      ocr_tools/universal_doc_parser/tests/test_dit_layout_adapter.py

+ 529 - 0
ocr_tools/universal_doc_parser/tests/test_dit_layout_adapter.py

@@ -0,0 +1,529 @@
+"""
+DiT Layout Detector 测试脚本
+
+测试 DitLayoutDetector 适配器,支持:
+- PDF 文件输入(自动转换为图像)
+- 图像文件输入
+- 目录输入(批量处理)
+- 页面范围过滤
+- 布局检测和结果统计
+- 可视化结果保存
+"""
+
+import sys
+import json
+import argparse
+from pathlib import Path
+from typing import List, Dict, Any
+
+import cv2
+
+# 添加项目根目录到路径
+project_root = Path(__file__).parents[1]
+sys.path.insert(0, str(project_root))
+
+# 添加 ocr_platform 根目录(用于导入 ocr_utils)
+ocr_platform_root = project_root.parents[1]
+if str(ocr_platform_root) not in sys.path:
+    sys.path.insert(0, str(ocr_platform_root))
+
+from dotenv import load_dotenv
+load_dotenv(override=True)
+
+from models.adapters.dit_layout_adapter import DitLayoutDetector
+from ocr_utils.file_utils import convert_pdf_to_images, get_image_files_from_dir
+
+
+def parse_args():
+    """解析命令行参数"""
+    parser = argparse.ArgumentParser(
+        description="测试 DiT Layout Detector 适配器",
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog="""
+示例:
+  # 测试 PDF 文件(处理所有页面)
+  python test_dit_layout_adapter.py --input /path/to/document.pdf
+
+  # 测试 PDF 文件(指定页面范围)
+  python test_dit_layout_adapter.py --input /path/to/document.pdf --pages "1-5,10-15"
+
+  # 测试图像文件
+  python test_dit_layout_adapter.py --input /path/to/image.png
+
+  # 测试目录(批量处理)
+  python test_dit_layout_adapter.py --input /path/to/images/ --output-dir ./results
+
+  # 使用自定义配置
+  python test_dit_layout_adapter.py --input /path/to/document.pdf \\
+      --config-file ./custom_config.yaml \\
+      --model-weights /path/to/model.pth \\
+      --device cuda \\
+      --conf 0.5
+        """
+    )
+    
+    parser.add_argument(
+        "--input",
+        type=str,
+        required=True,
+        help="输入路径(PDF文件/图像文件/图像目录)"
+    )
+    
+    parser.add_argument(
+        "--output-dir",
+        type=str,
+        default=None,
+        help="输出目录(默认: tests/output/)"
+    )
+    
+    parser.add_argument(
+        "--config-file",
+        type=str,
+        default=None,
+        help="DiT 配置文件路径(可选,默认使用内置配置)"
+    )
+    
+    parser.add_argument(
+        "--model-weights",
+        type=str,
+        default=None,
+        help="模型权重路径或 URL(可选,默认从 HuggingFace 下载)"
+    )
+    
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cpu",
+        choices=["cpu", "cuda", "mps"],
+        help="运行设备 (默认: cpu)"
+    )
+    
+    parser.add_argument(
+        "--conf",
+        type=float,
+        default=0.3,
+        help="置信度阈值 (默认: 0.3)"
+    )
+    
+    parser.add_argument(
+        "--pages",
+        type=str,
+        default=None,
+        help="页面范围(如 '1-5,7,9-12'),仅对 PDF 有效"
+    )
+    
+    parser.add_argument(
+        "--remove-overlap",
+        action="store_true",
+        default=True,
+        help="启用重叠框处理(默认启用)"
+    )
+    
+    parser.add_argument(
+        "--no-remove-overlap",
+        action="store_false",
+        dest="remove_overlap",
+        help="禁用重叠框处理"
+    )
+    
+    parser.add_argument(
+        "--iou-threshold",
+        type=float,
+        default=0.8,
+        help="IoU 阈值 (默认: 0.8)"
+    )
+    
+    parser.add_argument(
+        "--overlap-ratio-threshold",
+        type=float,
+        default=0.8,
+        help="重叠比例阈值 (默认: 0.8)"
+    )
+    
+    parser.add_argument(
+        "--dpi",
+        type=int,
+        default=200,
+        help="PDF 转图像 DPI (默认: 200)"
+    )
+    
+    parser.add_argument(
+        "--save-json",
+        action="store_true",
+        help="保存 JSON 格式的检测结果"
+    )
+    
+    parser.add_argument(
+        "--min-confidence",
+        type=float,
+        default=0.0,
+        help="可视化时的最小置信度阈值 (默认: 0.0)"
+    )
+    
+    return parser.parse_args()
+
+
+def get_input_images(input_path: str, page_range: str = None, dpi: int = 200) -> List[str]:
+    """
+    获取输入图像文件列表
+    
+    Args:
+        input_path: 输入路径(PDF/图像/目录)
+        page_range: 页面范围(仅对 PDF 有效)
+        dpi: PDF 转图像 DPI
+    
+    Returns:
+        图像文件路径列表
+    """
+    input_path_obj = Path(input_path)
+    
+    if not input_path_obj.exists():
+        raise FileNotFoundError(f"输入路径不存在: {input_path}")
+    
+    image_files = []
+    
+    if input_path_obj.is_file():
+        if input_path_obj.suffix.lower() == '.pdf':
+            # PDF 文件:转换为图像
+            print(f"📄 处理 PDF 文件: {input_path_obj.name}")
+            image_files = convert_pdf_to_images(
+                str(input_path_obj),
+                output_dir=None,  # 使用默认输出目录
+                dpi=dpi,
+                page_range=page_range
+            )
+            if not image_files:
+                raise ValueError(f"PDF 转换失败,未生成图像文件")
+            print(f"✅ PDF 转换为 {len(image_files)} 张图像")
+        
+        elif input_path_obj.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']:
+            # 图像文件:直接添加
+            image_files = [str(input_path_obj)]
+            print(f"📷 处理图像文件: {input_path_obj.name}")
+        
+        else:
+            raise ValueError(f"不支持的文件类型: {input_path_obj.suffix}")
+    
+    elif input_path_obj.is_dir():
+        # 目录:扫描所有图像文件
+        image_files = get_image_files_from_dir(input_path_obj)
+        if not image_files:
+            raise ValueError(f"目录中未找到图像文件: {input_path}")
+        print(f"📁 从目录中找到 {len(image_files)} 张图像")
+    
+    else:
+        raise ValueError(f"无效的输入路径: {input_path}")
+    
+    return sorted(image_files)
+
+
+def build_config(args, project_root: Path) -> Dict[str, Any]:
+    """
+    构建检测器配置
+    
+    Args:
+        args: 命令行参数
+        project_root: 项目根目录
+    
+    Returns:
+        配置字典
+    """
+    config = {
+        'device': args.device,
+        'conf': args.conf,
+        'remove_overlap': args.remove_overlap,
+        'iou_threshold': args.iou_threshold,
+        'overlap_ratio_threshold': args.overlap_ratio_threshold,
+    }
+    
+    # 配置文件路径
+    if args.config_file:
+        config['config_file'] = args.config_file
+    else:
+        # 使用默认配置文件
+        default_config_file = project_root / 'dit_support' / 'configs' / 'cascade' / 'cascade_dit_large.yaml'
+        if default_config_file.exists():
+            config['config_file'] = str(default_config_file)
+        else:
+            print(f"⚠️  警告: 默认配置文件不存在: {default_config_file}")
+            print("   请使用 --config-file 指定配置文件路径")
+    
+    # 模型权重
+    if args.model_weights:
+        config['model_weights'] = args.model_weights
+    else:
+        # 使用默认模型权重 URL
+        config['model_weights'] = (
+            'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth'
+        )
+    
+    return config
+
+
+def process_images(
+    detector: DitLayoutDetector,
+    image_files: List[str],
+    output_dir: Path,
+    save_json: bool = False,
+    min_confidence: float = 0.0
+) -> Dict[str, Any]:
+    """
+    处理图像列表,进行布局检测
+    
+    Args:
+        detector: 布局检测器
+        image_files: 图像文件路径列表
+        output_dir: 输出目录
+        save_json: 是否保存 JSON 结果
+        min_confidence: 最小置信度阈值
+    
+    Returns:
+        统计结果字典
+    """
+    all_results = {}
+    total_stats = {
+        'total_pages': len(image_files),
+        'total_regions': 0,
+        'category_counts': {},
+        'confidence_stats': {
+            'min': float('inf'),
+            'max': 0.0,
+            'sum': 0.0,
+            'count': 0
+        }
+    }
+    
+    for idx, image_path in enumerate(image_files, 1):
+        print(f"\n{'='*60}")
+        print(f"📖 处理图像 {idx}/{len(image_files)}: {Path(image_path).name}")
+        print(f"{'='*60}")
+        
+        # 读取图像
+        img = cv2.imread(image_path)
+        if img is None:
+            print(f"❌ 无法读取图像: {image_path}")
+            continue
+        
+        print(f"   图像尺寸: {img.shape[1]}x{img.shape[0]}")
+        
+        # 执行检测
+        try:
+            results = detector.detect(img)
+            print(f"✅ 检测到 {len(results)} 个区域")
+            
+            # 统计结果
+            page_stats = {
+                'image_path': image_path,
+                'image_size': [img.shape[1], img.shape[0]],
+                'regions': [],
+                'category_counts': {}
+            }
+            
+            for res in results:
+                # 添加到页面统计
+                page_stats['regions'].append({
+                    'category': res['category'],
+                    'bbox': res['bbox'],
+                    'confidence': float(res['confidence']),
+                    'original_label': res.get('raw', {}).get('original_label', 'unknown')
+                })
+                
+                # 更新类别统计
+                cat = res['category']
+                page_stats['category_counts'][cat] = page_stats['category_counts'].get(cat, 0) + 1
+                total_stats['category_counts'][cat] = total_stats['category_counts'].get(cat, 0) + 1
+                
+                # 更新置信度统计
+                conf = res['confidence']
+                total_stats['confidence_stats']['min'] = min(total_stats['confidence_stats']['min'], conf)
+                total_stats['confidence_stats']['max'] = max(total_stats['confidence_stats']['max'], conf)
+                total_stats['confidence_stats']['sum'] += conf
+                total_stats['confidence_stats']['count'] += 1
+            
+            total_stats['total_regions'] += len(results)
+            all_results[image_path] = page_stats
+            
+            # 打印页面统计
+            if page_stats['category_counts']:
+                print(f"\n   类别统计:")
+                for cat, count in sorted(page_stats['category_counts'].items()):
+                    print(f"     - {cat}: {count}")
+            
+            # 可视化
+            if len(results) > 0:
+                print(f"\n   🎨 生成可视化图像...")
+                
+                image_stem = Path(image_path).stem
+                output_path = output_dir / f"{image_stem}_dit_layout_vis.jpg"
+                
+                vis_img = detector.visualize(
+                    img,
+                    results,
+                    output_path=str(output_path),
+                    show_confidence=True,
+                    min_confidence=min_confidence
+                )
+                
+                print(f"   💾 可视化图像已保存: {output_path}")
+            
+            # 保存 JSON 结果
+            if save_json:
+                json_path = output_dir / f"{Path(image_path).stem}_dit_layout_results.json"
+                with open(json_path, 'w', encoding='utf-8') as f:
+                    json.dump(page_stats, f, ensure_ascii=False, indent=2)
+                print(f"   💾 JSON 结果已保存: {json_path}")
+        
+        except Exception as e:
+            print(f"❌ 检测失败: {e}")
+            import traceback
+            traceback.print_exc()
+            continue
+    
+    # 计算平均置信度
+    if total_stats['confidence_stats']['count'] > 0:
+        total_stats['confidence_stats']['mean'] = (
+            total_stats['confidence_stats']['sum'] / total_stats['confidence_stats']['count']
+        )
+    else:
+        total_stats['confidence_stats']['mean'] = 0.0
+        total_stats['confidence_stats']['min'] = 0.0
+    
+    return {
+        'all_results': all_results,
+        'total_stats': total_stats
+    }
+
+
+def print_summary(stats: Dict[str, Any]):
+    """打印统计摘要"""
+    total_stats = stats['total_stats']
+    
+    print(f"\n{'='*60}")
+    print(f"📊 检测结果摘要")
+    print(f"{'='*60}")
+    print(f"总页数: {total_stats['total_pages']}")
+    print(f"总区域数: {total_stats['total_regions']}")
+    
+    if total_stats['total_regions'] > 0:
+        print(f"\n类别统计:")
+        for cat, count in sorted(total_stats['category_counts'].items()):
+            percentage = (count / total_stats['total_regions']) * 100
+            print(f"  - {cat}: {count} ({percentage:.1f}%)")
+        
+        conf_stats = total_stats['confidence_stats']
+        print(f"\n置信度统计:")
+        print(f"  - 最小值: {conf_stats['min']:.3f}")
+        print(f"  - 最大值: {conf_stats['max']:.3f}")
+        print(f"  - 平均值: {conf_stats['mean']:.3f}")
+
+
+def main():
+    """主函数"""
+    args = parse_args()
+    
+    # 设置输出目录
+    if args.output_dir:
+        output_dir = Path(args.output_dir)
+    else:
+        output_dir = Path(__file__).parent / "output"
+    output_dir.mkdir(parents=True, exist_ok=True)
+    print(f"📁 输出目录: {output_dir}")
+    
+    # 获取输入图像列表
+    try:
+        image_files = get_input_images(
+            args.input,
+            page_range=args.pages,
+            dpi=args.dpi
+        )
+    except Exception as e:
+        print(f"❌ 错误: {e}")
+        sys.exit(1)
+    
+    if not image_files:
+        print("❌ 未找到要处理的图像文件")
+        sys.exit(1)
+    
+    # 构建配置
+    project_root = Path(__file__).parents[1]
+    config = build_config(args, project_root)
+    
+    # 初始化检测器
+    print(f"\n{'='*60}")
+    print(f"🔧 初始化 DiT Layout Detector")
+    print(f"{'='*60}")
+    print(f"配置文件: {config.get('config_file', 'N/A')}")
+    print(f"模型权重: {config.get('model_weights', 'N/A')}")
+    print(f"设备: {config['device']}")
+    print(f"置信度阈值: {config['conf']}")
+    print(f"重叠框处理: {config['remove_overlap']}")
+    
+    try:
+        detector = DitLayoutDetector(config)
+        detector.initialize()
+        print("✅ 检测器初始化成功")
+    except Exception as e:
+        print(f"❌ 检测器初始化失败: {e}")
+        import traceback
+        traceback.print_exc()
+        sys.exit(1)
+    
+    # 处理图像
+    try:
+        stats = process_images(
+            detector,
+            image_files,
+            output_dir,
+            save_json=args.save_json,
+            min_confidence=args.min_confidence
+        )
+        
+        # 打印摘要
+        print_summary(stats)
+        
+        # 保存总体统计
+        summary_path = output_dir / "detection_summary.json"
+        with open(summary_path, 'w', encoding='utf-8') as f:
+            json.dump(stats['total_stats'], f, ensure_ascii=False, indent=2)
+        print(f"\n💾 统计摘要已保存: {summary_path}")
+        
+    except Exception as e:
+        print(f"❌ 处理过程中出错: {e}")
+        import traceback
+        traceback.print_exc()
+    finally:
+        # 清理资源
+        detector.cleanup()
+        print("\n✅ 测试完成!")
+
+
+if __name__ == "__main__":
+    if len(sys.argv) == 1:
+        # 没有命令行参数时,使用默认配置运行
+        print("ℹ️  未提供命令行参数,使用默认配置运行...")
+        
+        # 默认配置
+        default_config = {
+            # 测试输入
+            "input": "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司.pdf",
+            "output-dir": "./output/2023年度报告母公司_dit_layout_adapter",
+
+            
+            # 页面范围(可选)
+            # "pages": "2-7,24, 26, 29-34",  # 只处理前1页
+            "pages": "32",  # 处理指定页面
+
+			# 是否启用重叠框处理
+			# "no-remove-overlap": True,
+        }
+        
+        # 构造参数
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            if isinstance(value, bool):
+                if value:
+                    sys.argv.append(f"--{key}")
+            else:
+                sys.argv.extend([f"--{key}", str(value)])
+    
+    sys.exit(main())