|
@@ -0,0 +1,529 @@
|
|
|
|
|
+"""
|
|
|
|
|
+DiT Layout Detector 测试脚本
|
|
|
|
|
+
|
|
|
|
|
+测试 DitLayoutDetector 适配器,支持:
|
|
|
|
|
+- PDF 文件输入(自动转换为图像)
|
|
|
|
|
+- 图像文件输入
|
|
|
|
|
+- 目录输入(批量处理)
|
|
|
|
|
+- 页面范围过滤
|
|
|
|
|
+- 布局检测和结果统计
|
|
|
|
|
+- 可视化结果保存
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import sys
|
|
|
|
|
+import json
|
|
|
|
|
+import argparse
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import List, Dict, Any
|
|
|
|
|
+
|
|
|
|
|
+import cv2
|
|
|
|
|
+
|
|
|
|
|
+# 添加项目根目录到路径
|
|
|
|
|
+project_root = Path(__file__).parents[1]
|
|
|
|
|
+sys.path.insert(0, str(project_root))
|
|
|
|
|
+
|
|
|
|
|
+# 添加 ocr_platform 根目录(用于导入 ocr_utils)
|
|
|
|
|
+ocr_platform_root = project_root.parents[1]
|
|
|
|
|
+if str(ocr_platform_root) not in sys.path:
|
|
|
|
|
+ sys.path.insert(0, str(ocr_platform_root))
|
|
|
|
|
+
|
|
|
|
|
+from dotenv import load_dotenv
|
|
|
|
|
+load_dotenv(override=True)
|
|
|
|
|
+
|
|
|
|
|
+from models.adapters.dit_layout_adapter import DitLayoutDetector
|
|
|
|
|
+from ocr_utils.file_utils import convert_pdf_to_images, get_image_files_from_dir
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def parse_args():
|
|
|
|
|
+ """解析命令行参数"""
|
|
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
|
|
+ description="测试 DiT Layout Detector 适配器",
|
|
|
|
|
+ formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
|
|
+ epilog="""
|
|
|
|
|
+示例:
|
|
|
|
|
+ # 测试 PDF 文件(处理所有页面)
|
|
|
|
|
+ python test_dit_layout_adapter.py --input /path/to/document.pdf
|
|
|
|
|
+
|
|
|
|
|
+ # 测试 PDF 文件(指定页面范围)
|
|
|
|
|
+ python test_dit_layout_adapter.py --input /path/to/document.pdf --pages "1-5,10-15"
|
|
|
|
|
+
|
|
|
|
|
+ # 测试图像文件
|
|
|
|
|
+ python test_dit_layout_adapter.py --input /path/to/image.png
|
|
|
|
|
+
|
|
|
|
|
+ # 测试目录(批量处理)
|
|
|
|
|
+ python test_dit_layout_adapter.py --input /path/to/images/ --output-dir ./results
|
|
|
|
|
+
|
|
|
|
|
+ # 使用自定义配置
|
|
|
|
|
+ python test_dit_layout_adapter.py --input /path/to/document.pdf \\
|
|
|
|
|
+ --config-file ./custom_config.yaml \\
|
|
|
|
|
+ --model-weights /path/to/model.pth \\
|
|
|
|
|
+ --device cuda \\
|
|
|
|
|
+ --conf 0.5
|
|
|
|
|
+ """
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--input",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ required=True,
|
|
|
|
|
+ help="输入路径(PDF文件/图像文件/图像目录)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--output-dir",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="输出目录(默认: tests/output/)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--config-file",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="DiT 配置文件路径(可选,默认使用内置配置)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--model-weights",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="模型权重路径或 URL(可选,默认从 HuggingFace 下载)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--device",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ default="cpu",
|
|
|
|
|
+ choices=["cpu", "cuda", "mps"],
|
|
|
|
|
+ help="运行设备 (默认: cpu)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--conf",
|
|
|
|
|
+ type=float,
|
|
|
|
|
+ default=0.3,
|
|
|
|
|
+ help="置信度阈值 (默认: 0.3)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--pages",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="页面范围(如 '1-5,7,9-12'),仅对 PDF 有效"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--remove-overlap",
|
|
|
|
|
+ action="store_true",
|
|
|
|
|
+ default=True,
|
|
|
|
|
+ help="启用重叠框处理(默认启用)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--no-remove-overlap",
|
|
|
|
|
+ action="store_false",
|
|
|
|
|
+ dest="remove_overlap",
|
|
|
|
|
+ help="禁用重叠框处理"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--iou-threshold",
|
|
|
|
|
+ type=float,
|
|
|
|
|
+ default=0.8,
|
|
|
|
|
+ help="IoU 阈值 (默认: 0.8)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--overlap-ratio-threshold",
|
|
|
|
|
+ type=float,
|
|
|
|
|
+ default=0.8,
|
|
|
|
|
+ help="重叠比例阈值 (默认: 0.8)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--dpi",
|
|
|
|
|
+ type=int,
|
|
|
|
|
+ default=200,
|
|
|
|
|
+ help="PDF 转图像 DPI (默认: 200)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--save-json",
|
|
|
|
|
+ action="store_true",
|
|
|
|
|
+ help="保存 JSON 格式的检测结果"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--min-confidence",
|
|
|
|
|
+ type=float,
|
|
|
|
|
+ default=0.0,
|
|
|
|
|
+ help="可视化时的最小置信度阈值 (默认: 0.0)"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return parser.parse_args()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_input_images(input_path: str, page_range: str = None, dpi: int = 200) -> List[str]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取输入图像文件列表
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ input_path: 输入路径(PDF/图像/目录)
|
|
|
|
|
+ page_range: 页面范围(仅对 PDF 有效)
|
|
|
|
|
+ dpi: PDF 转图像 DPI
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 图像文件路径列表
|
|
|
|
|
+ """
|
|
|
|
|
+ input_path_obj = Path(input_path)
|
|
|
|
|
+
|
|
|
|
|
+ if not input_path_obj.exists():
|
|
|
|
|
+ raise FileNotFoundError(f"输入路径不存在: {input_path}")
|
|
|
|
|
+
|
|
|
|
|
+ image_files = []
|
|
|
|
|
+
|
|
|
|
|
+ if input_path_obj.is_file():
|
|
|
|
|
+ if input_path_obj.suffix.lower() == '.pdf':
|
|
|
|
|
+ # PDF 文件:转换为图像
|
|
|
|
|
+ print(f"📄 处理 PDF 文件: {input_path_obj.name}")
|
|
|
|
|
+ image_files = convert_pdf_to_images(
|
|
|
|
|
+ str(input_path_obj),
|
|
|
|
|
+ output_dir=None, # 使用默认输出目录
|
|
|
|
|
+ dpi=dpi,
|
|
|
|
|
+ page_range=page_range
|
|
|
|
|
+ )
|
|
|
|
|
+ if not image_files:
|
|
|
|
|
+ raise ValueError(f"PDF 转换失败,未生成图像文件")
|
|
|
|
|
+ print(f"✅ PDF 转换为 {len(image_files)} 张图像")
|
|
|
|
|
+
|
|
|
|
|
+ elif input_path_obj.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']:
|
|
|
|
|
+ # 图像文件:直接添加
|
|
|
|
|
+ image_files = [str(input_path_obj)]
|
|
|
|
|
+ print(f"📷 处理图像文件: {input_path_obj.name}")
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"不支持的文件类型: {input_path_obj.suffix}")
|
|
|
|
|
+
|
|
|
|
|
+ elif input_path_obj.is_dir():
|
|
|
|
|
+ # 目录:扫描所有图像文件
|
|
|
|
|
+ image_files = get_image_files_from_dir(input_path_obj)
|
|
|
|
|
+ if not image_files:
|
|
|
|
|
+ raise ValueError(f"目录中未找到图像文件: {input_path}")
|
|
|
|
|
+ print(f"📁 从目录中找到 {len(image_files)} 张图像")
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"无效的输入路径: {input_path}")
|
|
|
|
|
+
|
|
|
|
|
+ return sorted(image_files)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def build_config(args, project_root: Path) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 构建检测器配置
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ args: 命令行参数
|
|
|
|
|
+ project_root: 项目根目录
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 配置字典
|
|
|
|
|
+ """
|
|
|
|
|
+ config = {
|
|
|
|
|
+ 'device': args.device,
|
|
|
|
|
+ 'conf': args.conf,
|
|
|
|
|
+ 'remove_overlap': args.remove_overlap,
|
|
|
|
|
+ 'iou_threshold': args.iou_threshold,
|
|
|
|
|
+ 'overlap_ratio_threshold': args.overlap_ratio_threshold,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 配置文件路径
|
|
|
|
|
+ if args.config_file:
|
|
|
|
|
+ config['config_file'] = args.config_file
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 使用默认配置文件
|
|
|
|
|
+ default_config_file = project_root / 'dit_support' / 'configs' / 'cascade' / 'cascade_dit_large.yaml'
|
|
|
|
|
+ if default_config_file.exists():
|
|
|
|
|
+ config['config_file'] = str(default_config_file)
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f"⚠️ 警告: 默认配置文件不存在: {default_config_file}")
|
|
|
|
|
+ print(" 请使用 --config-file 指定配置文件路径")
|
|
|
|
|
+
|
|
|
|
|
+ # 模型权重
|
|
|
|
|
+ if args.model_weights:
|
|
|
|
|
+ config['model_weights'] = args.model_weights
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 使用默认模型权重 URL
|
|
|
|
|
+ config['model_weights'] = (
|
|
|
|
|
+ 'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth'
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return config
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def process_images(
|
|
|
|
|
+ detector: DitLayoutDetector,
|
|
|
|
|
+ image_files: List[str],
|
|
|
|
|
+ output_dir: Path,
|
|
|
|
|
+ save_json: bool = False,
|
|
|
|
|
+ min_confidence: float = 0.0
|
|
|
|
|
+) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 处理图像列表,进行布局检测
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ detector: 布局检测器
|
|
|
|
|
+ image_files: 图像文件路径列表
|
|
|
|
|
+ output_dir: 输出目录
|
|
|
|
|
+ save_json: 是否保存 JSON 结果
|
|
|
|
|
+ min_confidence: 最小置信度阈值
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 统计结果字典
|
|
|
|
|
+ """
|
|
|
|
|
+ all_results = {}
|
|
|
|
|
+ total_stats = {
|
|
|
|
|
+ 'total_pages': len(image_files),
|
|
|
|
|
+ 'total_regions': 0,
|
|
|
|
|
+ 'category_counts': {},
|
|
|
|
|
+ 'confidence_stats': {
|
|
|
|
|
+ 'min': float('inf'),
|
|
|
|
|
+ 'max': 0.0,
|
|
|
|
|
+ 'sum': 0.0,
|
|
|
|
|
+ 'count': 0
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for idx, image_path in enumerate(image_files, 1):
|
|
|
|
|
+ print(f"\n{'='*60}")
|
|
|
|
|
+ print(f"📖 处理图像 {idx}/{len(image_files)}: {Path(image_path).name}")
|
|
|
|
|
+ print(f"{'='*60}")
|
|
|
|
|
+
|
|
|
|
|
+ # 读取图像
|
|
|
|
|
+ img = cv2.imread(image_path)
|
|
|
|
|
+ if img is None:
|
|
|
|
|
+ print(f"❌ 无法读取图像: {image_path}")
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ print(f" 图像尺寸: {img.shape[1]}x{img.shape[0]}")
|
|
|
|
|
+
|
|
|
|
|
+ # 执行检测
|
|
|
|
|
+ try:
|
|
|
|
|
+ results = detector.detect(img)
|
|
|
|
|
+ print(f"✅ 检测到 {len(results)} 个区域")
|
|
|
|
|
+
|
|
|
|
|
+ # 统计结果
|
|
|
|
|
+ page_stats = {
|
|
|
|
|
+ 'image_path': image_path,
|
|
|
|
|
+ 'image_size': [img.shape[1], img.shape[0]],
|
|
|
|
|
+ 'regions': [],
|
|
|
|
|
+ 'category_counts': {}
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for res in results:
|
|
|
|
|
+ # 添加到页面统计
|
|
|
|
|
+ page_stats['regions'].append({
|
|
|
|
|
+ 'category': res['category'],
|
|
|
|
|
+ 'bbox': res['bbox'],
|
|
|
|
|
+ 'confidence': float(res['confidence']),
|
|
|
|
|
+ 'original_label': res.get('raw', {}).get('original_label', 'unknown')
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 更新类别统计
|
|
|
|
|
+ cat = res['category']
|
|
|
|
|
+ page_stats['category_counts'][cat] = page_stats['category_counts'].get(cat, 0) + 1
|
|
|
|
|
+ total_stats['category_counts'][cat] = total_stats['category_counts'].get(cat, 0) + 1
|
|
|
|
|
+
|
|
|
|
|
+ # 更新置信度统计
|
|
|
|
|
+ conf = res['confidence']
|
|
|
|
|
+ total_stats['confidence_stats']['min'] = min(total_stats['confidence_stats']['min'], conf)
|
|
|
|
|
+ total_stats['confidence_stats']['max'] = max(total_stats['confidence_stats']['max'], conf)
|
|
|
|
|
+ total_stats['confidence_stats']['sum'] += conf
|
|
|
|
|
+ total_stats['confidence_stats']['count'] += 1
|
|
|
|
|
+
|
|
|
|
|
+ total_stats['total_regions'] += len(results)
|
|
|
|
|
+ all_results[image_path] = page_stats
|
|
|
|
|
+
|
|
|
|
|
+ # 打印页面统计
|
|
|
|
|
+ if page_stats['category_counts']:
|
|
|
|
|
+ print(f"\n 类别统计:")
|
|
|
|
|
+ for cat, count in sorted(page_stats['category_counts'].items()):
|
|
|
|
|
+ print(f" - {cat}: {count}")
|
|
|
|
|
+
|
|
|
|
|
+ # 可视化
|
|
|
|
|
+ if len(results) > 0:
|
|
|
|
|
+ print(f"\n 🎨 生成可视化图像...")
|
|
|
|
|
+
|
|
|
|
|
+ image_stem = Path(image_path).stem
|
|
|
|
|
+ output_path = output_dir / f"{image_stem}_dit_layout_vis.jpg"
|
|
|
|
|
+
|
|
|
|
|
+ vis_img = detector.visualize(
|
|
|
|
|
+ img,
|
|
|
|
|
+ results,
|
|
|
|
|
+ output_path=str(output_path),
|
|
|
|
|
+ show_confidence=True,
|
|
|
|
|
+ min_confidence=min_confidence
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ print(f" 💾 可视化图像已保存: {output_path}")
|
|
|
|
|
+
|
|
|
|
|
+ # 保存 JSON 结果
|
|
|
|
|
+ if save_json:
|
|
|
|
|
+ json_path = output_dir / f"{Path(image_path).stem}_dit_layout_results.json"
|
|
|
|
|
+ with open(json_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
+ json.dump(page_stats, f, ensure_ascii=False, indent=2)
|
|
|
|
|
+ print(f" 💾 JSON 结果已保存: {json_path}")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ 检测失败: {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 计算平均置信度
|
|
|
|
|
+ if total_stats['confidence_stats']['count'] > 0:
|
|
|
|
|
+ total_stats['confidence_stats']['mean'] = (
|
|
|
|
|
+ total_stats['confidence_stats']['sum'] / total_stats['confidence_stats']['count']
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ total_stats['confidence_stats']['mean'] = 0.0
|
|
|
|
|
+ total_stats['confidence_stats']['min'] = 0.0
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ 'all_results': all_results,
|
|
|
|
|
+ 'total_stats': total_stats
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def print_summary(stats: Dict[str, Any]):
|
|
|
|
|
+ """打印统计摘要"""
|
|
|
|
|
+ total_stats = stats['total_stats']
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n{'='*60}")
|
|
|
|
|
+ print(f"📊 检测结果摘要")
|
|
|
|
|
+ print(f"{'='*60}")
|
|
|
|
|
+ print(f"总页数: {total_stats['total_pages']}")
|
|
|
|
|
+ print(f"总区域数: {total_stats['total_regions']}")
|
|
|
|
|
+
|
|
|
|
|
+ if total_stats['total_regions'] > 0:
|
|
|
|
|
+ print(f"\n类别统计:")
|
|
|
|
|
+ for cat, count in sorted(total_stats['category_counts'].items()):
|
|
|
|
|
+ percentage = (count / total_stats['total_regions']) * 100
|
|
|
|
|
+ print(f" - {cat}: {count} ({percentage:.1f}%)")
|
|
|
|
|
+
|
|
|
|
|
+ conf_stats = total_stats['confidence_stats']
|
|
|
|
|
+ print(f"\n置信度统计:")
|
|
|
|
|
+ print(f" - 最小值: {conf_stats['min']:.3f}")
|
|
|
|
|
+ print(f" - 最大值: {conf_stats['max']:.3f}")
|
|
|
|
|
+ print(f" - 平均值: {conf_stats['mean']:.3f}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def main():
|
|
|
|
|
+ """主函数"""
|
|
|
|
|
+ args = parse_args()
|
|
|
|
|
+
|
|
|
|
|
+ # 设置输出目录
|
|
|
|
|
+ if args.output_dir:
|
|
|
|
|
+ output_dir = Path(args.output_dir)
|
|
|
|
|
+ else:
|
|
|
|
|
+ output_dir = Path(__file__).parent / "output"
|
|
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ print(f"📁 输出目录: {output_dir}")
|
|
|
|
|
+
|
|
|
|
|
+ # 获取输入图像列表
|
|
|
|
|
+ try:
|
|
|
|
|
+ image_files = get_input_images(
|
|
|
|
|
+ args.input,
|
|
|
|
|
+ page_range=args.pages,
|
|
|
|
|
+ dpi=args.dpi
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ 错误: {e}")
|
|
|
|
|
+ sys.exit(1)
|
|
|
|
|
+
|
|
|
|
|
+ if not image_files:
|
|
|
|
|
+ print("❌ 未找到要处理的图像文件")
|
|
|
|
|
+ sys.exit(1)
|
|
|
|
|
+
|
|
|
|
|
+ # 构建配置
|
|
|
|
|
+ project_root = Path(__file__).parents[1]
|
|
|
|
|
+ config = build_config(args, project_root)
|
|
|
|
|
+
|
|
|
|
|
+ # 初始化检测器
|
|
|
|
|
+ print(f"\n{'='*60}")
|
|
|
|
|
+ print(f"🔧 初始化 DiT Layout Detector")
|
|
|
|
|
+ print(f"{'='*60}")
|
|
|
|
|
+ print(f"配置文件: {config.get('config_file', 'N/A')}")
|
|
|
|
|
+ print(f"模型权重: {config.get('model_weights', 'N/A')}")
|
|
|
|
|
+ print(f"设备: {config['device']}")
|
|
|
|
|
+ print(f"置信度阈值: {config['conf']}")
|
|
|
|
|
+ print(f"重叠框处理: {config['remove_overlap']}")
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ detector = DitLayoutDetector(config)
|
|
|
|
|
+ detector.initialize()
|
|
|
|
|
+ print("✅ 检测器初始化成功")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ 检测器初始化失败: {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ sys.exit(1)
|
|
|
|
|
+
|
|
|
|
|
+ # 处理图像
|
|
|
|
|
+ try:
|
|
|
|
|
+ stats = process_images(
|
|
|
|
|
+ detector,
|
|
|
|
|
+ image_files,
|
|
|
|
|
+ output_dir,
|
|
|
|
|
+ save_json=args.save_json,
|
|
|
|
|
+ min_confidence=args.min_confidence
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 打印摘要
|
|
|
|
|
+ print_summary(stats)
|
|
|
|
|
+
|
|
|
|
|
+ # 保存总体统计
|
|
|
|
|
+ summary_path = output_dir / "detection_summary.json"
|
|
|
|
|
+ with open(summary_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
+ json.dump(stats['total_stats'], f, ensure_ascii=False, indent=2)
|
|
|
|
|
+ print(f"\n💾 统计摘要已保存: {summary_path}")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ 处理过程中出错: {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ finally:
|
|
|
|
|
+ # 清理资源
|
|
|
|
|
+ detector.cleanup()
|
|
|
|
|
+ print("\n✅ 测试完成!")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ if len(sys.argv) == 1:
|
|
|
|
|
+ # 没有命令行参数时,使用默认配置运行
|
|
|
|
|
+ print("ℹ️ 未提供命令行参数,使用默认配置运行...")
|
|
|
|
|
+
|
|
|
|
|
+ # 默认配置
|
|
|
|
|
+ default_config = {
|
|
|
|
|
+ # 测试输入
|
|
|
|
|
+ "input": "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司.pdf",
|
|
|
|
|
+ "output-dir": "./output/2023年度报告母公司_dit_layout_adapter",
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ # 页面范围(可选)
|
|
|
|
|
+ # "pages": "2-7,24, 26, 29-34", # 只处理前1页
|
|
|
|
|
+ "pages": "32", # 处理指定页面
|
|
|
|
|
+
|
|
|
|
|
+ # 是否启用重叠框处理
|
|
|
|
|
+ # "no-remove-overlap": True,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 构造参数
|
|
|
|
|
+ sys.argv = [sys.argv[0]]
|
|
|
|
|
+ for key, value in default_config.items():
|
|
|
|
|
+ if isinstance(value, bool):
|
|
|
|
|
+ if value:
|
|
|
|
|
+ sys.argv.append(f"--{key}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ sys.argv.extend([f"--{key}", str(value)])
|
|
|
|
|
+
|
|
|
|
|
+ sys.exit(main())
|