|
|
@@ -0,0 +1,518 @@
|
|
|
+# zhch/ppstructurev3_multi_gpu_multiprocess_official.py
|
|
|
+import json
|
|
|
+import time
|
|
|
+import os
|
|
|
+import glob
|
|
|
+import traceback
|
|
|
+import argparse
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Dict, Any, Tuple
|
|
|
+from multiprocessing import Manager, Process, Queue
|
|
|
+from queue import Empty
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+from paddlex import create_pipeline
|
|
|
+from paddlex.utils.device import constr_device, parse_device
|
|
|
+from tqdm import tqdm
|
|
|
+import paddle
|
|
|
+
|
|
|
+from dotenv import load_dotenv
|
|
|
+load_dotenv(override=True)
|
|
|
+
|
|
|
+def worker(pipeline_name_or_config_path: str,
|
|
|
+ device: str,
|
|
|
+ task_queue: Queue,
|
|
|
+ result_queue: Queue,
|
|
|
+ batch_size: int,
|
|
|
+ output_dir: str,
|
|
|
+ worker_id: int):
|
|
|
+ """
|
|
|
+ 工作进程函数 - 基于官方parallel_inference.md实现
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pipeline_name_or_config_path: Pipeline名称或配置路径
|
|
|
+ device: 设备字符串
|
|
|
+ task_queue: 任务队列
|
|
|
+ result_queue: 结果队列
|
|
|
+ batch_size: 批处理大小
|
|
|
+ output_dir: 输出目录
|
|
|
+ worker_id: 工作进程ID
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 创建pipeline实例
|
|
|
+ pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
|
|
|
+ print(f"Worker {worker_id} initialized with device {device}")
|
|
|
+
|
|
|
+ should_end = False
|
|
|
+ batch = []
|
|
|
+ processed_count = 0
|
|
|
+
|
|
|
+ while not should_end:
|
|
|
+ try:
|
|
|
+ input_path = task_queue.get_nowait()
|
|
|
+ except Empty:
|
|
|
+ should_end = True
|
|
|
+ else:
|
|
|
+ batch.append(input_path)
|
|
|
+
|
|
|
+ if batch and (len(batch) == batch_size or should_end):
|
|
|
+ try:
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # 使用pipeline预测
|
|
|
+ results = list(pipeline.predict(
|
|
|
+ batch,
|
|
|
+ use_doc_orientation_classify=True,
|
|
|
+ use_doc_unwarping=False,
|
|
|
+ use_seal_recognition=True,
|
|
|
+ use_chart_recognition=True,
|
|
|
+ use_table_recognition=True,
|
|
|
+ use_formula_recognition=True,
|
|
|
+ ))
|
|
|
+
|
|
|
+ batch_processing_time = time.time() - start_time
|
|
|
+ batch_results = []
|
|
|
+
|
|
|
+ for result in results:
|
|
|
+ try:
|
|
|
+ input_path = Path(result.input_path)
|
|
|
+
|
|
|
+ # 保存结果
|
|
|
+ if result.get("page_index") is not None:
|
|
|
+ output_filename = f"{input_path.stem}_{result['page_index']}"
|
|
|
+ else:
|
|
|
+ output_filename = f"{input_path.stem}"
|
|
|
+
|
|
|
+ # 保存JSON和Markdown
|
|
|
+ json_output_path = str(Path(output_dir, f"{output_filename}.json"))
|
|
|
+ md_output_path = str(Path(output_dir, f"{output_filename}.md"))
|
|
|
+
|
|
|
+ result.save_to_json(json_output_path)
|
|
|
+ result.save_to_markdown(md_output_path)
|
|
|
+
|
|
|
+ # 记录处理结果
|
|
|
+ batch_results.append({
|
|
|
+ "image_path": input_path.name,
|
|
|
+ "processing_time": batch_processing_time / len(batch), # 平均时间
|
|
|
+ "success": True,
|
|
|
+ "device": device,
|
|
|
+ "worker_id": worker_id,
|
|
|
+ "output_json": json_output_path,
|
|
|
+ "output_md": md_output_path
|
|
|
+ })
|
|
|
+
|
|
|
+ processed_count += 1
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ batch_results.append({
|
|
|
+ "image_path": Path(result.input_path).name if hasattr(result, 'input_path') else "unknown",
|
|
|
+ "processing_time": 0,
|
|
|
+ "success": False,
|
|
|
+ "device": device,
|
|
|
+ "worker_id": worker_id,
|
|
|
+ "error": str(e)
|
|
|
+ })
|
|
|
+
|
|
|
+ # 将结果放入结果队列
|
|
|
+ result_queue.put(batch_results)
|
|
|
+
|
|
|
+ print(f"Worker {worker_id} ({device}) processed batch of {len(batch)} files. Total: {processed_count}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ # 批处理失败
|
|
|
+ error_results = []
|
|
|
+ for img_path in batch:
|
|
|
+ error_results.append({
|
|
|
+ "image_path": Path(img_path).name,
|
|
|
+ "processing_time": 0,
|
|
|
+ "success": False,
|
|
|
+ "device": device,
|
|
|
+ "worker_id": worker_id,
|
|
|
+ "error": str(e)
|
|
|
+ })
|
|
|
+ result_queue.put(error_results)
|
|
|
+
|
|
|
+ print(f"Error processing batch {batch} on {device}: {e}", file=sys.stderr)
|
|
|
+
|
|
|
+ batch.clear()
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Worker {worker_id} ({device}) initialization failed: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+ finally:
|
|
|
+ print(f"Worker {worker_id} ({device}) finished")
|
|
|
+
|
|
|
+def parallel_process_with_official_approach(image_paths: List[str],
|
|
|
+ pipeline_name: str = "PP-StructureV3",
|
|
|
+ device_str: str = "gpu:0,1",
|
|
|
+ instances_per_device: int = 1,
|
|
|
+ batch_size: int = 1,
|
|
|
+ output_dir: str = "./output") -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 使用官方推荐的方法进行多GPU多进程并行处理
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image_paths: 图像路径列表
|
|
|
+ pipeline_name: Pipeline名称
|
|
|
+ device_str: 设备字符串,如"gpu:0,1,2,3"
|
|
|
+ instances_per_device: 每个设备的实例数
|
|
|
+ batch_size: 批处理大小
|
|
|
+ output_dir: 输出目录
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 处理结果列表
|
|
|
+ """
|
|
|
+ # 创建输出目录
|
|
|
+ output_path = Path(output_dir)
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 解析设备
|
|
|
+ try:
|
|
|
+ device_type, device_ids = parse_device(device_str)
|
|
|
+ if device_ids is None or len(device_ids) < 1:
|
|
|
+ print("No valid devices specified.", file=sys.stderr)
|
|
|
+ return []
|
|
|
+
|
|
|
+ print(f"Parsed devices: {device_type}:{device_ids}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Failed to parse device string '{device_str}': {e}", file=sys.stderr)
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 验证批处理大小
|
|
|
+ if batch_size <= 0:
|
|
|
+ print("Batch size must be greater than 0.", file=sys.stderr)
|
|
|
+ return []
|
|
|
+
|
|
|
+ total_instances = len(device_ids) * instances_per_device
|
|
|
+ print(f"Configuration:")
|
|
|
+ print(f" Devices: {device_ids}")
|
|
|
+ print(f" Instances per device: {instances_per_device}")
|
|
|
+ print(f" Total instances: {total_instances}")
|
|
|
+ print(f" Batch size: {batch_size}")
|
|
|
+ print(f" Total images: {len(image_paths)}")
|
|
|
+
|
|
|
+ # 使用Manager创建队列
|
|
|
+ with Manager() as manager:
|
|
|
+ task_queue = manager.Queue()
|
|
|
+ result_queue = manager.Queue()
|
|
|
+
|
|
|
+ # 将任务放入队列
|
|
|
+ for img_path in image_paths:
|
|
|
+ task_queue.put(str(img_path))
|
|
|
+
|
|
|
+ print(f"Added {len(image_paths)} tasks to queue")
|
|
|
+
|
|
|
+ # 创建并启动工作进程
|
|
|
+ processes = []
|
|
|
+ worker_id = 0
|
|
|
+
|
|
|
+ for device_id in device_ids:
|
|
|
+ for instance_idx in range(instances_per_device):
|
|
|
+ device = constr_device(device_type, [device_id])
|
|
|
+
|
|
|
+ p = Process(
|
|
|
+ target=worker,
|
|
|
+ args=(
|
|
|
+ pipeline_name,
|
|
|
+ device,
|
|
|
+ task_queue,
|
|
|
+ result_queue,
|
|
|
+ batch_size,
|
|
|
+ str(output_path),
|
|
|
+ worker_id,
|
|
|
+ ),
|
|
|
+ name=f"Worker-{worker_id}-{device}"
|
|
|
+ )
|
|
|
+ p.start()
|
|
|
+ processes.append(p)
|
|
|
+ worker_id += 1
|
|
|
+
|
|
|
+ print(f"Started {len(processes)} worker processes")
|
|
|
+
|
|
|
+ # 收集结果
|
|
|
+ all_results = []
|
|
|
+ completed_images = 0
|
|
|
+ total_images = len(image_paths)
|
|
|
+
|
|
|
+ with tqdm(total=total_images, desc="Processing images", unit="img") as pbar:
|
|
|
+ # 等待所有结果
|
|
|
+ active_workers = len(processes)
|
|
|
+
|
|
|
+ while completed_images < total_images and active_workers > 0:
|
|
|
+ try:
|
|
|
+ # 设置较短的超时时间,定期检查进程状态
|
|
|
+ batch_results = result_queue.get(timeout=5.0)
|
|
|
+
|
|
|
+ all_results.extend(batch_results)
|
|
|
+ batch_size_actual = len(batch_results)
|
|
|
+ completed_images += batch_size_actual
|
|
|
+
|
|
|
+ pbar.update(batch_size_actual)
|
|
|
+
|
|
|
+ # 更新进度条信息
|
|
|
+ success_count = sum(1 for r in batch_results if r.get('success', False))
|
|
|
+ total_success = sum(1 for r in all_results if r.get('success', False))
|
|
|
+
|
|
|
+ # 按设备统计
|
|
|
+ device_stats = {}
|
|
|
+ for r in all_results:
|
|
|
+ device = r.get('device', 'unknown')
|
|
|
+ if device not in device_stats:
|
|
|
+ device_stats[device] = {'success': 0, 'total': 0}
|
|
|
+ device_stats[device]['total'] += 1
|
|
|
+ if r.get('success', False):
|
|
|
+ device_stats[device]['success'] += 1
|
|
|
+
|
|
|
+ device_info = ', '.join([f"{k}:{v['success']}/{v['total']}"
|
|
|
+ for k, v in device_stats.items()])
|
|
|
+
|
|
|
+ pbar.set_postfix({
|
|
|
+ 'batch_success': f"{success_count}/{batch_size_actual}",
|
|
|
+ 'total_success': f"{total_success}/{completed_images}",
|
|
|
+ 'devices': device_info
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ # 检查是否还有活跃的进程
|
|
|
+ active_workers = sum(1 for p in processes if p.is_alive())
|
|
|
+ if active_workers == 0:
|
|
|
+ print("All workers have finished")
|
|
|
+ break
|
|
|
+
|
|
|
+ # 超时或其他错误,继续等待
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 等待所有进程结束
|
|
|
+ print("Waiting for all processes to finish...")
|
|
|
+ for p in processes:
|
|
|
+ p.join(timeout=10.0)
|
|
|
+ if p.is_alive():
|
|
|
+ print(f"Force terminating process: {p.name}")
|
|
|
+ p.terminate()
|
|
|
+ p.join(timeout=5.0)
|
|
|
+
|
|
|
+ return all_results
|
|
|
+
|
|
|
+def main():
|
|
|
+ """主函数"""
|
|
|
+ parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Multi-GPU Parallel Processing")
|
|
|
+
|
|
|
+ # 必需参数
|
|
|
+ parser.add_argument("--input_dir", type=str, required=True,
|
|
|
+ help="Input directory containing images")
|
|
|
+ parser.add_argument("--output_dir", type=str, default="./output",
|
|
|
+ help="Output directory")
|
|
|
+
|
|
|
+ # Pipeline配置
|
|
|
+ parser.add_argument("--pipeline", type=str, default="PP-StructureV3",
|
|
|
+ help="Pipeline name or config path")
|
|
|
+ parser.add_argument("--device", type=str, default="gpu:0,1",
|
|
|
+ help="Devices for parallel inference (e.g., 'gpu:0,1,2,3')")
|
|
|
+
|
|
|
+ # 并行配置
|
|
|
+ parser.add_argument("--instances_per_device", type=int, default=1,
|
|
|
+ help="Number of pipeline instances per device")
|
|
|
+ parser.add_argument("--batch_size", type=int, default=1,
|
|
|
+ help="Inference batch size for each pipeline instance")
|
|
|
+
|
|
|
+ # 输入文件配置
|
|
|
+ parser.add_argument("--input_glob_pattern", type=str, default="*",
|
|
|
+ help="Pattern to find input files")
|
|
|
+
|
|
|
+ # 测试模式
|
|
|
+ parser.add_argument("--test_mode", action="store_true",
|
|
|
+ help="Test mode: only process first 20 images")
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # 验证输入目录
|
|
|
+ input_dir = Path(args.input_dir)
|
|
|
+ if not input_dir.exists():
|
|
|
+ print(f"Input directory does not exist: {input_dir}", file=sys.stderr)
|
|
|
+ return 2
|
|
|
+ if not input_dir.is_dir():
|
|
|
+ print(f"{input_dir} is not a directory", file=sys.stderr)
|
|
|
+ return 2
|
|
|
+
|
|
|
+ # 验证输出目录
|
|
|
+ output_dir = Path(args.output_dir)
|
|
|
+ if output_dir.exists() and not output_dir.is_dir():
|
|
|
+ print(f"{output_dir} is not a directory", file=sys.stderr)
|
|
|
+ return 2
|
|
|
+
|
|
|
+ print("="*70)
|
|
|
+ print("PaddleX PP-StructureV3 Multi-GPU Parallel Processing")
|
|
|
+ print("="*70)
|
|
|
+ print(f"Input directory: {input_dir}")
|
|
|
+ print(f"Output directory: {output_dir}")
|
|
|
+ print(f"Pipeline: {args.pipeline}")
|
|
|
+ print(f"Device: {args.device}")
|
|
|
+ print(f"Instances per device: {args.instances_per_device}")
|
|
|
+ print(f"Batch size: {args.batch_size}")
|
|
|
+ print(f"Input pattern: {args.input_glob_pattern}")
|
|
|
+
|
|
|
+ # 查找图像文件
|
|
|
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.pdf']
|
|
|
+ image_files = []
|
|
|
+
|
|
|
+ for ext in image_extensions:
|
|
|
+ pattern = args.input_glob_pattern if args.input_glob_pattern != "*" else ext
|
|
|
+ image_files.extend(input_dir.glob(pattern))
|
|
|
+
|
|
|
+ # 如果没有找到文件,尝试使用用户指定的模式
|
|
|
+ if not image_files and args.input_glob_pattern != "*":
|
|
|
+ image_files = list(input_dir.glob(args.input_glob_pattern))
|
|
|
+
|
|
|
+ if not image_files:
|
|
|
+ print(f"No image files found in {input_dir} with pattern {args.input_glob_pattern}")
|
|
|
+ return 1
|
|
|
+
|
|
|
+ # 转换为字符串路径
|
|
|
+ image_paths = [str(f) for f in image_files]
|
|
|
+
|
|
|
+ print(f"Found {len(image_paths)} image files")
|
|
|
+
|
|
|
+ # 测试模式
|
|
|
+ if args.test_mode:
|
|
|
+ image_paths = image_paths[:20]
|
|
|
+ print(f"Test mode: processing only {len(image_paths)} images")
|
|
|
+
|
|
|
+ # 开始处理
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ try:
|
|
|
+ results = parallel_process_with_official_approach(
|
|
|
+ image_paths=image_paths,
|
|
|
+ pipeline_name=args.pipeline,
|
|
|
+ device_str=args.device,
|
|
|
+ instances_per_device=args.instances_per_device,
|
|
|
+ batch_size=args.batch_size,
|
|
|
+ output_dir=args.output_dir
|
|
|
+ )
|
|
|
+
|
|
|
+ total_time = time.time() - start_time
|
|
|
+
|
|
|
+ # 统计信息
|
|
|
+ success_count = sum(1 for r in results if r.get('success', False))
|
|
|
+ error_count = len(results) - success_count
|
|
|
+ total_processing_time = sum(r.get('processing_time', 0) for r in results if r.get('success', False))
|
|
|
+ avg_processing_time = total_processing_time / success_count if success_count > 0 else 0
|
|
|
+
|
|
|
+ # 按设备统计
|
|
|
+ device_stats = {}
|
|
|
+ worker_stats = {}
|
|
|
+
|
|
|
+ for r in results:
|
|
|
+ device = r.get('device', 'unknown')
|
|
|
+ worker_id = r.get('worker_id', 'unknown')
|
|
|
+
|
|
|
+ # 设备统计
|
|
|
+ if device not in device_stats:
|
|
|
+ device_stats[device] = {'success': 0, 'total': 0, 'total_time': 0}
|
|
|
+ device_stats[device]['total'] += 1
|
|
|
+ if r.get('success', False):
|
|
|
+ device_stats[device]['success'] += 1
|
|
|
+ device_stats[device]['total_time'] += r.get('processing_time', 0)
|
|
|
+
|
|
|
+ # Worker统计
|
|
|
+ if worker_id not in worker_stats:
|
|
|
+ worker_stats[worker_id] = {'success': 0, 'total': 0, 'device': device}
|
|
|
+ worker_stats[worker_id]['total'] += 1
|
|
|
+ if r.get('success', False):
|
|
|
+ worker_stats[worker_id]['success'] += 1
|
|
|
+
|
|
|
+ # 保存详细结果
|
|
|
+ detailed_results = {
|
|
|
+ "configuration": {
|
|
|
+ "pipeline": args.pipeline,
|
|
|
+ "device": args.device,
|
|
|
+ "instances_per_device": args.instances_per_device,
|
|
|
+ "batch_size": args.batch_size,
|
|
|
+ "input_glob_pattern": args.input_glob_pattern,
|
|
|
+ "test_mode": args.test_mode
|
|
|
+ },
|
|
|
+ "statistics": {
|
|
|
+ "total_files": len(image_paths),
|
|
|
+ "success_count": success_count,
|
|
|
+ "error_count": error_count,
|
|
|
+ "success_rate": success_count / len(image_paths) if image_paths else 0,
|
|
|
+ "total_time": total_time,
|
|
|
+ "avg_processing_time": avg_processing_time,
|
|
|
+ "throughput": len(image_paths) / total_time if total_time > 0 else 0,
|
|
|
+ "device_stats": device_stats,
|
|
|
+ "worker_stats": worker_stats
|
|
|
+ },
|
|
|
+ "results": results
|
|
|
+ }
|
|
|
+
|
|
|
+ # 保存结果文件
|
|
|
+ result_file = output_dir / "processing_results.json"
|
|
|
+ with open(result_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(detailed_results, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ # 打印统计信息
|
|
|
+ print("\n" + "="*70)
|
|
|
+ print("Processing completed!")
|
|
|
+ print("="*70)
|
|
|
+ print(f"Total files: {len(image_paths)}")
|
|
|
+ print(f"Successfully processed: {success_count}")
|
|
|
+ print(f"Failed: {error_count}")
|
|
|
+ print(f"Success rate: {success_count / len(image_paths) * 100:.2f}%")
|
|
|
+ print(f"Total time: {total_time:.2f} seconds")
|
|
|
+ print(f"Average processing time: {avg_processing_time:.2f} seconds/image")
|
|
|
+ print(f"Throughput: {len(image_paths) / total_time:.2f} images/second")
|
|
|
+
|
|
|
+ # 设备统计
|
|
|
+ print(f"\nDevice Statistics:")
|
|
|
+ for device, stats in device_stats.items():
|
|
|
+ if stats['total'] > 0:
|
|
|
+ success_rate = stats['success'] / stats['total'] * 100
|
|
|
+ avg_time = stats['total_time'] / stats['success'] if stats['success'] > 0 else 0
|
|
|
+ print(f" {device}: {stats['success']}/{stats['total']} "
|
|
|
+ f"({success_rate:.1f}%), avg {avg_time:.2f}s/image")
|
|
|
+
|
|
|
+ # Worker统计
|
|
|
+ print(f"\nWorker Statistics:")
|
|
|
+ for worker_id, stats in worker_stats.items():
|
|
|
+ if stats['total'] > 0:
|
|
|
+ success_rate = stats['success'] / stats['total'] * 100
|
|
|
+ print(f" Worker {worker_id} ({stats['device']}): {stats['success']}/{stats['total']} "
|
|
|
+ f"({success_rate:.1f}%)")
|
|
|
+
|
|
|
+ print(f"\nDetailed results saved to: {result_file}")
|
|
|
+ print("All done!")
|
|
|
+
|
|
|
+ return 0
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Processing failed: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+ return 1
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ if len(sys.argv) == 1:
|
|
|
+ # 如果没有命令行参数,使用默认配置运行
|
|
|
+ print("No command line arguments provided. Running with default configuration...")
|
|
|
+
|
|
|
+ # 默认配置
|
|
|
+ default_config = {
|
|
|
+ "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
|
|
|
+ "output_dir": "./OmniDocBench_Results_Official",
|
|
|
+ "pipeline": "PP-StructureV3",
|
|
|
+ "device": "gpu:0,1",
|
|
|
+ "instances_per_device": 1,
|
|
|
+ "batch_size": 1,
|
|
|
+ "test_mode": False
|
|
|
+ }
|
|
|
+
|
|
|
+ # 构造参数
|
|
|
+ sys.argv = [sys.argv[0]]
|
|
|
+ for key, value in default_config.items():
|
|
|
+ sys.argv.extend([f"--{key}", str(value)])
|
|
|
+
|
|
|
+ # 测试模式
|
|
|
+ # sys.argv.append("--test_mode")
|
|
|
+
|
|
|
+ sys.exit(main())
|