Переглянути джерело

feat: 新增DocLayout-YOLO批量布局检测工具,支持批量处理图像并输出结果

zhch158_admin 3 тижнів тому
батько
коміт
37032b5861
1 змінених файлів з 531 додано та 0 видалено
  1. 531 0
      zhch/model_evaluator/doclayoutyolo_batch.py

+ 531 - 0
zhch/model_evaluator/doclayoutyolo_batch.py

@@ -0,0 +1,531 @@
+"""DocLayout-YOLO 批量布局检测工具"""
+import json
+import time
+import os
+import traceback
+import argparse
+import sys
+from pathlib import Path
+from typing import List, Dict, Any, Union
+from tqdm import tqdm
+
+from doclayout_yolo import YOLOv10
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from utils import get_input_files
+
+
+class DocLayoutYOLOModel:
+    """DocLayout-YOLO 模型封装类"""
+    
+    # 类别ID到名称的映射(DocLayout-YOLO标准类别)
+    # 类型定义参见:docs/en/reference/output_files.md
+    CATEGORY_NAMES = {
+        0: 'title',            # Title 
+        1: 'plain_text',       # text
+        2: 'abandon',          # Including headers, footers, page numbers, and page annotations
+        3: 'figure',           # Image
+        4: 'figure_caption',   # Image caption
+        5: 'table',            # Table
+        6: 'table_caption',    # Table caption
+        7: 'table_footnote',   # Table footnote
+        8: 'isolate_formula',  # Interline formula
+        9: 'formula_caption',  # Interline formula number
+        13: 'embedding',       # Inline formula
+        14: 'isolated',        # Interline formula
+        15: 'text',            # OCR recognition result
+    }
+    
+    # 类别对应的颜色(与MinerU保持一致)
+    CATEGORY_COLORS = {
+        0: (102, 102, 255),      # title: 蓝色
+        1: (153, 0, 76),         # plain_text: 深红
+        2: (158, 158, 158),      # abandon: 灰色
+        3: (153, 255, 51),       # figure: 绿色
+        4: (102, 178, 255),      # figure_caption: 浅蓝
+        5: (204, 204, 0),        # table: 黄色
+        6: (255, 255, 102),      # table_caption: 浅黄
+        7: (229, 255, 204),      # table_footnote: 浅绿
+        8: (0, 255, 0),          # isolate_formula: 亮绿
+        9: (255, 0, 0),          # formula_caption: 红色
+    }
+
+    def __init__(
+        self,
+        weight: str,
+        device: str = "cuda",
+        imgsz: int = 1280,
+        conf: float = 0.25,
+        iou: float = 0.45,
+    ):
+        self.model = YOLOv10(weight).to(device)
+        self.device = device
+        self.imgsz = imgsz
+        self.conf = conf
+        self.iou = iou
+
+    def _parse_prediction(self, prediction) -> List[Dict]:
+        """解析模型预测结果"""
+        layout_res = []
+
+        if not hasattr(prediction, "boxes") or prediction.boxes is None:
+            return layout_res
+
+        for xyxy, conf, cls in zip(
+            prediction.boxes.xyxy.cpu(),
+            prediction.boxes.conf.cpu(),
+            prediction.boxes.cls.cpu(),
+        ):
+            coords = list(map(int, xyxy.tolist()))
+            xmin, ymin, xmax, ymax = coords
+            layout_res.append({
+                "category_id": int(cls.item()),
+                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                "score": round(float(conf.item()), 3),
+            })
+        return layout_res
+
+    def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
+        """单张图片预测"""
+        prediction = self.model.predict(
+            image,
+            imgsz=self.imgsz,
+            conf=self.conf,
+            iou=self.iou,
+            verbose=False
+        )[0]
+        return self._parse_prediction(prediction)
+
+    def batch_predict(
+        self,
+        images: List[Union[np.ndarray, Image.Image]],
+        batch_size: int = 4
+    ) -> List[List[Dict]]:
+        """批量预测"""
+        results = []
+        with tqdm(total=len(images), desc="Layout Predict", disable=True) as pbar:
+            for idx in range(0, len(images), batch_size):
+                batch = images[idx: idx + batch_size]
+                conf = 0.9 * self.conf if batch_size == 1 else self.conf
+                
+                predictions = self.model.predict(
+                    batch,
+                    imgsz=self.imgsz,
+                    conf=conf,
+                    iou=self.iou,
+                    verbose=False,
+                )
+                for pred in predictions:
+                    results.append(self._parse_prediction(pred))
+                pbar.update(len(batch))
+        return results
+
+    def visualize(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        results: List[Dict],
+        output_path: str = None,
+        draw_type_label: bool = True,
+        draw_score: bool = True,
+        draw_order_number: bool = False,
+        font_size: int = 14,
+        line_width: int = 2,
+        verbose: bool = False
+    ) -> Image.Image:
+        """可视化布局检测结果"""
+        """
+        Args:
+            image: 输入图像(PIL Image或numpy array)
+            results: 检测结果列表
+            output_path: 输出图片路径(如果为None则不保存)
+            draw_type_label: 是否标注类型名称(默认True)
+            draw_score: 是否标注置信度分数(默认True)
+            draw_order_number: 是否标注检测顺序编号(默认False)
+            font_size: 字体大小(默认14)
+            line_width: 边框线宽(默认2)
+        
+        Returns:
+            PIL.Image: 标注后的图像
+        """
+        # 1. 转换图像格式
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+        else:
+            image = image.copy()
+        
+        draw = ImageDraw.Draw(image)
+        
+        # 2. 尝试加载字体
+        try:
+            font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", font_size)
+        except:
+            try:
+                font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size)
+            except:
+                font = ImageFont.load_default()
+        
+        # 3. 绘制每个检测框
+        for idx, res in enumerate(results, 1):
+            poly = res['poly']
+            xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
+            category_id = res['category_id']
+            score = res['score']
+            
+            # 获取类别名称和颜色
+            category_name = self.CATEGORY_NAMES.get(category_id, f'unknown_{category_id}')
+            color = self.CATEGORY_COLORS.get(category_id, (255, 0, 0))
+            
+            # 3.1 绘制边框
+            draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=line_width)
+            
+            # 3.2 准备标注文本
+            labels = []
+            if draw_type_label:
+                labels.append(category_name)
+            if draw_score:
+                labels.append(f"{score:.2f}")
+            if draw_order_number:
+                labels.append(f"#{idx}")
+            
+            label_text = " | ".join(labels) if labels else ""
+            
+            # 3.3 绘制标注文本(如果有)
+            if label_text:
+                # 计算文本背景框
+                bbox = draw.textbbox((xmin + 2, ymin + 2), label_text, font=font)
+                # 绘制半透明背景
+                draw.rectangle(bbox, fill=(*color, 200))
+                # 绘制文本
+                draw.text((xmin + 2, ymin + 2), label_text, fill='white', font=font)
+            
+            if verbose:
+                print(f"Box #{idx}: {category_name} [{xmin}, {ymin}, {xmax}, {ymax}] score={score:.3f}")
+        
+        # 4. 保存到文件(如果指定了路径)
+        if output_path:
+            os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
+            image.save(output_path)
+            if verbose:
+                print(f"✅ Layout visualization saved to: {output_path}")
+        
+        return image
+
+
+def process_images_batch(
+    image_paths: List[str],
+    model: DocLayoutYOLOModel,
+    output_dir: str = "./output",
+    draw_type_label: bool = True,
+    draw_score: bool = True,
+    draw_order_number: bool = False,
+    save_json: bool = True,
+    save_visualization: bool = True,
+    font_size: int = 14,
+    line_width: int = 2
+) -> List[Dict[str, Any]]:
+    """批量处理图像"""
+    
+    # 创建输出目录
+    output_path = Path(output_dir)
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    if save_json:
+        json_dir = output_path / "json"
+        json_dir.mkdir(exist_ok=True)
+    
+    if save_visualization:
+        viz_dir = output_path / "visualization"
+        viz_dir.mkdir(exist_ok=True)
+    
+    all_results = []
+    total_images = len(image_paths)
+    
+    print(f"Processing {total_images} images")
+    
+    # 使用tqdm显示进度
+    with tqdm(total=total_images, desc="Processing images", unit="img",
+              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+        
+        for img_path in image_paths:
+            start_time = time.time()
+            
+            try:
+                # 加载图片
+                image = Image.open(img_path)
+                
+                # 预测
+                results = model.predict(image)
+                processing_time = time.time() - start_time
+                
+                # 生成输出文件名
+                input_path = Path(img_path)
+                output_filename = input_path.stem
+                
+                # 保存JSON
+                json_output_path = None
+                if save_json:
+                    json_output_path = json_dir / f"{output_filename}_layout.json"
+                    with open(json_output_path, 'w', encoding='utf-8') as f:
+                        json.dump({
+                            "image_path": str(img_path),
+                            "image_size": list(image.size),
+                            "layout_results": results,
+                            "processing_time": processing_time
+                        }, f, ensure_ascii=False, indent=2)
+                
+                # 保存可视化
+                viz_output_path = None
+                if save_visualization:
+                    viz_output_path = viz_dir / f"{output_filename}_layout.png"
+                    model.visualize(
+                        image,
+                        results,
+                        output_path=str(viz_output_path),
+                        draw_type_label=draw_type_label,
+                        draw_score=draw_score,
+                        draw_order_number=draw_order_number,
+                        font_size=font_size,
+                        line_width=line_width,
+                        verbose=False
+                    )
+                
+                # 记录结果
+                all_results.append({
+                    "image_path": str(input_path),
+                    "processing_time": processing_time,
+                    "success": True,
+                    "num_detections": len(results),
+                    "output_json": str(json_output_path) if json_output_path else None,
+                    "output_viz": str(viz_output_path) if viz_output_path else None,
+                    "detections": results
+                })
+                
+                # 更新进度条
+                success_count = sum(1 for r in all_results if r.get('success', False))
+                pbar.update(1)
+                pbar.set_postfix({
+                    'time': f"{processing_time:.2f}s",
+                    'boxes': len(results),
+                    'success': f"{success_count}/{len(all_results)}"
+                })
+                
+            except Exception as e:
+                print(f"\n❌ Error processing {Path(img_path).name}: {e}", file=sys.stderr)
+                traceback.print_exc()
+                
+                all_results.append({
+                    "image_path": str(img_path),
+                    "processing_time": 0,
+                    "success": False,
+                    "error": str(e)
+                })
+                pbar.update(1)
+    
+    return all_results
+
+
+def collect_results(results: List[Dict], output_csv: str):
+    """收集处理结果到CSV"""
+    import csv
+    
+    with open(output_csv, 'w', encoding='utf-8', newline='') as f:
+        writer = csv.writer(f)
+        writer.writerow(['image_path', 'status', 'num_detections', 'processing_time'])
+        
+        for result in results:
+            writer.writerow([
+                result['image_path'],
+                'success' if result.get('success', False) else 'failed',
+                result.get('num_detections', 0),
+                f"{result.get('processing_time', 0):.2f}"
+            ])
+
+
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(description="DocLayout-YOLO Batch Layout Detection Tool")
+    
+    # 输入参数
+    input_group = parser.add_mutually_exclusive_group(required=True)
+    input_group.add_argument("--input_file", type=str, help="Input image file")
+    input_group.add_argument("--input_dir", type=str, help="Input directory")
+    input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
+    input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path column")
+    
+    # 输出参数
+    parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
+    parser.add_argument("--collect_results", type=str, help="Collect results to CSV file")
+    
+    # 模型参数
+    parser.add_argument("--model_path", type=str, help="Custom model path (optional)")
+    parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device")
+    parser.add_argument("--imgsz", type=int, default=1280, help="Image size for inference")
+    parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
+    parser.add_argument("--iou", type=float, default=0.45, help="IoU threshold")
+    
+    # 可视化参数
+    parser.add_argument("--no-visualization", action="store_true", help="Disable visualization output")
+    parser.add_argument("--no-json", action="store_true", help="Disable JSON output")
+    parser.add_argument("--draw_type_label", action="store_true", default=True, help="Draw type labels")
+    parser.add_argument("--draw_score", action="store_true", default=True, help="Draw confidence scores")
+    parser.add_argument("--draw_order_number", action="store_true", help="Draw detection order numbers")
+    parser.add_argument("--font_size", type=int, default=14, help="Font size for labels")
+    parser.add_argument("--line_width", type=int, default=2, help="Line width for bounding boxes")
+    
+    # 其他参数
+    parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
+    
+    args = parser.parse_args()
+    
+    try:
+        # 获取输入文件
+        print("🔄 Getting input files...")
+        input_files = get_input_files(args)
+        
+        if not input_files:
+            print("❌ No input files found")
+            return 1
+        
+        if args.test_mode:
+            input_files = input_files[:20]
+            print(f"🧪 Test mode: processing only {len(input_files)} images")
+        
+        # 加载模型
+        print("🔄 Loading DocLayout-YOLO model...")
+        if args.model_path:
+            model_path = args.model_path
+        else:
+            model_path = os.path.join(
+                auto_download_and_get_model_root_path(ModelPath.doclayout_yolo),
+                ModelPath.doclayout_yolo
+            )
+        
+        model = DocLayoutYOLOModel(
+            weight=model_path,
+            device=args.device,
+            imgsz=args.imgsz,
+            conf=args.conf,
+            iou=args.iou
+        )
+        
+        print(f"✅ Model loaded: {model_path}")
+        print(f"🔧 Device: {args.device}")
+        print(f"🔧 Image size: {args.imgsz}")
+        print(f"🔧 Confidence threshold: {args.conf}")
+        print(f"🔧 IoU threshold: {args.iou}")
+        
+        # 开始处理
+        start_time = time.time()
+        results = process_images_batch(
+            input_files,
+            model,
+            args.output_dir,
+            draw_type_label=args.draw_type_label,
+            draw_score=args.draw_score,
+            draw_order_number=args.draw_order_number,
+            save_json=not args.no_json,
+            save_visualization=not args.no_visualization,
+            font_size=args.font_size,
+            line_width=args.line_width
+        )
+        total_time = time.time() - start_time
+        
+        # 统计结果
+        success_count = sum(1 for r in results if r.get('success', False))
+        error_count = len(results) - success_count
+        total_detections = sum(r.get('num_detections', 0) for r in results if r.get('success', False))
+        
+        print(f"\n" + "="*60)
+        print(f"✅ Processing completed!")
+        print(f"📊 Statistics:")
+        print(f"  Total files processed: {len(input_files)}")
+        print(f"  Successful: {success_count}")
+        print(f"  Failed: {error_count}")
+        if len(input_files) > 0:
+            print(f"  Success rate: {success_count / len(input_files) * 100:.2f}%")
+        print(f"  Total detections: {total_detections}")
+        if success_count > 0:
+            print(f"  Avg detections per image: {total_detections / success_count:.2f}")
+        print(f"⏱️ Performance:")
+        print(f"  Total time: {total_time:.2f} seconds")
+        if total_time > 0:
+            print(f"  Throughput: {len(input_files) / total_time:.2f} images/second")
+            print(f"  Avg time per image: {total_time / len(input_files):.2f} seconds")
+        
+        # 保存结果统计
+        stats = {
+            "total_files": len(input_files),
+            "success_count": success_count,
+            "error_count": error_count,
+            "success_rate": success_count / len(input_files) if len(input_files) > 0 else 0,
+            "total_detections": total_detections,
+            "avg_detections": total_detections / success_count if success_count > 0 else 0,
+            "total_time": total_time,
+            "throughput": len(input_files) / total_time if total_time > 0 else 0,
+            "avg_time_per_image": total_time / len(input_files) if len(input_files) > 0 else 0,
+            "model_path": model_path,
+            "device": args.device,
+            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
+        }
+        
+        # 保存最终结果
+        output_file_name = Path(args.output_dir).name
+        output_file = os.path.join(args.output_dir, f"{output_file_name}_results.json")
+        final_results = {
+            "stats": stats,
+            "results": results
+        }
+        
+        with open(output_file, 'w', encoding='utf-8') as f:
+            json.dump(final_results, f, ensure_ascii=False, indent=2)
+        
+        print(f"💾 Results saved to: {output_file}")
+        
+        # 收集处理结果
+        if args.collect_results:
+            output_csv = Path(args.collect_results).resolve()
+        else:
+            output_csv = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
+        
+        collect_results(results, str(output_csv))
+        print(f"💾 Processed files saved to: {output_csv}")
+        
+        return 0
+        
+    except Exception as e:
+        print(f"❌ Processing failed: {e}", file=sys.stderr)
+        traceback.print_exc()
+        return 1
+
+
+if __name__ == "__main__":
+    print(f"🚀 Starting DocLayout-YOLO Batch Processing...")
+    
+    if len(sys.argv) == 1:
+        # 默认配置
+        print("ℹ️  No command line arguments provided. Running with default configuration...")
+        
+        default_config = {
+            # "input_file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水.img/B用户_扫描流水_page_002.png",
+            "input_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/data_PPStructureV3_Results/B用户_扫描流水",
+            "output_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/doclayout_yolo_results",
+            "device": "cpu",
+            "draw_type_label": True,
+            "draw_score": True,
+            "draw_order_number": True,
+        }
+        
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            if isinstance(value, bool):
+                if value:
+                    sys.argv.append(f"--{key}")
+            else:
+                sys.argv.extend([f"--{key}", str(value)])
+    
+    sys.exit(main())