Prechádzať zdrojové kódy

feat: 添加PP-StructureV3多GPU多进程处理功能,支持批量图像处理和结果保存

zhch158_admin 3 mesiacov pred
rodič
commit
56b8c69ce8

+ 518 - 0
zhch/ppstructurev3_multi_gpu_multiprocess_official.py

@@ -0,0 +1,518 @@
+# zhch/ppstructurev3_multi_gpu_multiprocess_official.py
+import json
+import time
+import os
+import glob
+import traceback
+import argparse
+import sys
+from pathlib import Path
+from typing import List, Dict, Any, Tuple
+from multiprocessing import Manager, Process, Queue
+from queue import Empty
+import cv2
+import numpy as np
+from paddlex import create_pipeline
+from paddlex.utils.device import constr_device, parse_device
+from tqdm import tqdm
+import paddle
+
+from dotenv import load_dotenv
+load_dotenv(override=True)
+
+def worker(pipeline_name_or_config_path: str, 
+          device: str, 
+          task_queue: Queue, 
+          result_queue: Queue,
+          batch_size: int, 
+          output_dir: str,
+          worker_id: int):
+    """
+    工作进程函数 - 基于官方parallel_inference.md实现
+    
+    Args:
+        pipeline_name_or_config_path: Pipeline名称或配置路径
+        device: 设备字符串
+        task_queue: 任务队列
+        result_queue: 结果队列
+        batch_size: 批处理大小
+        output_dir: 输出目录
+        worker_id: 工作进程ID
+    """
+    try:
+        # 创建pipeline实例
+        pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
+        print(f"Worker {worker_id} initialized with device {device}")
+        
+        should_end = False
+        batch = []
+        processed_count = 0
+        
+        while not should_end:
+            try:
+                input_path = task_queue.get_nowait()
+            except Empty:
+                should_end = True
+            else:
+                batch.append(input_path)
+            
+            if batch and (len(batch) == batch_size or should_end):
+                try:
+                    start_time = time.time()
+                    
+                    # 使用pipeline预测
+                    results = list(pipeline.predict(
+                        batch,
+                        use_doc_orientation_classify=True,
+                        use_doc_unwarping=False,
+                        use_seal_recognition=True,
+                        use_chart_recognition=True,
+                        use_table_recognition=True,
+                        use_formula_recognition=True,
+                    ))
+                    
+                    batch_processing_time = time.time() - start_time
+                    batch_results = []
+                    
+                    for result in results:
+                        try:
+                            input_path = Path(result.input_path)
+                            
+                            # 保存结果
+                            if result.get("page_index") is not None:
+                                output_filename = f"{input_path.stem}_{result['page_index']}"
+                            else:
+                                output_filename = f"{input_path.stem}"
+                            
+                            # 保存JSON和Markdown
+                            json_output_path = str(Path(output_dir, f"{output_filename}.json"))
+                            md_output_path = str(Path(output_dir, f"{output_filename}.md"))
+                            
+                            result.save_to_json(json_output_path)
+                            result.save_to_markdown(md_output_path)
+                            
+                            # 记录处理结果
+                            batch_results.append({
+                                "image_path": input_path.name,
+                                "processing_time": batch_processing_time / len(batch),  # 平均时间
+                                "success": True,
+                                "device": device,
+                                "worker_id": worker_id,
+                                "output_json": json_output_path,
+                                "output_md": md_output_path
+                            })
+                            
+                            processed_count += 1
+                            
+                        except Exception as e:
+                            batch_results.append({
+                                "image_path": Path(result.input_path).name if hasattr(result, 'input_path') else "unknown",
+                                "processing_time": 0,
+                                "success": False,
+                                "device": device,
+                                "worker_id": worker_id,
+                                "error": str(e)
+                            })
+                    
+                    # 将结果放入结果队列
+                    result_queue.put(batch_results)
+                    
+                    print(f"Worker {worker_id} ({device}) processed batch of {len(batch)} files. Total: {processed_count}")
+                    
+                except Exception as e:
+                    # 批处理失败
+                    error_results = []
+                    for img_path in batch:
+                        error_results.append({
+                            "image_path": Path(img_path).name,
+                            "processing_time": 0,
+                            "success": False,
+                            "device": device,
+                            "worker_id": worker_id,
+                            "error": str(e)
+                        })
+                    result_queue.put(error_results)
+                    
+                    print(f"Error processing batch {batch} on {device}: {e}", file=sys.stderr)
+                
+                batch.clear()
+    
+    except Exception as e:
+        print(f"Worker {worker_id} ({device}) initialization failed: {e}", file=sys.stderr)
+        traceback.print_exc()
+    finally:
+        print(f"Worker {worker_id} ({device}) finished")
+
+def parallel_process_with_official_approach(image_paths: List[str],
+                                          pipeline_name: str = "PP-StructureV3",
+                                          device_str: str = "gpu:0,1",
+                                          instances_per_device: int = 1,
+                                          batch_size: int = 1,
+                                          output_dir: str = "./output") -> List[Dict[str, Any]]:
+    """
+    使用官方推荐的方法进行多GPU多进程并行处理
+    
+    Args:
+        image_paths: 图像路径列表
+        pipeline_name: Pipeline名称
+        device_str: 设备字符串,如"gpu:0,1,2,3"
+        instances_per_device: 每个设备的实例数
+        batch_size: 批处理大小
+        output_dir: 输出目录
+        
+    Returns:
+        处理结果列表
+    """
+    # 创建输出目录
+    output_path = Path(output_dir)
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    # 解析设备
+    try:
+        device_type, device_ids = parse_device(device_str)
+        if device_ids is None or len(device_ids) < 1:
+            print("No valid devices specified.", file=sys.stderr)
+            return []
+        
+        print(f"Parsed devices: {device_type}:{device_ids}")
+        
+    except Exception as e:
+        print(f"Failed to parse device string '{device_str}': {e}", file=sys.stderr)
+        return []
+    
+    # 验证批处理大小
+    if batch_size <= 0:
+        print("Batch size must be greater than 0.", file=sys.stderr)
+        return []
+    
+    total_instances = len(device_ids) * instances_per_device
+    print(f"Configuration:")
+    print(f"  Devices: {device_ids}")
+    print(f"  Instances per device: {instances_per_device}")
+    print(f"  Total instances: {total_instances}")
+    print(f"  Batch size: {batch_size}")
+    print(f"  Total images: {len(image_paths)}")
+    
+    # 使用Manager创建队列
+    with Manager() as manager:
+        task_queue = manager.Queue()
+        result_queue = manager.Queue()
+        
+        # 将任务放入队列
+        for img_path in image_paths:
+            task_queue.put(str(img_path))
+        
+        print(f"Added {len(image_paths)} tasks to queue")
+        
+        # 创建并启动工作进程
+        processes = []
+        worker_id = 0
+        
+        for device_id in device_ids:
+            for instance_idx in range(instances_per_device):
+                device = constr_device(device_type, [device_id])
+                
+                p = Process(
+                    target=worker,
+                    args=(
+                        pipeline_name,
+                        device,
+                        task_queue,
+                        result_queue,
+                        batch_size,
+                        str(output_path),
+                        worker_id,
+                    ),
+                    name=f"Worker-{worker_id}-{device}"
+                )
+                p.start()
+                processes.append(p)
+                worker_id += 1
+        
+        print(f"Started {len(processes)} worker processes")
+        
+        # 收集结果
+        all_results = []
+        completed_images = 0
+        total_images = len(image_paths)
+        
+        with tqdm(total=total_images, desc="Processing images", unit="img") as pbar:
+            # 等待所有结果
+            active_workers = len(processes)
+            
+            while completed_images < total_images and active_workers > 0:
+                try:
+                    # 设置较短的超时时间,定期检查进程状态
+                    batch_results = result_queue.get(timeout=5.0)
+                    
+                    all_results.extend(batch_results)
+                    batch_size_actual = len(batch_results)
+                    completed_images += batch_size_actual
+                    
+                    pbar.update(batch_size_actual)
+                    
+                    # 更新进度条信息
+                    success_count = sum(1 for r in batch_results if r.get('success', False))
+                    total_success = sum(1 for r in all_results if r.get('success', False))
+                    
+                    # 按设备统计
+                    device_stats = {}
+                    for r in all_results:
+                        device = r.get('device', 'unknown')
+                        if device not in device_stats:
+                            device_stats[device] = {'success': 0, 'total': 0}
+                        device_stats[device]['total'] += 1
+                        if r.get('success', False):
+                            device_stats[device]['success'] += 1
+                    
+                    device_info = ', '.join([f"{k}:{v['success']}/{v['total']}" 
+                                           for k, v in device_stats.items()])
+                    
+                    pbar.set_postfix({
+                        'batch_success': f"{success_count}/{batch_size_actual}",
+                        'total_success': f"{total_success}/{completed_images}",
+                        'devices': device_info
+                    })
+                    
+                except Exception as e:
+                    # 检查是否还有活跃的进程
+                    active_workers = sum(1 for p in processes if p.is_alive())
+                    if active_workers == 0:
+                        print("All workers have finished")
+                        break
+                    
+                    # 超时或其他错误,继续等待
+                    continue
+        
+        # 等待所有进程结束
+        print("Waiting for all processes to finish...")
+        for p in processes:
+            p.join(timeout=10.0)
+            if p.is_alive():
+                print(f"Force terminating process: {p.name}")
+                p.terminate()
+                p.join(timeout=5.0)
+    
+    return all_results
+
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Multi-GPU Parallel Processing")
+    
+    # 必需参数
+    parser.add_argument("--input_dir", type=str, required=True, 
+                       help="Input directory containing images")
+    parser.add_argument("--output_dir", type=str, default="./output", 
+                       help="Output directory")
+    
+    # Pipeline配置
+    parser.add_argument("--pipeline", type=str, default="PP-StructureV3",
+                       help="Pipeline name or config path")
+    parser.add_argument("--device", type=str, default="gpu:0,1",
+                       help="Devices for parallel inference (e.g., 'gpu:0,1,2,3')")
+    
+    # 并行配置
+    parser.add_argument("--instances_per_device", type=int, default=1,
+                       help="Number of pipeline instances per device")
+    parser.add_argument("--batch_size", type=int, default=1,
+                       help="Inference batch size for each pipeline instance")
+    
+    # 输入文件配置
+    parser.add_argument("--input_glob_pattern", type=str, default="*",
+                       help="Pattern to find input files")
+    
+    # 测试模式
+    parser.add_argument("--test_mode", action="store_true",
+                       help="Test mode: only process first 20 images")
+    
+    args = parser.parse_args()
+    
+    # 验证输入目录
+    input_dir = Path(args.input_dir)
+    if not input_dir.exists():
+        print(f"Input directory does not exist: {input_dir}", file=sys.stderr)
+        return 2
+    if not input_dir.is_dir():
+        print(f"{input_dir} is not a directory", file=sys.stderr)
+        return 2
+    
+    # 验证输出目录
+    output_dir = Path(args.output_dir)
+    if output_dir.exists() and not output_dir.is_dir():
+        print(f"{output_dir} is not a directory", file=sys.stderr)
+        return 2
+    
+    print("="*70)
+    print("PaddleX PP-StructureV3 Multi-GPU Parallel Processing")
+    print("="*70)
+    print(f"Input directory: {input_dir}")
+    print(f"Output directory: {output_dir}")
+    print(f"Pipeline: {args.pipeline}")
+    print(f"Device: {args.device}")
+    print(f"Instances per device: {args.instances_per_device}")
+    print(f"Batch size: {args.batch_size}")
+    print(f"Input pattern: {args.input_glob_pattern}")
+    
+    # 查找图像文件
+    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.pdf']
+    image_files = []
+    
+    for ext in image_extensions:
+        pattern = args.input_glob_pattern if args.input_glob_pattern != "*" else ext
+        image_files.extend(input_dir.glob(pattern))
+    
+    # 如果没有找到文件,尝试使用用户指定的模式
+    if not image_files and args.input_glob_pattern != "*":
+        image_files = list(input_dir.glob(args.input_glob_pattern))
+    
+    if not image_files:
+        print(f"No image files found in {input_dir} with pattern {args.input_glob_pattern}")
+        return 1
+    
+    # 转换为字符串路径
+    image_paths = [str(f) for f in image_files]
+    
+    print(f"Found {len(image_paths)} image files")
+    
+    # 测试模式
+    if args.test_mode:
+        image_paths = image_paths[:20]
+        print(f"Test mode: processing only {len(image_paths)} images")
+    
+    # 开始处理
+    start_time = time.time()
+    
+    try:
+        results = parallel_process_with_official_approach(
+            image_paths=image_paths,
+            pipeline_name=args.pipeline,
+            device_str=args.device,
+            instances_per_device=args.instances_per_device,
+            batch_size=args.batch_size,
+            output_dir=args.output_dir
+        )
+        
+        total_time = time.time() - start_time
+        
+        # 统计信息
+        success_count = sum(1 for r in results if r.get('success', False))
+        error_count = len(results) - success_count
+        total_processing_time = sum(r.get('processing_time', 0) for r in results if r.get('success', False))
+        avg_processing_time = total_processing_time / success_count if success_count > 0 else 0
+        
+        # 按设备统计
+        device_stats = {}
+        worker_stats = {}
+        
+        for r in results:
+            device = r.get('device', 'unknown')
+            worker_id = r.get('worker_id', 'unknown')
+            
+            # 设备统计
+            if device not in device_stats:
+                device_stats[device] = {'success': 0, 'total': 0, 'total_time': 0}
+            device_stats[device]['total'] += 1
+            if r.get('success', False):
+                device_stats[device]['success'] += 1
+                device_stats[device]['total_time'] += r.get('processing_time', 0)
+            
+            # Worker统计
+            if worker_id not in worker_stats:
+                worker_stats[worker_id] = {'success': 0, 'total': 0, 'device': device}
+            worker_stats[worker_id]['total'] += 1
+            if r.get('success', False):
+                worker_stats[worker_id]['success'] += 1
+        
+        # 保存详细结果
+        detailed_results = {
+            "configuration": {
+                "pipeline": args.pipeline,
+                "device": args.device,
+                "instances_per_device": args.instances_per_device,
+                "batch_size": args.batch_size,
+                "input_glob_pattern": args.input_glob_pattern,
+                "test_mode": args.test_mode
+            },
+            "statistics": {
+                "total_files": len(image_paths),
+                "success_count": success_count,
+                "error_count": error_count,
+                "success_rate": success_count / len(image_paths) if image_paths else 0,
+                "total_time": total_time,
+                "avg_processing_time": avg_processing_time,
+                "throughput": len(image_paths) / total_time if total_time > 0 else 0,
+                "device_stats": device_stats,
+                "worker_stats": worker_stats
+            },
+            "results": results
+        }
+        
+        # 保存结果文件
+        result_file = output_dir / "processing_results.json"
+        with open(result_file, 'w', encoding='utf-8') as f:
+            json.dump(detailed_results, f, ensure_ascii=False, indent=2)
+        
+        # 打印统计信息
+        print("\n" + "="*70)
+        print("Processing completed!")
+        print("="*70)
+        print(f"Total files: {len(image_paths)}")
+        print(f"Successfully processed: {success_count}")
+        print(f"Failed: {error_count}")
+        print(f"Success rate: {success_count / len(image_paths) * 100:.2f}%")
+        print(f"Total time: {total_time:.2f} seconds")
+        print(f"Average processing time: {avg_processing_time:.2f} seconds/image")
+        print(f"Throughput: {len(image_paths) / total_time:.2f} images/second")
+        
+        # 设备统计
+        print(f"\nDevice Statistics:")
+        for device, stats in device_stats.items():
+            if stats['total'] > 0:
+                success_rate = stats['success'] / stats['total'] * 100
+                avg_time = stats['total_time'] / stats['success'] if stats['success'] > 0 else 0
+                print(f"  {device}: {stats['success']}/{stats['total']} "
+                      f"({success_rate:.1f}%), avg {avg_time:.2f}s/image")
+        
+        # Worker统计
+        print(f"\nWorker Statistics:")
+        for worker_id, stats in worker_stats.items():
+            if stats['total'] > 0:
+                success_rate = stats['success'] / stats['total'] * 100
+                print(f"  Worker {worker_id} ({stats['device']}): {stats['success']}/{stats['total']} "
+                      f"({success_rate:.1f}%)")
+        
+        print(f"\nDetailed results saved to: {result_file}")
+        print("All done!")
+        
+        return 0
+        
+    except Exception as e:
+        print(f"Processing failed: {e}", file=sys.stderr)
+        traceback.print_exc()
+        return 1
+
+if __name__ == "__main__":
+    if len(sys.argv) == 1:
+        # 如果没有命令行参数,使用默认配置运行
+        print("No command line arguments provided. Running with default configuration...")
+        
+        # 默认配置
+        default_config = {
+            "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
+            "output_dir": "./OmniDocBench_Results_Official",
+            "pipeline": "PP-StructureV3",
+            "device": "gpu:0,1",
+            "instances_per_device": 1,
+            "batch_size": 1,
+            "test_mode": False
+        }
+        
+        # 构造参数
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            sys.argv.extend([f"--{key}", str(value)])
+        
+        # 测试模式
+        # sys.argv.append("--test_mode")
+    
+    sys.exit(main())