|
|
@@ -0,0 +1,531 @@
|
|
|
+"""DocLayout-YOLO 批量布局检测工具"""
|
|
|
+import json
|
|
|
+import time
|
|
|
+import os
|
|
|
+import traceback
|
|
|
+import argparse
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Dict, Any, Union
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
+from doclayout_yolo import YOLOv10
|
|
|
+import numpy as np
|
|
|
+from PIL import Image, ImageDraw, ImageFont
|
|
|
+
|
|
|
+from mineru.utils.enum_class import ModelPath
|
|
|
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
+
|
|
|
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
+from utils import get_input_files
|
|
|
+
|
|
|
+
|
|
|
+class DocLayoutYOLOModel:
|
|
|
+ """DocLayout-YOLO 模型封装类"""
|
|
|
+
|
|
|
+ # 类别ID到名称的映射(DocLayout-YOLO标准类别)
|
|
|
+ # 类型定义参见:docs/en/reference/output_files.md
|
|
|
+ CATEGORY_NAMES = {
|
|
|
+ 0: 'title', # Title
|
|
|
+ 1: 'plain_text', # text
|
|
|
+ 2: 'abandon', # Including headers, footers, page numbers, and page annotations
|
|
|
+ 3: 'figure', # Image
|
|
|
+ 4: 'figure_caption', # Image caption
|
|
|
+ 5: 'table', # Table
|
|
|
+ 6: 'table_caption', # Table caption
|
|
|
+ 7: 'table_footnote', # Table footnote
|
|
|
+ 8: 'isolate_formula', # Interline formula
|
|
|
+ 9: 'formula_caption', # Interline formula number
|
|
|
+ 13: 'embedding', # Inline formula
|
|
|
+ 14: 'isolated', # Interline formula
|
|
|
+ 15: 'text', # OCR recognition result
|
|
|
+ }
|
|
|
+
|
|
|
+ # 类别对应的颜色(与MinerU保持一致)
|
|
|
+ CATEGORY_COLORS = {
|
|
|
+ 0: (102, 102, 255), # title: 蓝色
|
|
|
+ 1: (153, 0, 76), # plain_text: 深红
|
|
|
+ 2: (158, 158, 158), # abandon: 灰色
|
|
|
+ 3: (153, 255, 51), # figure: 绿色
|
|
|
+ 4: (102, 178, 255), # figure_caption: 浅蓝
|
|
|
+ 5: (204, 204, 0), # table: 黄色
|
|
|
+ 6: (255, 255, 102), # table_caption: 浅黄
|
|
|
+ 7: (229, 255, 204), # table_footnote: 浅绿
|
|
|
+ 8: (0, 255, 0), # isolate_formula: 亮绿
|
|
|
+ 9: (255, 0, 0), # formula_caption: 红色
|
|
|
+ }
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ weight: str,
|
|
|
+ device: str = "cuda",
|
|
|
+ imgsz: int = 1280,
|
|
|
+ conf: float = 0.25,
|
|
|
+ iou: float = 0.45,
|
|
|
+ ):
|
|
|
+ self.model = YOLOv10(weight).to(device)
|
|
|
+ self.device = device
|
|
|
+ self.imgsz = imgsz
|
|
|
+ self.conf = conf
|
|
|
+ self.iou = iou
|
|
|
+
|
|
|
+ def _parse_prediction(self, prediction) -> List[Dict]:
|
|
|
+ """解析模型预测结果"""
|
|
|
+ layout_res = []
|
|
|
+
|
|
|
+ if not hasattr(prediction, "boxes") or prediction.boxes is None:
|
|
|
+ return layout_res
|
|
|
+
|
|
|
+ for xyxy, conf, cls in zip(
|
|
|
+ prediction.boxes.xyxy.cpu(),
|
|
|
+ prediction.boxes.conf.cpu(),
|
|
|
+ prediction.boxes.cls.cpu(),
|
|
|
+ ):
|
|
|
+ coords = list(map(int, xyxy.tolist()))
|
|
|
+ xmin, ymin, xmax, ymax = coords
|
|
|
+ layout_res.append({
|
|
|
+ "category_id": int(cls.item()),
|
|
|
+ "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
+ "score": round(float(conf.item()), 3),
|
|
|
+ })
|
|
|
+ return layout_res
|
|
|
+
|
|
|
+ def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
|
|
|
+ """单张图片预测"""
|
|
|
+ prediction = self.model.predict(
|
|
|
+ image,
|
|
|
+ imgsz=self.imgsz,
|
|
|
+ conf=self.conf,
|
|
|
+ iou=self.iou,
|
|
|
+ verbose=False
|
|
|
+ )[0]
|
|
|
+ return self._parse_prediction(prediction)
|
|
|
+
|
|
|
+ def batch_predict(
|
|
|
+ self,
|
|
|
+ images: List[Union[np.ndarray, Image.Image]],
|
|
|
+ batch_size: int = 4
|
|
|
+ ) -> List[List[Dict]]:
|
|
|
+ """批量预测"""
|
|
|
+ results = []
|
|
|
+ with tqdm(total=len(images), desc="Layout Predict", disable=True) as pbar:
|
|
|
+ for idx in range(0, len(images), batch_size):
|
|
|
+ batch = images[idx: idx + batch_size]
|
|
|
+ conf = 0.9 * self.conf if batch_size == 1 else self.conf
|
|
|
+
|
|
|
+ predictions = self.model.predict(
|
|
|
+ batch,
|
|
|
+ imgsz=self.imgsz,
|
|
|
+ conf=conf,
|
|
|
+ iou=self.iou,
|
|
|
+ verbose=False,
|
|
|
+ )
|
|
|
+ for pred in predictions:
|
|
|
+ results.append(self._parse_prediction(pred))
|
|
|
+ pbar.update(len(batch))
|
|
|
+ return results
|
|
|
+
|
|
|
+ def visualize(
|
|
|
+ self,
|
|
|
+ image: Union[np.ndarray, Image.Image],
|
|
|
+ results: List[Dict],
|
|
|
+ output_path: str = None,
|
|
|
+ draw_type_label: bool = True,
|
|
|
+ draw_score: bool = True,
|
|
|
+ draw_order_number: bool = False,
|
|
|
+ font_size: int = 14,
|
|
|
+ line_width: int = 2,
|
|
|
+ verbose: bool = False
|
|
|
+ ) -> Image.Image:
|
|
|
+ """可视化布局检测结果"""
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ image: 输入图像(PIL Image或numpy array)
|
|
|
+ results: 检测结果列表
|
|
|
+ output_path: 输出图片路径(如果为None则不保存)
|
|
|
+ draw_type_label: 是否标注类型名称(默认True)
|
|
|
+ draw_score: 是否标注置信度分数(默认True)
|
|
|
+ draw_order_number: 是否标注检测顺序编号(默认False)
|
|
|
+ font_size: 字体大小(默认14)
|
|
|
+ line_width: 边框线宽(默认2)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ PIL.Image: 标注后的图像
|
|
|
+ """
|
|
|
+ # 1. 转换图像格式
|
|
|
+ if isinstance(image, np.ndarray):
|
|
|
+ image = Image.fromarray(image)
|
|
|
+ else:
|
|
|
+ image = image.copy()
|
|
|
+
|
|
|
+ draw = ImageDraw.Draw(image)
|
|
|
+
|
|
|
+ # 2. 尝试加载字体
|
|
|
+ try:
|
|
|
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", font_size)
|
|
|
+ except:
|
|
|
+ try:
|
|
|
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size)
|
|
|
+ except:
|
|
|
+ font = ImageFont.load_default()
|
|
|
+
|
|
|
+ # 3. 绘制每个检测框
|
|
|
+ for idx, res in enumerate(results, 1):
|
|
|
+ poly = res['poly']
|
|
|
+ xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
|
|
|
+ category_id = res['category_id']
|
|
|
+ score = res['score']
|
|
|
+
|
|
|
+ # 获取类别名称和颜色
|
|
|
+ category_name = self.CATEGORY_NAMES.get(category_id, f'unknown_{category_id}')
|
|
|
+ color = self.CATEGORY_COLORS.get(category_id, (255, 0, 0))
|
|
|
+
|
|
|
+ # 3.1 绘制边框
|
|
|
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=line_width)
|
|
|
+
|
|
|
+ # 3.2 准备标注文本
|
|
|
+ labels = []
|
|
|
+ if draw_type_label:
|
|
|
+ labels.append(category_name)
|
|
|
+ if draw_score:
|
|
|
+ labels.append(f"{score:.2f}")
|
|
|
+ if draw_order_number:
|
|
|
+ labels.append(f"#{idx}")
|
|
|
+
|
|
|
+ label_text = " | ".join(labels) if labels else ""
|
|
|
+
|
|
|
+ # 3.3 绘制标注文本(如果有)
|
|
|
+ if label_text:
|
|
|
+ # 计算文本背景框
|
|
|
+ bbox = draw.textbbox((xmin + 2, ymin + 2), label_text, font=font)
|
|
|
+ # 绘制半透明背景
|
|
|
+ draw.rectangle(bbox, fill=(*color, 200))
|
|
|
+ # 绘制文本
|
|
|
+ draw.text((xmin + 2, ymin + 2), label_text, fill='white', font=font)
|
|
|
+
|
|
|
+ if verbose:
|
|
|
+ print(f"Box #{idx}: {category_name} [{xmin}, {ymin}, {xmax}, {ymax}] score={score:.3f}")
|
|
|
+
|
|
|
+ # 4. 保存到文件(如果指定了路径)
|
|
|
+ if output_path:
|
|
|
+ os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
|
|
|
+ image.save(output_path)
|
|
|
+ if verbose:
|
|
|
+ print(f"✅ Layout visualization saved to: {output_path}")
|
|
|
+
|
|
|
+ return image
|
|
|
+
|
|
|
+
|
|
|
+def process_images_batch(
|
|
|
+ image_paths: List[str],
|
|
|
+ model: DocLayoutYOLOModel,
|
|
|
+ output_dir: str = "./output",
|
|
|
+ draw_type_label: bool = True,
|
|
|
+ draw_score: bool = True,
|
|
|
+ draw_order_number: bool = False,
|
|
|
+ save_json: bool = True,
|
|
|
+ save_visualization: bool = True,
|
|
|
+ font_size: int = 14,
|
|
|
+ line_width: int = 2
|
|
|
+) -> List[Dict[str, Any]]:
|
|
|
+ """批量处理图像"""
|
|
|
+
|
|
|
+ # 创建输出目录
|
|
|
+ output_path = Path(output_dir)
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ if save_json:
|
|
|
+ json_dir = output_path / "json"
|
|
|
+ json_dir.mkdir(exist_ok=True)
|
|
|
+
|
|
|
+ if save_visualization:
|
|
|
+ viz_dir = output_path / "visualization"
|
|
|
+ viz_dir.mkdir(exist_ok=True)
|
|
|
+
|
|
|
+ all_results = []
|
|
|
+ total_images = len(image_paths)
|
|
|
+
|
|
|
+ print(f"Processing {total_images} images")
|
|
|
+
|
|
|
+ # 使用tqdm显示进度
|
|
|
+ with tqdm(total=total_images, desc="Processing images", unit="img",
|
|
|
+ bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
|
|
|
+
|
|
|
+ for img_path in image_paths:
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 加载图片
|
|
|
+ image = Image.open(img_path)
|
|
|
+
|
|
|
+ # 预测
|
|
|
+ results = model.predict(image)
|
|
|
+ processing_time = time.time() - start_time
|
|
|
+
|
|
|
+ # 生成输出文件名
|
|
|
+ input_path = Path(img_path)
|
|
|
+ output_filename = input_path.stem
|
|
|
+
|
|
|
+ # 保存JSON
|
|
|
+ json_output_path = None
|
|
|
+ if save_json:
|
|
|
+ json_output_path = json_dir / f"{output_filename}_layout.json"
|
|
|
+ with open(json_output_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump({
|
|
|
+ "image_path": str(img_path),
|
|
|
+ "image_size": list(image.size),
|
|
|
+ "layout_results": results,
|
|
|
+ "processing_time": processing_time
|
|
|
+ }, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ # 保存可视化
|
|
|
+ viz_output_path = None
|
|
|
+ if save_visualization:
|
|
|
+ viz_output_path = viz_dir / f"{output_filename}_layout.png"
|
|
|
+ model.visualize(
|
|
|
+ image,
|
|
|
+ results,
|
|
|
+ output_path=str(viz_output_path),
|
|
|
+ draw_type_label=draw_type_label,
|
|
|
+ draw_score=draw_score,
|
|
|
+ draw_order_number=draw_order_number,
|
|
|
+ font_size=font_size,
|
|
|
+ line_width=line_width,
|
|
|
+ verbose=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # 记录结果
|
|
|
+ all_results.append({
|
|
|
+ "image_path": str(input_path),
|
|
|
+ "processing_time": processing_time,
|
|
|
+ "success": True,
|
|
|
+ "num_detections": len(results),
|
|
|
+ "output_json": str(json_output_path) if json_output_path else None,
|
|
|
+ "output_viz": str(viz_output_path) if viz_output_path else None,
|
|
|
+ "detections": results
|
|
|
+ })
|
|
|
+
|
|
|
+ # 更新进度条
|
|
|
+ success_count = sum(1 for r in all_results if r.get('success', False))
|
|
|
+ pbar.update(1)
|
|
|
+ pbar.set_postfix({
|
|
|
+ 'time': f"{processing_time:.2f}s",
|
|
|
+ 'boxes': len(results),
|
|
|
+ 'success': f"{success_count}/{len(all_results)}"
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"\n❌ Error processing {Path(img_path).name}: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+ all_results.append({
|
|
|
+ "image_path": str(img_path),
|
|
|
+ "processing_time": 0,
|
|
|
+ "success": False,
|
|
|
+ "error": str(e)
|
|
|
+ })
|
|
|
+ pbar.update(1)
|
|
|
+
|
|
|
+ return all_results
|
|
|
+
|
|
|
+
|
|
|
+def collect_results(results: List[Dict], output_csv: str):
|
|
|
+ """收集处理结果到CSV"""
|
|
|
+ import csv
|
|
|
+
|
|
|
+ with open(output_csv, 'w', encoding='utf-8', newline='') as f:
|
|
|
+ writer = csv.writer(f)
|
|
|
+ writer.writerow(['image_path', 'status', 'num_detections', 'processing_time'])
|
|
|
+
|
|
|
+ for result in results:
|
|
|
+ writer.writerow([
|
|
|
+ result['image_path'],
|
|
|
+ 'success' if result.get('success', False) else 'failed',
|
|
|
+ result.get('num_detections', 0),
|
|
|
+ f"{result.get('processing_time', 0):.2f}"
|
|
|
+ ])
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ """主函数"""
|
|
|
+ parser = argparse.ArgumentParser(description="DocLayout-YOLO Batch Layout Detection Tool")
|
|
|
+
|
|
|
+ # 输入参数
|
|
|
+ input_group = parser.add_mutually_exclusive_group(required=True)
|
|
|
+ input_group.add_argument("--input_file", type=str, help="Input image file")
|
|
|
+ input_group.add_argument("--input_dir", type=str, help="Input directory")
|
|
|
+ input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
|
|
|
+ input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path column")
|
|
|
+
|
|
|
+ # 输出参数
|
|
|
+ parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
|
|
|
+ parser.add_argument("--collect_results", type=str, help="Collect results to CSV file")
|
|
|
+
|
|
|
+ # 模型参数
|
|
|
+ parser.add_argument("--model_path", type=str, help="Custom model path (optional)")
|
|
|
+ parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device")
|
|
|
+ parser.add_argument("--imgsz", type=int, default=1280, help="Image size for inference")
|
|
|
+ parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
|
|
|
+ parser.add_argument("--iou", type=float, default=0.45, help="IoU threshold")
|
|
|
+
|
|
|
+ # 可视化参数
|
|
|
+ parser.add_argument("--no-visualization", action="store_true", help="Disable visualization output")
|
|
|
+ parser.add_argument("--no-json", action="store_true", help="Disable JSON output")
|
|
|
+ parser.add_argument("--draw_type_label", action="store_true", default=True, help="Draw type labels")
|
|
|
+ parser.add_argument("--draw_score", action="store_true", default=True, help="Draw confidence scores")
|
|
|
+ parser.add_argument("--draw_order_number", action="store_true", help="Draw detection order numbers")
|
|
|
+ parser.add_argument("--font_size", type=int, default=14, help="Font size for labels")
|
|
|
+ parser.add_argument("--line_width", type=int, default=2, help="Line width for bounding boxes")
|
|
|
+
|
|
|
+ # 其他参数
|
|
|
+ parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 获取输入文件
|
|
|
+ print("🔄 Getting input files...")
|
|
|
+ input_files = get_input_files(args)
|
|
|
+
|
|
|
+ if not input_files:
|
|
|
+ print("❌ No input files found")
|
|
|
+ return 1
|
|
|
+
|
|
|
+ if args.test_mode:
|
|
|
+ input_files = input_files[:20]
|
|
|
+ print(f"🧪 Test mode: processing only {len(input_files)} images")
|
|
|
+
|
|
|
+ # 加载模型
|
|
|
+ print("🔄 Loading DocLayout-YOLO model...")
|
|
|
+ if args.model_path:
|
|
|
+ model_path = args.model_path
|
|
|
+ else:
|
|
|
+ model_path = os.path.join(
|
|
|
+ auto_download_and_get_model_root_path(ModelPath.doclayout_yolo),
|
|
|
+ ModelPath.doclayout_yolo
|
|
|
+ )
|
|
|
+
|
|
|
+ model = DocLayoutYOLOModel(
|
|
|
+ weight=model_path,
|
|
|
+ device=args.device,
|
|
|
+ imgsz=args.imgsz,
|
|
|
+ conf=args.conf,
|
|
|
+ iou=args.iou
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"✅ Model loaded: {model_path}")
|
|
|
+ print(f"🔧 Device: {args.device}")
|
|
|
+ print(f"🔧 Image size: {args.imgsz}")
|
|
|
+ print(f"🔧 Confidence threshold: {args.conf}")
|
|
|
+ print(f"🔧 IoU threshold: {args.iou}")
|
|
|
+
|
|
|
+ # 开始处理
|
|
|
+ start_time = time.time()
|
|
|
+ results = process_images_batch(
|
|
|
+ input_files,
|
|
|
+ model,
|
|
|
+ args.output_dir,
|
|
|
+ draw_type_label=args.draw_type_label,
|
|
|
+ draw_score=args.draw_score,
|
|
|
+ draw_order_number=args.draw_order_number,
|
|
|
+ save_json=not args.no_json,
|
|
|
+ save_visualization=not args.no_visualization,
|
|
|
+ font_size=args.font_size,
|
|
|
+ line_width=args.line_width
|
|
|
+ )
|
|
|
+ total_time = time.time() - start_time
|
|
|
+
|
|
|
+ # 统计结果
|
|
|
+ success_count = sum(1 for r in results if r.get('success', False))
|
|
|
+ error_count = len(results) - success_count
|
|
|
+ total_detections = sum(r.get('num_detections', 0) for r in results if r.get('success', False))
|
|
|
+
|
|
|
+ print(f"\n" + "="*60)
|
|
|
+ print(f"✅ Processing completed!")
|
|
|
+ print(f"📊 Statistics:")
|
|
|
+ print(f" Total files processed: {len(input_files)}")
|
|
|
+ print(f" Successful: {success_count}")
|
|
|
+ print(f" Failed: {error_count}")
|
|
|
+ if len(input_files) > 0:
|
|
|
+ print(f" Success rate: {success_count / len(input_files) * 100:.2f}%")
|
|
|
+ print(f" Total detections: {total_detections}")
|
|
|
+ if success_count > 0:
|
|
|
+ print(f" Avg detections per image: {total_detections / success_count:.2f}")
|
|
|
+ print(f"⏱️ Performance:")
|
|
|
+ print(f" Total time: {total_time:.2f} seconds")
|
|
|
+ if total_time > 0:
|
|
|
+ print(f" Throughput: {len(input_files) / total_time:.2f} images/second")
|
|
|
+ print(f" Avg time per image: {total_time / len(input_files):.2f} seconds")
|
|
|
+
|
|
|
+ # 保存结果统计
|
|
|
+ stats = {
|
|
|
+ "total_files": len(input_files),
|
|
|
+ "success_count": success_count,
|
|
|
+ "error_count": error_count,
|
|
|
+ "success_rate": success_count / len(input_files) if len(input_files) > 0 else 0,
|
|
|
+ "total_detections": total_detections,
|
|
|
+ "avg_detections": total_detections / success_count if success_count > 0 else 0,
|
|
|
+ "total_time": total_time,
|
|
|
+ "throughput": len(input_files) / total_time if total_time > 0 else 0,
|
|
|
+ "avg_time_per_image": total_time / len(input_files) if len(input_files) > 0 else 0,
|
|
|
+ "model_path": model_path,
|
|
|
+ "device": args.device,
|
|
|
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ }
|
|
|
+
|
|
|
+ # 保存最终结果
|
|
|
+ output_file_name = Path(args.output_dir).name
|
|
|
+ output_file = os.path.join(args.output_dir, f"{output_file_name}_results.json")
|
|
|
+ final_results = {
|
|
|
+ "stats": stats,
|
|
|
+ "results": results
|
|
|
+ }
|
|
|
+
|
|
|
+ with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(final_results, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ print(f"💾 Results saved to: {output_file}")
|
|
|
+
|
|
|
+ # 收集处理结果
|
|
|
+ if args.collect_results:
|
|
|
+ output_csv = Path(args.collect_results).resolve()
|
|
|
+ else:
|
|
|
+ output_csv = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
|
|
|
+
|
|
|
+ collect_results(results, str(output_csv))
|
|
|
+ print(f"💾 Processed files saved to: {output_csv}")
|
|
|
+
|
|
|
+ return 0
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ Processing failed: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+ return 1
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ print(f"🚀 Starting DocLayout-YOLO Batch Processing...")
|
|
|
+
|
|
|
+ if len(sys.argv) == 1:
|
|
|
+ # 默认配置
|
|
|
+ print("ℹ️ No command line arguments provided. Running with default configuration...")
|
|
|
+
|
|
|
+ default_config = {
|
|
|
+ # "input_file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水.img/B用户_扫描流水_page_002.png",
|
|
|
+ "input_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/data_PPStructureV3_Results/B用户_扫描流水",
|
|
|
+ "output_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/doclayout_yolo_results",
|
|
|
+ "device": "cpu",
|
|
|
+ "draw_type_label": True,
|
|
|
+ "draw_score": True,
|
|
|
+ "draw_order_number": True,
|
|
|
+ }
|
|
|
+
|
|
|
+ sys.argv = [sys.argv[0]]
|
|
|
+ for key, value in default_config.items():
|
|
|
+ if isinstance(value, bool):
|
|
|
+ if value:
|
|
|
+ sys.argv.append(f"--{key}")
|
|
|
+ else:
|
|
|
+ sys.argv.extend([f"--{key}", str(value)])
|
|
|
+
|
|
|
+ sys.exit(main())
|