"""DocLayout-YOLO 批量布局检测工具""" import json import time import os import traceback import argparse import sys from pathlib import Path from typing import List, Dict, Any, Union from tqdm import tqdm from doclayout_yolo import YOLOv10 import numpy as np from PIL import Image, ImageDraw, ImageFont from mineru.utils.enum_class import ModelPath from mineru.utils.models_download_utils import auto_download_and_get_model_root_path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from utils import get_input_files class DocLayoutYOLOModel: """DocLayout-YOLO 模型封装类""" # 类别ID到名称的映射(DocLayout-YOLO标准类别) # 类型定义参见:docs/en/reference/output_files.md CATEGORY_NAMES = { 0: 'title', # Title 1: 'plain_text', # text 2: 'abandon', # Including headers, footers, page numbers, and page annotations 3: 'figure', # Image 4: 'figure_caption', # Image caption 5: 'table', # Table 6: 'table_caption', # Table caption 7: 'table_footnote', # Table footnote 8: 'isolate_formula', # Interline formula 9: 'formula_caption', # Interline formula number 13: 'embedding', # Inline formula 14: 'isolated', # Interline formula 15: 'text', # OCR recognition result } # 类别对应的颜色(与MinerU保持一致) CATEGORY_COLORS = { 0: (102, 102, 255), # title: 蓝色 1: (153, 0, 76), # plain_text: 深红 2: (158, 158, 158), # abandon: 灰色 3: (153, 255, 51), # figure: 绿色 4: (102, 178, 255), # figure_caption: 浅蓝 5: (204, 204, 0), # table: 黄色 6: (255, 255, 102), # table_caption: 浅黄 7: (229, 255, 204), # table_footnote: 浅绿 8: (0, 255, 0), # isolate_formula: 亮绿 9: (255, 0, 0), # formula_caption: 红色 } def __init__( self, weight: str, device: str = "cuda", imgsz: int = 1280, conf: float = 0.25, iou: float = 0.45, ): self.model = YOLOv10(weight).to(device) self.device = device self.imgsz = imgsz self.conf = conf self.iou = iou def _parse_prediction(self, prediction) -> List[Dict]: """解析模型预测结果""" layout_res = [] if not hasattr(prediction, "boxes") or prediction.boxes is None: return layout_res for xyxy, conf, cls in zip( prediction.boxes.xyxy.cpu(), prediction.boxes.conf.cpu(), prediction.boxes.cls.cpu(), ): coords = list(map(int, xyxy.tolist())) xmin, ymin, xmax, ymax = coords layout_res.append({ "category_id": int(cls.item()), "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], "score": round(float(conf.item()), 3), }) return layout_res def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]: """单张图片预测""" prediction = self.model.predict( image, imgsz=self.imgsz, conf=self.conf, iou=self.iou, verbose=False )[0] return self._parse_prediction(prediction) def batch_predict( self, images: List[Union[np.ndarray, Image.Image]], batch_size: int = 4 ) -> List[List[Dict]]: """批量预测""" results = [] with tqdm(total=len(images), desc="Layout Predict", disable=True) as pbar: for idx in range(0, len(images), batch_size): batch = images[idx: idx + batch_size] conf = 0.9 * self.conf if batch_size == 1 else self.conf predictions = self.model.predict( batch, imgsz=self.imgsz, conf=conf, iou=self.iou, verbose=False, ) for pred in predictions: results.append(self._parse_prediction(pred)) pbar.update(len(batch)) return results def visualize( self, image: Union[np.ndarray, Image.Image], results: List[Dict], output_path: str = None, draw_type_label: bool = True, draw_score: bool = True, draw_order_number: bool = False, font_size: int = 14, line_width: int = 2, verbose: bool = False ) -> Image.Image: """可视化布局检测结果""" """ Args: image: 输入图像(PIL Image或numpy array) results: 检测结果列表 output_path: 输出图片路径(如果为None则不保存) draw_type_label: 是否标注类型名称(默认True) draw_score: 是否标注置信度分数(默认True) draw_order_number: 是否标注检测顺序编号(默认False) font_size: 字体大小(默认14) line_width: 边框线宽(默认2) Returns: PIL.Image: 标注后的图像 """ # 1. 转换图像格式 if isinstance(image, np.ndarray): image = Image.fromarray(image) else: image = image.copy() draw = ImageDraw.Draw(image) # 2. 尝试加载字体 try: font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", font_size) except: try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size) except: font = ImageFont.load_default() # 3. 绘制每个检测框 for idx, res in enumerate(results, 1): poly = res['poly'] xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5] category_id = res['category_id'] score = res['score'] # 获取类别名称和颜色 category_name = self.CATEGORY_NAMES.get(category_id, f'unknown_{category_id}') color = self.CATEGORY_COLORS.get(category_id, (255, 0, 0)) # 3.1 绘制边框 draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=line_width) # 3.2 准备标注文本 labels = [] if draw_type_label: labels.append(category_name) if draw_score: labels.append(f"{score:.2f}") if draw_order_number: labels.append(f"#{idx}") label_text = " | ".join(labels) if labels else "" # 3.3 绘制标注文本(如果有) if label_text: # 计算文本背景框 bbox = draw.textbbox((xmin + 2, ymin + 2), label_text, font=font) # 绘制半透明背景 draw.rectangle(bbox, fill=(*color, 200)) # 绘制文本 draw.text((xmin + 2, ymin + 2), label_text, fill='white', font=font) if verbose: print(f"Box #{idx}: {category_name} [{xmin}, {ymin}, {xmax}, {ymax}] score={score:.3f}") # 4. 保存到文件(如果指定了路径) if output_path: os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) image.save(output_path) if verbose: print(f"✅ Layout visualization saved to: {output_path}") return image def process_images_batch( image_paths: List[str], model: DocLayoutYOLOModel, output_dir: str = "./output", draw_type_label: bool = True, draw_score: bool = True, draw_order_number: bool = False, save_json: bool = True, save_visualization: bool = True, font_size: int = 14, line_width: int = 2 ) -> List[Dict[str, Any]]: """批量处理图像""" # 创建输出目录 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) if save_json: json_dir = output_path / "json" json_dir.mkdir(exist_ok=True) if save_visualization: viz_dir = output_path / "visualization" viz_dir.mkdir(exist_ok=True) all_results = [] total_images = len(image_paths) print(f"Processing {total_images} images") # 使用tqdm显示进度 with tqdm(total=total_images, desc="Processing images", unit="img", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar: for img_path in image_paths: start_time = time.time() try: # 加载图片 image = Image.open(img_path) # 预测 results = model.predict(image) processing_time = time.time() - start_time # 生成输出文件名 input_path = Path(img_path) output_filename = input_path.stem # 保存JSON json_output_path = None if save_json: json_output_path = json_dir / f"{output_filename}_layout.json" with open(json_output_path, 'w', encoding='utf-8') as f: json.dump({ "image_path": str(img_path), "image_size": list(image.size), "layout_results": results, "processing_time": processing_time }, f, ensure_ascii=False, indent=2) # 保存可视化 viz_output_path = None if save_visualization: viz_output_path = viz_dir / f"{output_filename}_layout.png" model.visualize( image, results, output_path=str(viz_output_path), draw_type_label=draw_type_label, draw_score=draw_score, draw_order_number=draw_order_number, font_size=font_size, line_width=line_width, verbose=False ) # 记录结果 all_results.append({ "image_path": str(input_path), "processing_time": processing_time, "success": True, "num_detections": len(results), "output_json": str(json_output_path) if json_output_path else None, "output_viz": str(viz_output_path) if viz_output_path else None, "detections": results }) # 更新进度条 success_count = sum(1 for r in all_results if r.get('success', False)) pbar.update(1) pbar.set_postfix({ 'time': f"{processing_time:.2f}s", 'boxes': len(results), 'success': f"{success_count}/{len(all_results)}" }) except Exception as e: print(f"\n❌ Error processing {Path(img_path).name}: {e}", file=sys.stderr) traceback.print_exc() all_results.append({ "image_path": str(img_path), "processing_time": 0, "success": False, "error": str(e) }) pbar.update(1) return all_results def collect_results(results: List[Dict], output_csv: str): """收集处理结果到CSV""" import csv with open(output_csv, 'w', encoding='utf-8', newline='') as f: writer = csv.writer(f) writer.writerow(['image_path', 'status', 'num_detections', 'processing_time']) for result in results: writer.writerow([ result['image_path'], 'success' if result.get('success', False) else 'failed', result.get('num_detections', 0), f"{result.get('processing_time', 0):.2f}" ]) def main(): """主函数""" parser = argparse.ArgumentParser(description="DocLayout-YOLO Batch Layout Detection Tool") # 输入参数 input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument("--input_file", type=str, help="Input image file") input_group.add_argument("--input_dir", type=str, help="Input directory") input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)") input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path column") # 输出参数 parser.add_argument("--output_dir", type=str, required=True, help="Output directory") parser.add_argument("--collect_results", type=str, help="Collect results to CSV file") # 模型参数 parser.add_argument("--model_path", type=str, help="Custom model path (optional)") parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device") parser.add_argument("--imgsz", type=int, default=1280, help="Image size for inference") parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold") parser.add_argument("--iou", type=float, default=0.45, help="IoU threshold") # 可视化参数 parser.add_argument("--no-visualization", action="store_true", help="Disable visualization output") parser.add_argument("--no-json", action="store_true", help="Disable JSON output") parser.add_argument("--draw_type_label", action="store_true", default=True, help="Draw type labels") parser.add_argument("--draw_score", action="store_true", default=True, help="Draw confidence scores") parser.add_argument("--draw_order_number", action="store_true", help="Draw detection order numbers") parser.add_argument("--font_size", type=int, default=14, help="Font size for labels") parser.add_argument("--line_width", type=int, default=2, help="Line width for bounding boxes") # 其他参数 parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)") args = parser.parse_args() try: # 获取输入文件 print("🔄 Getting input files...") input_files = get_input_files(args) if not input_files: print("❌ No input files found") return 1 if args.test_mode: input_files = input_files[:20] print(f"🧪 Test mode: processing only {len(input_files)} images") # 加载模型 print("🔄 Loading DocLayout-YOLO model...") if args.model_path: model_path = args.model_path else: model_path = os.path.join( auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo ) model = DocLayoutYOLOModel( weight=model_path, device=args.device, imgsz=args.imgsz, conf=args.conf, iou=args.iou ) print(f"✅ Model loaded: {model_path}") print(f"🔧 Device: {args.device}") print(f"🔧 Image size: {args.imgsz}") print(f"🔧 Confidence threshold: {args.conf}") print(f"🔧 IoU threshold: {args.iou}") # 开始处理 start_time = time.time() results = process_images_batch( input_files, model, args.output_dir, draw_type_label=args.draw_type_label, draw_score=args.draw_score, draw_order_number=args.draw_order_number, save_json=not args.no_json, save_visualization=not args.no_visualization, font_size=args.font_size, line_width=args.line_width ) total_time = time.time() - start_time # 统计结果 success_count = sum(1 for r in results if r.get('success', False)) error_count = len(results) - success_count total_detections = sum(r.get('num_detections', 0) for r in results if r.get('success', False)) print(f"\n" + "="*60) print(f"✅ Processing completed!") print(f"📊 Statistics:") print(f" Total files processed: {len(input_files)}") print(f" Successful: {success_count}") print(f" Failed: {error_count}") if len(input_files) > 0: print(f" Success rate: {success_count / len(input_files) * 100:.2f}%") print(f" Total detections: {total_detections}") if success_count > 0: print(f" Avg detections per image: {total_detections / success_count:.2f}") print(f"⏱️ Performance:") print(f" Total time: {total_time:.2f} seconds") if total_time > 0: print(f" Throughput: {len(input_files) / total_time:.2f} images/second") print(f" Avg time per image: {total_time / len(input_files):.2f} seconds") # 保存结果统计 stats = { "total_files": len(input_files), "success_count": success_count, "error_count": error_count, "success_rate": success_count / len(input_files) if len(input_files) > 0 else 0, "total_detections": total_detections, "avg_detections": total_detections / success_count if success_count > 0 else 0, "total_time": total_time, "throughput": len(input_files) / total_time if total_time > 0 else 0, "avg_time_per_image": total_time / len(input_files) if len(input_files) > 0 else 0, "model_path": model_path, "device": args.device, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S") } # 保存最终结果 output_file_name = Path(args.output_dir).name output_file = os.path.join(args.output_dir, f"{output_file_name}_results.json") final_results = { "stats": stats, "results": results } with open(output_file, 'w', encoding='utf-8') as f: json.dump(final_results, f, ensure_ascii=False, indent=2) print(f"💾 Results saved to: {output_file}") # 收集处理结果 if args.collect_results: output_csv = Path(args.collect_results).resolve() else: output_csv = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv" collect_results(results, str(output_csv)) print(f"💾 Processed files saved to: {output_csv}") return 0 except Exception as e: print(f"❌ Processing failed: {e}", file=sys.stderr) traceback.print_exc() return 1 if __name__ == "__main__": print(f"🚀 Starting DocLayout-YOLO Batch Processing...") if len(sys.argv) == 1: # 默认配置 print("ℹ️ No command line arguments provided. Running with default configuration...") default_config = { # "input_file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水.img/B用户_扫描流水_page_002.png", "input_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/data_PPStructureV3_Results/B用户_扫描流水", "output_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/doclayout_yolo_results", "device": "cpu", "draw_type_label": True, "draw_score": True, "draw_order_number": True, } sys.argv = [sys.argv[0]] for key, value in default_config.items(): if isinstance(value, bool): if value: sys.argv.append(f"--{key}") else: sys.argv.extend([f"--{key}", str(value)]) sys.exit(main())