| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531 |
- """DocLayout-YOLO 批量布局检测工具"""
- import json
- import time
- import os
- import traceback
- import argparse
- import sys
- from pathlib import Path
- from typing import List, Dict, Any, Union
- from tqdm import tqdm
- from doclayout_yolo import YOLOv10
- import numpy as np
- from PIL import Image, ImageDraw, ImageFont
- from mineru.utils.enum_class import ModelPath
- from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- from utils import get_input_files
- class DocLayoutYOLOModel:
- """DocLayout-YOLO 模型封装类"""
-
- # 类别ID到名称的映射(DocLayout-YOLO标准类别)
- # 类型定义参见:docs/en/reference/output_files.md
- CATEGORY_NAMES = {
- 0: 'title', # Title
- 1: 'plain_text', # text
- 2: 'abandon', # Including headers, footers, page numbers, and page annotations
- 3: 'figure', # Image
- 4: 'figure_caption', # Image caption
- 5: 'table', # Table
- 6: 'table_caption', # Table caption
- 7: 'table_footnote', # Table footnote
- 8: 'isolate_formula', # Interline formula
- 9: 'formula_caption', # Interline formula number
- 13: 'embedding', # Inline formula
- 14: 'isolated', # Interline formula
- 15: 'text', # OCR recognition result
- }
-
- # 类别对应的颜色(与MinerU保持一致)
- CATEGORY_COLORS = {
- 0: (102, 102, 255), # title: 蓝色
- 1: (153, 0, 76), # plain_text: 深红
- 2: (158, 158, 158), # abandon: 灰色
- 3: (153, 255, 51), # figure: 绿色
- 4: (102, 178, 255), # figure_caption: 浅蓝
- 5: (204, 204, 0), # table: 黄色
- 6: (255, 255, 102), # table_caption: 浅黄
- 7: (229, 255, 204), # table_footnote: 浅绿
- 8: (0, 255, 0), # isolate_formula: 亮绿
- 9: (255, 0, 0), # formula_caption: 红色
- }
- def __init__(
- self,
- weight: str,
- device: str = "cuda",
- imgsz: int = 1280,
- conf: float = 0.25,
- iou: float = 0.45,
- ):
- self.model = YOLOv10(weight).to(device)
- self.device = device
- self.imgsz = imgsz
- self.conf = conf
- self.iou = iou
- def _parse_prediction(self, prediction) -> List[Dict]:
- """解析模型预测结果"""
- layout_res = []
- if not hasattr(prediction, "boxes") or prediction.boxes is None:
- return layout_res
- for xyxy, conf, cls in zip(
- prediction.boxes.xyxy.cpu(),
- prediction.boxes.conf.cpu(),
- prediction.boxes.cls.cpu(),
- ):
- coords = list(map(int, xyxy.tolist()))
- xmin, ymin, xmax, ymax = coords
- layout_res.append({
- "category_id": int(cls.item()),
- "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
- "score": round(float(conf.item()), 3),
- })
- return layout_res
- def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
- """单张图片预测"""
- prediction = self.model.predict(
- image,
- imgsz=self.imgsz,
- conf=self.conf,
- iou=self.iou,
- verbose=False
- )[0]
- return self._parse_prediction(prediction)
- def batch_predict(
- self,
- images: List[Union[np.ndarray, Image.Image]],
- batch_size: int = 4
- ) -> List[List[Dict]]:
- """批量预测"""
- results = []
- with tqdm(total=len(images), desc="Layout Predict", disable=True) as pbar:
- for idx in range(0, len(images), batch_size):
- batch = images[idx: idx + batch_size]
- conf = 0.9 * self.conf if batch_size == 1 else self.conf
-
- predictions = self.model.predict(
- batch,
- imgsz=self.imgsz,
- conf=conf,
- iou=self.iou,
- verbose=False,
- )
- for pred in predictions:
- results.append(self._parse_prediction(pred))
- pbar.update(len(batch))
- return results
- def visualize(
- self,
- image: Union[np.ndarray, Image.Image],
- results: List[Dict],
- output_path: str = None,
- draw_type_label: bool = True,
- draw_score: bool = True,
- draw_order_number: bool = False,
- font_size: int = 14,
- line_width: int = 2,
- verbose: bool = False
- ) -> Image.Image:
- """可视化布局检测结果"""
- """
- Args:
- image: 输入图像(PIL Image或numpy array)
- results: 检测结果列表
- output_path: 输出图片路径(如果为None则不保存)
- draw_type_label: 是否标注类型名称(默认True)
- draw_score: 是否标注置信度分数(默认True)
- draw_order_number: 是否标注检测顺序编号(默认False)
- font_size: 字体大小(默认14)
- line_width: 边框线宽(默认2)
-
- Returns:
- PIL.Image: 标注后的图像
- """
- # 1. 转换图像格式
- if isinstance(image, np.ndarray):
- image = Image.fromarray(image)
- else:
- image = image.copy()
-
- draw = ImageDraw.Draw(image)
-
- # 2. 尝试加载字体
- try:
- font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", font_size)
- except:
- try:
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size)
- except:
- font = ImageFont.load_default()
-
- # 3. 绘制每个检测框
- for idx, res in enumerate(results, 1):
- poly = res['poly']
- xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
- category_id = res['category_id']
- score = res['score']
-
- # 获取类别名称和颜色
- category_name = self.CATEGORY_NAMES.get(category_id, f'unknown_{category_id}')
- color = self.CATEGORY_COLORS.get(category_id, (255, 0, 0))
-
- # 3.1 绘制边框
- draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=line_width)
-
- # 3.2 准备标注文本
- labels = []
- if draw_type_label:
- labels.append(category_name)
- if draw_score:
- labels.append(f"{score:.2f}")
- if draw_order_number:
- labels.append(f"#{idx}")
-
- label_text = " | ".join(labels) if labels else ""
-
- # 3.3 绘制标注文本(如果有)
- if label_text:
- # 计算文本背景框
- bbox = draw.textbbox((xmin + 2, ymin + 2), label_text, font=font)
- # 绘制半透明背景
- draw.rectangle(bbox, fill=(*color, 200))
- # 绘制文本
- draw.text((xmin + 2, ymin + 2), label_text, fill='white', font=font)
-
- if verbose:
- print(f"Box #{idx}: {category_name} [{xmin}, {ymin}, {xmax}, {ymax}] score={score:.3f}")
-
- # 4. 保存到文件(如果指定了路径)
- if output_path:
- os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
- image.save(output_path)
- if verbose:
- print(f"✅ Layout visualization saved to: {output_path}")
-
- return image
- def process_images_batch(
- image_paths: List[str],
- model: DocLayoutYOLOModel,
- output_dir: str = "./output",
- draw_type_label: bool = True,
- draw_score: bool = True,
- draw_order_number: bool = False,
- save_json: bool = True,
- save_visualization: bool = True,
- font_size: int = 14,
- line_width: int = 2
- ) -> List[Dict[str, Any]]:
- """批量处理图像"""
-
- # 创建输出目录
- output_path = Path(output_dir)
- output_path.mkdir(parents=True, exist_ok=True)
-
- if save_json:
- json_dir = output_path / "json"
- json_dir.mkdir(exist_ok=True)
-
- if save_visualization:
- viz_dir = output_path / "visualization"
- viz_dir.mkdir(exist_ok=True)
-
- all_results = []
- total_images = len(image_paths)
-
- print(f"Processing {total_images} images")
-
- # 使用tqdm显示进度
- with tqdm(total=total_images, desc="Processing images", unit="img",
- bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
-
- for img_path in image_paths:
- start_time = time.time()
-
- try:
- # 加载图片
- image = Image.open(img_path)
-
- # 预测
- results = model.predict(image)
- processing_time = time.time() - start_time
-
- # 生成输出文件名
- input_path = Path(img_path)
- output_filename = input_path.stem
-
- # 保存JSON
- json_output_path = None
- if save_json:
- json_output_path = json_dir / f"{output_filename}_layout.json"
- with open(json_output_path, 'w', encoding='utf-8') as f:
- json.dump({
- "image_path": str(img_path),
- "image_size": list(image.size),
- "layout_results": results,
- "processing_time": processing_time
- }, f, ensure_ascii=False, indent=2)
-
- # 保存可视化
- viz_output_path = None
- if save_visualization:
- viz_output_path = viz_dir / f"{output_filename}_layout.png"
- model.visualize(
- image,
- results,
- output_path=str(viz_output_path),
- draw_type_label=draw_type_label,
- draw_score=draw_score,
- draw_order_number=draw_order_number,
- font_size=font_size,
- line_width=line_width,
- verbose=False
- )
-
- # 记录结果
- all_results.append({
- "image_path": str(input_path),
- "processing_time": processing_time,
- "success": True,
- "num_detections": len(results),
- "output_json": str(json_output_path) if json_output_path else None,
- "output_viz": str(viz_output_path) if viz_output_path else None,
- "detections": results
- })
-
- # 更新进度条
- success_count = sum(1 for r in all_results if r.get('success', False))
- pbar.update(1)
- pbar.set_postfix({
- 'time': f"{processing_time:.2f}s",
- 'boxes': len(results),
- 'success': f"{success_count}/{len(all_results)}"
- })
-
- except Exception as e:
- print(f"\n❌ Error processing {Path(img_path).name}: {e}", file=sys.stderr)
- traceback.print_exc()
-
- all_results.append({
- "image_path": str(img_path),
- "processing_time": 0,
- "success": False,
- "error": str(e)
- })
- pbar.update(1)
-
- return all_results
- def collect_results(results: List[Dict], output_csv: str):
- """收集处理结果到CSV"""
- import csv
-
- with open(output_csv, 'w', encoding='utf-8', newline='') as f:
- writer = csv.writer(f)
- writer.writerow(['image_path', 'status', 'num_detections', 'processing_time'])
-
- for result in results:
- writer.writerow([
- result['image_path'],
- 'success' if result.get('success', False) else 'failed',
- result.get('num_detections', 0),
- f"{result.get('processing_time', 0):.2f}"
- ])
- def main():
- """主函数"""
- parser = argparse.ArgumentParser(description="DocLayout-YOLO Batch Layout Detection Tool")
-
- # 输入参数
- input_group = parser.add_mutually_exclusive_group(required=True)
- input_group.add_argument("--input_file", type=str, help="Input image file")
- input_group.add_argument("--input_dir", type=str, help="Input directory")
- input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
- input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path column")
-
- # 输出参数
- parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
- parser.add_argument("--collect_results", type=str, help="Collect results to CSV file")
-
- # 模型参数
- parser.add_argument("--model_path", type=str, help="Custom model path (optional)")
- parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device")
- parser.add_argument("--imgsz", type=int, default=1280, help="Image size for inference")
- parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
- parser.add_argument("--iou", type=float, default=0.45, help="IoU threshold")
-
- # 可视化参数
- parser.add_argument("--no-visualization", action="store_true", help="Disable visualization output")
- parser.add_argument("--no-json", action="store_true", help="Disable JSON output")
- parser.add_argument("--draw_type_label", action="store_true", default=True, help="Draw type labels")
- parser.add_argument("--draw_score", action="store_true", default=True, help="Draw confidence scores")
- parser.add_argument("--draw_order_number", action="store_true", help="Draw detection order numbers")
- parser.add_argument("--font_size", type=int, default=14, help="Font size for labels")
- parser.add_argument("--line_width", type=int, default=2, help="Line width for bounding boxes")
-
- # 其他参数
- parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
-
- args = parser.parse_args()
-
- try:
- # 获取输入文件
- print("🔄 Getting input files...")
- input_files = get_input_files(args)
-
- if not input_files:
- print("❌ No input files found")
- return 1
-
- if args.test_mode:
- input_files = input_files[:20]
- print(f"🧪 Test mode: processing only {len(input_files)} images")
-
- # 加载模型
- print("🔄 Loading DocLayout-YOLO model...")
- if args.model_path:
- model_path = args.model_path
- else:
- model_path = os.path.join(
- auto_download_and_get_model_root_path(ModelPath.doclayout_yolo),
- ModelPath.doclayout_yolo
- )
-
- model = DocLayoutYOLOModel(
- weight=model_path,
- device=args.device,
- imgsz=args.imgsz,
- conf=args.conf,
- iou=args.iou
- )
-
- print(f"✅ Model loaded: {model_path}")
- print(f"🔧 Device: {args.device}")
- print(f"🔧 Image size: {args.imgsz}")
- print(f"🔧 Confidence threshold: {args.conf}")
- print(f"🔧 IoU threshold: {args.iou}")
-
- # 开始处理
- start_time = time.time()
- results = process_images_batch(
- input_files,
- model,
- args.output_dir,
- draw_type_label=args.draw_type_label,
- draw_score=args.draw_score,
- draw_order_number=args.draw_order_number,
- save_json=not args.no_json,
- save_visualization=not args.no_visualization,
- font_size=args.font_size,
- line_width=args.line_width
- )
- total_time = time.time() - start_time
-
- # 统计结果
- success_count = sum(1 for r in results if r.get('success', False))
- error_count = len(results) - success_count
- total_detections = sum(r.get('num_detections', 0) for r in results if r.get('success', False))
-
- print(f"\n" + "="*60)
- print(f"✅ Processing completed!")
- print(f"📊 Statistics:")
- print(f" Total files processed: {len(input_files)}")
- print(f" Successful: {success_count}")
- print(f" Failed: {error_count}")
- if len(input_files) > 0:
- print(f" Success rate: {success_count / len(input_files) * 100:.2f}%")
- print(f" Total detections: {total_detections}")
- if success_count > 0:
- print(f" Avg detections per image: {total_detections / success_count:.2f}")
- print(f"⏱️ Performance:")
- print(f" Total time: {total_time:.2f} seconds")
- if total_time > 0:
- print(f" Throughput: {len(input_files) / total_time:.2f} images/second")
- print(f" Avg time per image: {total_time / len(input_files):.2f} seconds")
-
- # 保存结果统计
- stats = {
- "total_files": len(input_files),
- "success_count": success_count,
- "error_count": error_count,
- "success_rate": success_count / len(input_files) if len(input_files) > 0 else 0,
- "total_detections": total_detections,
- "avg_detections": total_detections / success_count if success_count > 0 else 0,
- "total_time": total_time,
- "throughput": len(input_files) / total_time if total_time > 0 else 0,
- "avg_time_per_image": total_time / len(input_files) if len(input_files) > 0 else 0,
- "model_path": model_path,
- "device": args.device,
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
- }
-
- # 保存最终结果
- output_file_name = Path(args.output_dir).name
- output_file = os.path.join(args.output_dir, f"{output_file_name}_results.json")
- final_results = {
- "stats": stats,
- "results": results
- }
-
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump(final_results, f, ensure_ascii=False, indent=2)
-
- print(f"💾 Results saved to: {output_file}")
-
- # 收集处理结果
- if args.collect_results:
- output_csv = Path(args.collect_results).resolve()
- else:
- output_csv = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
-
- collect_results(results, str(output_csv))
- print(f"💾 Processed files saved to: {output_csv}")
-
- return 0
-
- except Exception as e:
- print(f"❌ Processing failed: {e}", file=sys.stderr)
- traceback.print_exc()
- return 1
- if __name__ == "__main__":
- print(f"🚀 Starting DocLayout-YOLO Batch Processing...")
-
- if len(sys.argv) == 1:
- # 默认配置
- print("ℹ️ No command line arguments provided. Running with default configuration...")
-
- default_config = {
- # "input_file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水.img/B用户_扫描流水_page_002.png",
- "input_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/data_PPStructureV3_Results/B用户_扫描流水",
- "output_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/doclayout_yolo_results",
- "device": "cpu",
- "draw_type_label": True,
- "draw_score": True,
- "draw_order_number": True,
- }
-
- sys.argv = [sys.argv[0]]
- for key, value in default_config.items():
- if isinstance(value, bool):
- if value:
- sys.argv.append(f"--{key}")
- else:
- sys.argv.extend([f"--{key}", str(value)])
-
- sys.exit(main())
|