瀏覽代碼

refactor(zhch): 优化多进程中paddle的导入和设备设置,避免主进程CUDA冲突

zhch158_admin 3 月之前
父節點
當前提交
4ae079b3da
共有 1 個文件被更改,包括 87 次插入223 次删除
  1. 87 223
      zhch/ppstructurev3_multi_gpu_multiprocess_official.py

+ 87 - 223
zhch/ppstructurev3_multi_gpu_multiprocess_official.py

@@ -15,8 +15,8 @@ import numpy as np
 from paddlex import create_pipeline
 from paddlex.utils.device import constr_device, parse_device
 from tqdm import tqdm
-import paddle
-from cuda_utils import detect_available_gpus, monitor_gpu_memory
+# import paddle  # ❌ 不要在主模块导入paddle
+# from cuda_utils import detect_available_gpus, monitor_gpu_memory  # ❌ 不要在主进程使用
 
 from dotenv import load_dotenv
 load_dotenv(override=True)
@@ -41,9 +41,18 @@ def worker(pipeline_name_or_config_path: str,
         worker_id: 工作进程ID
     """
     try:
-        # 创建pipeline实例
+        # 在子进程中导入paddle,避免主进程CUDA冲突
+        import paddle
+        import os
+        
+        # 设置子进程的CUDA设备
+        device_id = device.split(':')[1] if ':' in device else '0'
+        os.environ['CUDA_VISIBLE_DEVICES'] = device_id
+        
+        # 直接创建pipeline,让PaddleX自动处理设备初始化
         pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
         print(f"Worker {worker_id} initialized with device {device}")
+        
     except Exception as e:
         print(f"Worker {worker_id} ({device}) initialization failed: {e}", file=sys.stderr)
         traceback.print_exc()
@@ -180,7 +189,7 @@ def parallel_process_with_official_approach(image_paths: List[str],
     output_path = Path(output_dir)
     output_path.mkdir(parents=True, exist_ok=True)
     
-    # 解析设备
+    # 解析设备 - 不要在主进程中初始化paddle
     try:
         device_type, device_ids = parse_device(device_str)
         if device_ids is None or len(device_ids) < 1:
@@ -222,9 +231,8 @@ def parallel_process_with_official_approach(image_paths: List[str],
         worker_id = 0
         
         for device_id in device_ids:
-            for instance_idx in range(instances_per_device):
+            for _ in range(instances_per_device):
                 device = constr_device(device_type, [device_id])
-                
                 p = Process(
                     target=worker,
                     args=(
@@ -236,7 +244,6 @@ def parallel_process_with_official_approach(image_paths: List[str],
                         str(output_path),
                         worker_id,
                     ),
-                    name=f"Worker-{worker_id}-{device}"
                 )
                 p.start()
                 processes.append(p)
@@ -244,67 +251,42 @@ def parallel_process_with_official_approach(image_paths: List[str],
         
         print(f"Started {len(processes)} worker processes")
         
+        # 发送结束信号
+        for _ in range(total_instances):
+            task_queue.put(None)
+        
         # 收集结果
         all_results = []
-        completed_images = 0
         total_images = len(image_paths)
         
         with tqdm(total=total_images, desc="Processing images", unit="img") as pbar:
-            # 等待所有结果
-            active_workers = len(processes)
+            completed_count = 0
             
-            while completed_images < total_images and active_workers > 0:
+            while completed_count < total_images:
                 try:
-                    # 设置较短的超时时间,定期检查进程状态
-                    batch_results = result_queue.get(timeout=5.0)
-                    
+                    batch_results = result_queue.get(timeout=300)  # 5分钟超时
                     all_results.extend(batch_results)
-                    batch_size_actual = len(batch_results)
-                    completed_images += batch_size_actual
                     
-                    pbar.update(batch_size_actual)
-                    
-                    # 更新进度条信息
-                    success_count = sum(1 for r in batch_results if r.get('success', False))
-                    total_success = sum(1 for r in all_results if r.get('success', False))
-                    
-                    # 按设备统计
-                    device_stats = {}
-                    for r in all_results:
-                        device = r.get('device', 'unknown')
-                        if device not in device_stats:
-                            device_stats[device] = {'success': 0, 'total': 0}
-                        device_stats[device]['total'] += 1
-                        if r.get('success', False):
-                            device_stats[device]['success'] += 1
-                    
-                    device_info = ', '.join([f"{k}:{v['success']}/{v['total']}" 
-                                           for k, v in device_stats.items()])
+                    # 更新进度条
+                    batch_success_count = sum(1 for r in batch_results if r.get('success', False))
+                    completed_count += len(batch_results)
+                    pbar.update(len(batch_results))
                     
+                    # 显示当前批次状态
                     pbar.set_postfix({
-                        'batch_success': f"{success_count}/{batch_size_actual}",
-                        'total_success': f"{total_success}/{completed_images}",
-                        'devices': device_info
+                        'batch_success': f"{batch_success_count}/{len(batch_results)}",
+                        'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{completed_count}"
                     })
                     
                 except Exception as e:
-                    # 检查是否还有活跃的进程
-                    active_workers = sum(1 for p in processes if p.is_alive())
-                    if active_workers == 0:
-                        print("All workers have finished")
-                        break
-                    
-                    # 超时或其他错误,继续等待
-                    continue
+                    print(f"Error collecting results: {e}")
+                    break
         
         # 等待所有进程结束
-        print("Waiting for all processes to finish...")
         for p in processes:
-            p.join(timeout=10.0)
+            p.join(timeout=10)
             if p.is_alive():
-                print(f"Force terminating process: {p.name}")
                 p.terminate()
-                p.join(timeout=5.0)
     
     return all_results
 
@@ -313,189 +295,66 @@ def main():
     parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Multi-GPU Parallel Processing")
     
     # 必需参数
-    parser.add_argument("--input_dir", type=str, required=True, 
-                       help="Input directory containing images")
-    parser.add_argument("--output_dir", type=str, default="./output", 
-                       help="Output directory")
-    
-    # Pipeline配置
-    parser.add_argument("--pipeline", type=str, default="PP-StructureV3",
-                       help="Pipeline name or config path")
-    parser.add_argument("--device", type=str, default="gpu:0,1",
-                       help="Devices for parallel inference (e.g., 'gpu:0,1,2,3')")
-    
-    # 并行配置
-    parser.add_argument("--instances_per_device", type=int, default=1,
-                       help="Number of pipeline instances per device")
-    parser.add_argument("--batch_size", type=int, default=1,
-                       help="Inference batch size for each pipeline instance")
-    
-    # 输入文件配置
-    parser.add_argument("--input_glob_pattern", type=str, default="*",
-                       help="Pattern to find input files")
-    
-    # 测试模式
-    parser.add_argument("--test_mode", action="store_true",
-                       help="Test mode: only process first 20 images")
+    parser.add_argument("--input_dir", type=str, default="../../OmniDocBench/OpenDataLab___OmniDocBench/images", help="Input directory")
+    parser.add_argument("--output_dir", type=str, default="./OmniDocBench_Results_Official", help="Output directory")
+    parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
+    parser.add_argument("--device", type=str, default="gpu:0", help="Device string")
+    parser.add_argument("--instances_per_device", type=int, default=1, help="Instances per device")
+    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
+    parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
+    parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 images)")
     
     args = parser.parse_args()
     
-    # 验证输入目录
-    input_dir = Path(args.input_dir)
-    if not input_dir.exists():
-        print(f"Input directory does not exist: {input_dir}", file=sys.stderr)
-        return 2
-    if not input_dir.is_dir():
-        print(f"{input_dir} is not a directory", file=sys.stderr)
-        return 2
-    
-    # 验证输出目录
-    output_dir = Path(args.output_dir)
-    if output_dir.exists() and not output_dir.is_dir():
-        print(f"{output_dir} is not a directory", file=sys.stderr)
-        return 2
-    
-    print("="*70)
-    print("PaddleX PP-StructureV3 Multi-GPU Parallel Processing")
-    print("="*70)
-    print(f"Input directory: {input_dir}")
-    print(f"Output directory: {output_dir}")
-    print(f"Pipeline: {args.pipeline}")
-    print(f"Device: {args.device}")
-    print(f"Instances per device: {args.instances_per_device}")
-    print(f"Batch size: {args.batch_size}")
-    print(f"Input pattern: {args.input_glob_pattern}")
-    
-    # 查找图像文件
-    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.pdf']
-    image_files = []
-    
-    for ext in image_extensions:
-        pattern = args.input_glob_pattern if args.input_glob_pattern != "*" else ext
-        image_files.extend(input_dir.glob(pattern))
-    
-    # 如果没有找到文件,尝试使用用户指定的模式
-    if not image_files and args.input_glob_pattern != "*":
-        image_files = list(input_dir.glob(args.input_glob_pattern))
-    
-    if not image_files:
-        print(f"No image files found in {input_dir} with pattern {args.input_glob_pattern}")
-        return 1
-    
-    # 转换为字符串路径
-    image_paths = [str(f) for f in image_files]
-    
-    print(f"Found {len(image_paths)} image files")
-    
-    # 测试模式
-    if args.test_mode:
-        image_paths = image_paths[:20]
-        print(f"Test mode: processing only {len(image_paths)} images")
-    
-    # 开始处理
-    start_time = time.time()
-    
     try:
-        results = parallel_process_with_official_approach(
-            image_paths=image_paths,
-            pipeline_name=args.pipeline,
-            device_str=args.device,
-            instances_per_device=args.instances_per_device,
-            batch_size=args.batch_size,
-            output_dir=args.output_dir
-        )
+        # 获取图像文件列表
+        input_dir = Path(args.input_dir)
+        if not input_dir.exists():
+            print(f"Input directory does not exist: {input_dir}")
+            return 1
         
-        total_time = time.time() - start_time
+        # 查找图像文件
+        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+        image_files = []
+        for ext in image_extensions:
+            image_files.extend(list(input_dir.glob(f"*{ext}")))
+            image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
         
-        # 统计信息
-        success_count = sum(1 for r in results if r.get('success', False))
-        error_count = len(results) - success_count
-        total_processing_time = sum(r.get('processing_time', 0) for r in results if r.get('success', False))
-        avg_processing_time = total_processing_time / success_count if success_count > 0 else 0
+        if not image_files:
+            print(f"No image files found in {input_dir}")
+            return 1
         
-        # 按设备统计
-        device_stats = {}
-        worker_stats = {}
+        image_files = [str(f) for f in image_files]
+        print(f"Found {len(image_files)} image files")
         
-        for r in results:
-            device = r.get('device', 'unknown')
-            worker_id = r.get('worker_id', 'unknown')
-            
-            # 设备统计
-            if device not in device_stats:
-                device_stats[device] = {'success': 0, 'total': 0, 'total_time': 0}
-            device_stats[device]['total'] += 1
-            if r.get('success', False):
-                device_stats[device]['success'] += 1
-                device_stats[device]['total_time'] += r.get('processing_time', 0)
-            
-            # Worker统计
-            if worker_id not in worker_stats:
-                worker_stats[worker_id] = {'success': 0, 'total': 0, 'device': device}
-            worker_stats[worker_id]['total'] += 1
-            if r.get('success', False):
-                worker_stats[worker_id]['success'] += 1
+        if args.test_mode:
+            image_files = image_files[:20]
+            print(f"Test mode: processing only {len(image_files)} images")
         
-        # 保存详细结果
-        detailed_results = {
-            "configuration": {
-                "pipeline": args.pipeline,
-                "device": args.device,
-                "instances_per_device": args.instances_per_device,
-                "batch_size": args.batch_size,
-                "input_glob_pattern": args.input_glob_pattern,
-                "test_mode": args.test_mode
-            },
-            "statistics": {
-                "total_files": len(image_paths),
-                "success_count": success_count,
-                "error_count": error_count,
-                "success_rate": success_count / len(image_paths) if image_paths else 0,
-                "total_time": total_time,
-                "avg_processing_time": avg_processing_time,
-                "throughput": len(image_paths) / total_time if total_time > 0 else 0,
-                "device_stats": device_stats,
-                "worker_stats": worker_stats
-            },
-            "results": results
-        }
+        # 开始处理
+        start_time = time.time()
+        results = parallel_process_with_official_approach(
+            image_files,
+            args.pipeline,
+            args.device,
+            args.instances_per_device,
+            args.batch_size,
+            args.output_dir
+        )
+        total_time = time.time() - start_time
         
-        # 保存结果文件
-        result_file = output_dir / "processing_results.json"
-        with open(result_file, 'w', encoding='utf-8') as f:
-            json.dump(detailed_results, f, ensure_ascii=False, indent=2)
+        # 统计结果
+        success_count = sum(1 for r in results if r.get('success', False))
+        error_count = len(results) - success_count
         
-        # 打印统计信息
-        print("\n" + "="*70)
-        print("Processing completed!")
-        print("="*70)
-        print(f"Total files: {len(image_paths)}")
-        print(f"Successfully processed: {success_count}")
+        print(f"\n" + "="*50)
+        print(f"Processing completed!")
+        print(f"Total files: {len(image_files)}")
+        print(f"Successful: {success_count}")
         print(f"Failed: {error_count}")
-        print(f"Success rate: {success_count / len(image_paths) * 100:.2f}%")
+        print(f"Success rate: {success_count / len(image_files) * 100:.2f}%")
         print(f"Total time: {total_time:.2f} seconds")
-        print(f"Average processing time: {avg_processing_time:.2f} seconds/image")
-        print(f"Throughput: {len(image_paths) / total_time:.2f} images/second")
-        
-        # 设备统计
-        print(f"\nDevice Statistics:")
-        for device, stats in device_stats.items():
-            if stats['total'] > 0:
-                success_rate = stats['success'] / stats['total'] * 100
-                avg_time = stats['total_time'] / stats['success'] if stats['success'] > 0 else 0
-                print(f"  {device}: {stats['success']}/{stats['total']} "
-                      f"({success_rate:.1f}%), avg {avg_time:.2f}s/image")
-        
-        # Worker统计
-        print(f"\nWorker Statistics:")
-        for worker_id, stats in worker_stats.items():
-            if stats['total'] > 0:
-                success_rate = stats['success'] / stats['total'] * 100
-                print(f"  Worker {worker_id} ({stats['device']}): {stats['success']}/{stats['total']} "
-                      f"({success_rate:.1f}%)")
-        
-        print(f"\nDetailed results saved to: {result_file}")
-        print("All done!")
+        print(f"Throughput: {len(image_files) / total_time:.2f} images/second")
         
         return 0
         
@@ -505,12 +364,17 @@ def main():
         return 1
 
 if __name__ == "__main__":
+    # ❌ 移除所有主进程CUDA操作
+    # print(f"🚀 启动OCR程序...")
+    # print(f"CUDA 版本: {paddle.device.cuda.get_device_name()}")
+    # print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
+    # available_gpus = detect_available_gpus()
+    # monitor_gpu_memory(available_gpus)
+    
+    # ✅ 只进行简单的环境检查
     print(f"🚀 启动OCR程序...")
-    print(f"CUDA 版本: {paddle.device.cuda.get_device_name()}")
     print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
-    available_gpus = detect_available_gpus()
-    monitor_gpu_memory(available_gpus)
-
+    
     if len(sys.argv) == 1:
         # 如果没有命令行参数,使用默认配置运行
         print("No command line arguments provided. Running with default configuration...")