Selaa lähdekoodia

feat(zhch): 添加多进程调度程序,支持图像文件的并行处理和结果保存

zhch158_admin 3 kuukautta sitten
vanhempi
commit
7996999a26
1 muutettua tiedostoa jossa 380 lisäystä ja 0 poistoa
  1. 380 0
      zhch/ppstructurev3_scheduler.py

+ 380 - 0
zhch/ppstructurev3_scheduler.py

@@ -0,0 +1,380 @@
+import json
+import time
+import os
+import argparse
+import sys
+import subprocess
+import tempfile
+from pathlib import Path
+from typing import List, Dict, Any, Tuple
+from concurrent.futures import ProcessPoolExecutor, as_completed
+import threading
+from queue import Queue
+from tqdm import tqdm
+
+def split_files(file_list: List[str], num_splits: int) -> List[List[str]]:
+    """
+    将文件列表分割成指定数量的子列表
+    
+    Args:
+        file_list: 文件路径列表
+        num_splits: 分割数量
+        
+    Returns:
+        分割后的文件列表
+    """
+    if num_splits <= 0:
+        return [file_list]
+    
+    chunk_size = len(file_list) // num_splits
+    remainder = len(file_list) % num_splits
+    
+    chunks = []
+    start = 0
+    
+    for i in range(num_splits):
+        # 前remainder个chunk多分配一个文件
+        current_chunk_size = chunk_size + (1 if i < remainder else 0)
+        if current_chunk_size > 0:
+            chunks.append(file_list[start:start + current_chunk_size])
+            start += current_chunk_size
+    
+    return [chunk for chunk in chunks if chunk]  # 过滤空列表
+
+def create_temp_file_list(file_chunk: List[str]) -> str:
+    """
+    创建临时文件列表文件
+    
+    Args:
+        file_chunk: 文件路径列表
+        
+    Returns:
+        临时文件路径
+    """
+    with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
+        for file_path in file_chunk:
+            f.write(f"{file_path}\n")
+        return f.name
+
+def run_single_process(args: Tuple[List[str], Dict[str, Any], int]) -> Dict[str, Any]:
+    """
+    运行单个ppstructurev3_single_process.py进程
+    
+    Args:
+        args: (file_chunk, config, process_id)
+        
+    Returns:
+        处理结果
+    """
+    file_chunk, config, process_id = args
+    
+    if not file_chunk:
+        return {"process_id": process_id, "success": False, "error": "Empty file chunk"}
+    
+    # 创建临时文件列表
+    temp_file_list = create_temp_file_list(file_chunk)
+    
+    try:
+        # 创建进程专用的输出目录
+        process_output_dir = Path(config["output_dir"]) / f"process_{process_id}"
+        process_output_dir.mkdir(parents=True, exist_ok=True)
+        
+        # 构建命令行参数
+        cmd = [
+            sys.executable,
+            config["single_process_script"],
+            "--input_file_list", temp_file_list,  # 需要修改single_process脚本支持文件列表
+            "--output_dir", str(process_output_dir),
+            "--pipeline", config["pipeline"],
+            "--device", config["device"],
+            "--batch_size", str(config["batch_size"]),
+        ]
+        
+        # 添加可选参数
+        if config.get("test_mode", False):
+            cmd.append("--test_mode")
+        
+        print(f"Process {process_id} starting with {len(file_chunk)} files on device {config['device']}")
+        
+        # 执行子进程
+        start_time = time.time()
+        result = subprocess.run(
+            cmd,
+            capture_output=True,
+            text=True,
+            timeout=config.get("timeout", 3600)  # 1小时超时
+        )
+        
+        processing_time = time.time() - start_time
+        
+        if result.returncode == 0:
+            print(f"Process {process_id} completed successfully in {processing_time:.2f}s")
+            
+            # 读取结果文件
+            result_files = list(process_output_dir.glob("*.json"))
+            
+            return {
+                "process_id": process_id,
+                "success": True,
+                "processing_time": processing_time,
+                "file_count": len(file_chunk),
+                "device": config["device"],
+                "output_dir": str(process_output_dir),
+                "result_files": [str(f) for f in result_files],
+                "stdout": result.stdout,
+                "stderr": result.stderr
+            }
+        else:
+            print(f"Process {process_id} failed with return code {result.returncode}")
+            return {
+                "process_id": process_id,
+                "success": False,
+                "error": f"Process failed with return code {result.returncode}",
+                "stdout": result.stdout,
+                "stderr": result.stderr
+            }
+            
+    except subprocess.TimeoutExpired:
+        print(f"Process {process_id} timed out")
+        return {
+            "process_id": process_id,
+            "success": False,
+            "error": "Process timeout"
+        }
+    except Exception as e:
+        print(f"Process {process_id} error: {e}")
+        return {
+            "process_id": process_id,
+            "success": False,
+            "error": str(e)
+        }
+    finally:
+        # 清理临时文件
+        try:
+            os.unlink(temp_file_list)
+        except:
+            pass
+
+def monitor_progress(total_files: int, completed_queue: Queue):
+    """
+    监控处理进度
+    """
+    with tqdm(total=total_files, desc="Total Progress", unit="files") as pbar:
+        completed_count = 0
+        while completed_count < total_files:
+            try:
+                batch_count = completed_queue.get(timeout=1)
+                completed_count += batch_count
+                pbar.update(batch_count)
+            except:
+                continue
+
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Multi-Process Scheduler")
+    
+    # 输入输出参数
+    parser.add_argument("--input_dir", type=str, required=True, help="Input directory")
+    parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
+    parser.add_argument("--single_process_script", type=str, 
+                       default="./ppstructurev3_single_process.py", 
+                       help="Path to single process script")
+    
+    # 并行参数
+    parser.add_argument("--num_processes", type=int, default=4, help="Number of parallel processes")
+    parser.add_argument("--devices", type=str, default="gpu:0,gpu:1,gpu:2,gpu:3", 
+                       help="Device list (comma separated)")
+    
+    # Pipeline参数
+    parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
+    parser.add_argument("--batch_size", type=int, default=4, help="Batch size per process")
+    parser.add_argument("--timeout", type=int, default=3600, help="Process timeout in seconds")
+    
+    # 其他参数
+    parser.add_argument("--test_mode", action="store_true", help="Test mode")
+    parser.add_argument("--max_files", type=int, default=None, help="Maximum files to process")
+    
+    args = parser.parse_args()
+    
+    try:
+        # 获取图像文件列表
+        input_dir = Path(args.input_dir).resolve()
+        output_dir = Path(args.output_dir).resolve()
+        
+        print(f"Input dir: {input_dir}")
+        print(f"Output dir: {output_dir}")
+        
+        if not input_dir.exists():
+            print(f"Input directory does not exist: {input_dir}")
+            return 1
+        
+        # 查找图像文件
+        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+        image_files = []
+        for ext in image_extensions:
+            image_files.extend(list(input_dir.glob(f"*{ext}")))
+            image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
+        
+        if not image_files:
+            print(f"No image files found in {input_dir}")
+            return 1
+        
+        # 去重并排序
+        image_files = sorted(list(set(str(f) for f in image_files)))
+        
+        # 限制文件数量
+        if args.max_files:
+            image_files = image_files[:args.max_files]
+            
+        if args.test_mode:
+            image_files = image_files[:20]
+            print(f"Test mode: processing only {len(image_files)} images")
+        
+        print(f"Found {len(image_files)} image files")
+        
+        # 解析设备列表
+        devices = [d.strip() for d in args.devices.split(',')]
+        if len(devices) < args.num_processes:
+            # 如果设备数少于进程数,循环使用设备
+            devices = devices * ((args.num_processes // len(devices)) + 1)
+        devices = devices[:args.num_processes]
+        
+        print(f"Using {args.num_processes} processes with devices: {devices}")
+        
+        # 分割文件列表
+        file_chunks = split_files(image_files, args.num_processes)
+        print(f"Split into {len(file_chunks)} chunks: {[len(chunk) for chunk in file_chunks]}")
+        
+        # 创建输出目录
+        output_dir.mkdir(parents=True, exist_ok=True)
+        
+        # 准备进程参数
+        process_configs = []
+        for i, (chunk, device) in enumerate(zip(file_chunks, devices)):
+            config = {
+                "single_process_script": str(Path(args.single_process_script).resolve()),
+                "output_dir": str(output_dir),
+                "pipeline": args.pipeline,
+                "device": device,
+                "batch_size": args.batch_size,
+                "timeout": args.timeout,
+                "test_mode": args.test_mode
+            }
+            process_configs.append((chunk, config, i))
+        
+        # 启动进度监控
+        completed_queue = Queue()
+        progress_thread = threading.Thread(
+            target=monitor_progress, 
+            args=(len(image_files), completed_queue)
+        )
+        progress_thread.daemon = True
+        progress_thread.start()
+        
+        # 执行并行处理
+        start_time = time.time()
+        results = []
+        
+        with ProcessPoolExecutor(max_workers=args.num_processes) as executor:
+            # 提交所有任务
+            future_to_process = {
+                executor.submit(run_single_process, config): i 
+                for i, config in enumerate(process_configs)
+            }
+            
+            # 收集结果
+            for future in as_completed(future_to_process):
+                process_id = future_to_process[future]
+                try:
+                    result = future.result()
+                    results.append(result)
+                    
+                    # 更新进度
+                    if result.get("success", False):
+                        completed_queue.put(result.get("file_count", 0))
+                    
+                    print(f"Process {process_id} finished: {result.get('success', False)}")
+                    
+                except Exception as e:
+                    print(f"Process {process_id} generated an exception: {e}")
+                    results.append({
+                        "process_id": process_id,
+                        "success": False,
+                        "error": str(e)
+                    })
+        
+        total_time = time.time() - start_time
+        
+        # 统计结果
+        successful_processes = sum(1 for r in results if r.get('success', False))
+        total_processed_files = sum(r.get('file_count', 0) for r in results if r.get('success', False))
+        
+        print(f"\n" + "="*60)
+        print(f"🎉 Parallel processing completed!")
+        print(f"📊 Statistics:")
+        print(f"  Total processes: {len(results)}")
+        print(f"  Successful processes: {successful_processes}")
+        print(f"  Total files processed: {total_processed_files}/{len(image_files)}")
+        print(f"  Success rate: {total_processed_files/len(image_files)*100:.2f}%")
+        print(f"⏱️ Performance:")
+        print(f"  Total time: {total_time:.2f} seconds")
+        print(f"  Throughput: {total_processed_files/total_time:.2f} files/second")
+        print(f"  Avg time per file: {total_time/total_processed_files:.2f} seconds")
+        
+        # 保存调度结果
+        scheduler_stats = {
+            "total_files": len(image_files),
+            "total_processes": len(results),
+            "successful_processes": successful_processes,
+            "total_processed_files": total_processed_files,
+            "success_rate": total_processed_files / len(image_files) if len(image_files) > 0 else 0,
+            "total_time": total_time,
+            "throughput": total_processed_files / total_time if total_time > 0 else 0,
+            "avg_time_per_file": total_time / total_processed_files if total_processed_files > 0 else 0,
+            "num_processes": args.num_processes,
+            "devices": devices,
+            "batch_size": args.batch_size,
+            "pipeline": args.pipeline,
+            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
+        }
+        
+        final_results = {
+            "scheduler_stats": scheduler_stats,
+            "process_results": results
+        }
+        
+        # 保存结果
+        output_file = output_dir / f"scheduler_results_{args.num_processes}procs.json"
+        with open(output_file, 'w', encoding='utf-8') as f:
+            json.dump(final_results, f, ensure_ascii=False, indent=2)
+        
+        print(f"💾 Scheduler results saved to: {output_file}")
+        
+        return 0 if successful_processes == len(results) else 1
+        
+    except Exception as e:
+        print(f"❌ Scheduler failed: {e}")
+        import traceback
+        traceback.print_exc()
+        return 1
+
+if __name__ == "__main__":
+    print(f"🚀 启动多进程调度程序...")
+    
+    if len(sys.argv) == 1:
+        # 默认配置
+        default_config = {
+            "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
+            "output_dir": "./OmniDocBench_Results_Scheduler",
+            "num_processes": 4,
+            "devices": "gpu:0,gpu:1,gpu:2,gpu:3",
+            "batch_size": 2,
+        }
+        
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            sys.argv.extend([f"--{key}", str(value)])
+        
+        sys.argv.append("--test_mode")
+    
+    sys.exit(main())