ppstructurev3_multi_gpu_multiprocess_official.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # zhch/ppstructurev3_multi_gpu_multiprocess_official.py
  2. import json
  3. import time
  4. import os
  5. import glob
  6. import traceback
  7. import argparse
  8. import sys
  9. from pathlib import Path
  10. from typing import List, Dict, Any, Tuple
  11. from multiprocessing import Manager, Process, Queue
  12. from queue import Empty
  13. import cv2
  14. import numpy as np
  15. from paddlex import create_pipeline
  16. from paddlex.utils.device import constr_device, parse_device
  17. from tqdm import tqdm
  18. import paddle
  19. from cuda_utils import detect_available_gpus, monitor_gpu_memory
  20. from dotenv import load_dotenv
  21. load_dotenv(override=True)
  22. def worker(pipeline_name_or_config_path: str,
  23. device: str,
  24. task_queue: Queue,
  25. result_queue: Queue,
  26. batch_size: int,
  27. output_dir: str,
  28. worker_id: int):
  29. """
  30. 工作进程函数 - 基于官方parallel_inference.md实现
  31. Args:
  32. pipeline_name_or_config_path: Pipeline名称或配置路径
  33. device: 设备字符串
  34. task_queue: 任务队列
  35. result_queue: 结果队列
  36. batch_size: 批处理大小
  37. output_dir: 输出目录
  38. worker_id: 工作进程ID
  39. """
  40. try:
  41. # 创建pipeline实例
  42. from dotenv import load_dotenv
  43. load_dotenv(override=True)
  44. print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
  45. import paddle
  46. paddle.set_device(device)
  47. pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
  48. print(f"Worker {worker_id} initialized with device {device}")
  49. except Exception as e:
  50. print(f"Worker {worker_id} ({device}) initialization failed: {e}", file=sys.stderr)
  51. traceback.print_exc()
  52. # 发送错误信息到结果队列
  53. result_queue.put([{
  54. "error": f"Worker initialization failed: {str(e)}",
  55. "worker_id": worker_id,
  56. "device": device,
  57. "success": False
  58. }])
  59. return
  60. try:
  61. should_end = False
  62. batch = []
  63. processed_count = 0
  64. while not should_end:
  65. try:
  66. input_path = task_queue.get_nowait()
  67. except Empty:
  68. should_end = True
  69. else:
  70. batch.append(input_path)
  71. if batch and (len(batch) == batch_size or should_end):
  72. try:
  73. start_time = time.time()
  74. # 使用pipeline预测
  75. results = pipeline.predict(
  76. batch,
  77. use_doc_orientation_classify=True,
  78. use_doc_unwarping=False,
  79. use_seal_recognition=True,
  80. use_chart_recognition=True,
  81. use_table_recognition=True,
  82. use_formula_recognition=True,
  83. )
  84. batch_processing_time = time.time() - start_time
  85. batch_results = []
  86. for result in results:
  87. try:
  88. input_path = Path(result.input_path)
  89. # 保存结果
  90. if result.get("page_index") is not None:
  91. output_filename = f"{input_path.stem}_{result['page_index']}"
  92. else:
  93. output_filename = f"{input_path.stem}"
  94. # 保存JSON和Markdown
  95. json_output_path = str(Path(output_dir, f"{output_filename}.json"))
  96. md_output_path = str(Path(output_dir, f"{output_filename}.md"))
  97. result.save_to_json(json_output_path)
  98. result.save_to_markdown(md_output_path)
  99. # 记录处理结果
  100. batch_results.append({
  101. "image_path": input_path.name,
  102. "processing_time": batch_processing_time / len(batch), # 平均时间
  103. "success": True,
  104. "device": device,
  105. "worker_id": worker_id,
  106. "output_json": json_output_path,
  107. "output_md": md_output_path
  108. })
  109. processed_count += 1
  110. except Exception as e:
  111. batch_results.append({
  112. "image_path": Path(result.input_path).name if hasattr(result, 'input_path') else "unknown",
  113. "processing_time": 0,
  114. "success": False,
  115. "device": device,
  116. "worker_id": worker_id,
  117. "error": str(e)
  118. })
  119. # 将结果放入结果队列
  120. result_queue.put(batch_results)
  121. print(f"Worker {worker_id} ({device}) processed batch of {len(batch)} files. Total: {processed_count}")
  122. except Exception as e:
  123. # 批处理失败
  124. error_results = []
  125. for img_path in batch:
  126. error_results.append({
  127. "image_path": Path(img_path).name,
  128. "processing_time": 0,
  129. "success": False,
  130. "device": device,
  131. "worker_id": worker_id,
  132. "error": str(e)
  133. })
  134. result_queue.put(error_results)
  135. print(f"Error processing batch {batch} on {device}: {e}", file=sys.stderr)
  136. batch.clear()
  137. except Exception as e:
  138. print(f"Worker {worker_id} ({device}) initialization failed: {e}", file=sys.stderr)
  139. traceback.print_exc()
  140. finally:
  141. print(f"Worker {worker_id} ({device}) finished")
  142. def parallel_process_with_official_approach(image_paths: List[str],
  143. pipeline_name: str = "PP-StructureV3",
  144. device_str: str = "gpu:0,1",
  145. instances_per_device: int = 1,
  146. batch_size: int = 1,
  147. output_dir: str = "./output") -> List[Dict[str, Any]]:
  148. """
  149. 使用官方推荐的方法进行多GPU多进程并行处理
  150. Args:
  151. image_paths: 图像路径列表
  152. pipeline_name: Pipeline名称
  153. device_str: 设备字符串,如"gpu:0,1,2,3"
  154. instances_per_device: 每个设备的实例数
  155. batch_size: 批处理大小
  156. output_dir: 输出目录
  157. Returns:
  158. 处理结果列表
  159. """
  160. # 创建输出目录
  161. output_path = Path(output_dir)
  162. output_path.mkdir(parents=True, exist_ok=True)
  163. # 解析设备
  164. try:
  165. device_type, device_ids = parse_device(device_str)
  166. if device_ids is None or len(device_ids) < 1:
  167. print("No valid devices specified.", file=sys.stderr)
  168. return []
  169. print(f"Parsed devices: {device_type}:{device_ids}")
  170. except Exception as e:
  171. print(f"Failed to parse device string '{device_str}': {e}", file=sys.stderr)
  172. return []
  173. # 验证批处理大小
  174. if batch_size <= 0:
  175. print("Batch size must be greater than 0.", file=sys.stderr)
  176. return []
  177. total_instances = len(device_ids) * instances_per_device
  178. print(f"Configuration:")
  179. print(f" Devices: {device_ids}")
  180. print(f" Instances per device: {instances_per_device}")
  181. print(f" Total instances: {total_instances}")
  182. print(f" Batch size: {batch_size}")
  183. print(f" Total images: {len(image_paths)}")
  184. # 在主进程中初始化paddle,防止子进程CUDA初始化冲突
  185. try:
  186. import paddle
  187. # 只在主进程中设置一个默认设备
  188. paddle.set_device("cpu") # 主进程使用CPU
  189. except Exception as e:
  190. print(f"Warning: Failed to initialize paddle in main process: {e}")
  191. # 使用Manager创建队列
  192. with Manager() as manager:
  193. task_queue = manager.Queue()
  194. result_queue = manager.Queue()
  195. # 将任务放入队列
  196. for img_path in image_paths:
  197. task_queue.put(str(img_path))
  198. print(f"Added {len(image_paths)} tasks to queue")
  199. # 创建并启动工作进程
  200. processes = []
  201. worker_id = 0
  202. for device_id in device_ids:
  203. for instance_idx in range(instances_per_device):
  204. device = constr_device(device_type, [device_id])
  205. p = Process(
  206. target=worker,
  207. args=(
  208. pipeline_name,
  209. device,
  210. task_queue,
  211. result_queue,
  212. batch_size,
  213. str(output_path),
  214. worker_id,
  215. ),
  216. name=f"Worker-{worker_id}-{device}"
  217. )
  218. p.start()
  219. processes.append(p)
  220. worker_id += 1
  221. print(f"Started {len(processes)} worker processes")
  222. # 收集结果
  223. all_results = []
  224. completed_images = 0
  225. total_images = len(image_paths)
  226. with tqdm(total=total_images, desc="Processing images", unit="img") as pbar:
  227. # 等待所有结果
  228. active_workers = len(processes)
  229. while completed_images < total_images and active_workers > 0:
  230. try:
  231. # 设置较短的超时时间,定期检查进程状态
  232. batch_results = result_queue.get(timeout=5.0)
  233. all_results.extend(batch_results)
  234. batch_size_actual = len(batch_results)
  235. completed_images += batch_size_actual
  236. pbar.update(batch_size_actual)
  237. # 更新进度条信息
  238. success_count = sum(1 for r in batch_results if r.get('success', False))
  239. total_success = sum(1 for r in all_results if r.get('success', False))
  240. # 按设备统计
  241. device_stats = {}
  242. for r in all_results:
  243. device = r.get('device', 'unknown')
  244. if device not in device_stats:
  245. device_stats[device] = {'success': 0, 'total': 0}
  246. device_stats[device]['total'] += 1
  247. if r.get('success', False):
  248. device_stats[device]['success'] += 1
  249. device_info = ', '.join([f"{k}:{v['success']}/{v['total']}"
  250. for k, v in device_stats.items()])
  251. pbar.set_postfix({
  252. 'batch_success': f"{success_count}/{batch_size_actual}",
  253. 'total_success': f"{total_success}/{completed_images}",
  254. 'devices': device_info
  255. })
  256. except Exception as e:
  257. # 检查是否还有活跃的进程
  258. active_workers = sum(1 for p in processes if p.is_alive())
  259. if active_workers == 0:
  260. print("All workers have finished")
  261. break
  262. # 超时或其他错误,继续等待
  263. continue
  264. # 等待所有进程结束
  265. print("Waiting for all processes to finish...")
  266. for p in processes:
  267. p.join(timeout=10.0)
  268. if p.is_alive():
  269. print(f"Force terminating process: {p.name}")
  270. p.terminate()
  271. p.join(timeout=5.0)
  272. return all_results
  273. def main():
  274. """主函数"""
  275. parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Multi-GPU Parallel Processing")
  276. # 必需参数
  277. parser.add_argument("--input_dir", type=str, required=True,
  278. help="Input directory containing images")
  279. parser.add_argument("--output_dir", type=str, default="./output",
  280. help="Output directory")
  281. # Pipeline配置
  282. parser.add_argument("--pipeline", type=str, default="PP-StructureV3",
  283. help="Pipeline name or config path")
  284. parser.add_argument("--device", type=str, default="gpu:0,1",
  285. help="Devices for parallel inference (e.g., 'gpu:0,1,2,3')")
  286. # 并行配置
  287. parser.add_argument("--instances_per_device", type=int, default=1,
  288. help="Number of pipeline instances per device")
  289. parser.add_argument("--batch_size", type=int, default=1,
  290. help="Inference batch size for each pipeline instance")
  291. # 输入文件配置
  292. parser.add_argument("--input_glob_pattern", type=str, default="*",
  293. help="Pattern to find input files")
  294. # 测试模式
  295. parser.add_argument("--test_mode", action="store_true",
  296. help="Test mode: only process first 20 images")
  297. args = parser.parse_args()
  298. # 验证输入目录
  299. input_dir = Path(args.input_dir)
  300. if not input_dir.exists():
  301. print(f"Input directory does not exist: {input_dir}", file=sys.stderr)
  302. return 2
  303. if not input_dir.is_dir():
  304. print(f"{input_dir} is not a directory", file=sys.stderr)
  305. return 2
  306. # 验证输出目录
  307. output_dir = Path(args.output_dir)
  308. if output_dir.exists() and not output_dir.is_dir():
  309. print(f"{output_dir} is not a directory", file=sys.stderr)
  310. return 2
  311. print("="*70)
  312. print("PaddleX PP-StructureV3 Multi-GPU Parallel Processing")
  313. print("="*70)
  314. print(f"Input directory: {input_dir}")
  315. print(f"Output directory: {output_dir}")
  316. print(f"Pipeline: {args.pipeline}")
  317. print(f"Device: {args.device}")
  318. print(f"Instances per device: {args.instances_per_device}")
  319. print(f"Batch size: {args.batch_size}")
  320. print(f"Input pattern: {args.input_glob_pattern}")
  321. # 查找图像文件
  322. image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.pdf']
  323. image_files = []
  324. for ext in image_extensions:
  325. pattern = args.input_glob_pattern if args.input_glob_pattern != "*" else ext
  326. image_files.extend(input_dir.glob(pattern))
  327. # 如果没有找到文件,尝试使用用户指定的模式
  328. if not image_files and args.input_glob_pattern != "*":
  329. image_files = list(input_dir.glob(args.input_glob_pattern))
  330. if not image_files:
  331. print(f"No image files found in {input_dir} with pattern {args.input_glob_pattern}")
  332. return 1
  333. # 转换为字符串路径
  334. image_paths = [str(f) for f in image_files]
  335. print(f"Found {len(image_paths)} image files")
  336. # 测试模式
  337. if args.test_mode:
  338. image_paths = image_paths[:20]
  339. print(f"Test mode: processing only {len(image_paths)} images")
  340. # 开始处理
  341. start_time = time.time()
  342. try:
  343. results = parallel_process_with_official_approach(
  344. image_paths=image_paths,
  345. pipeline_name=args.pipeline,
  346. device_str=args.device,
  347. instances_per_device=args.instances_per_device,
  348. batch_size=args.batch_size,
  349. output_dir=args.output_dir
  350. )
  351. total_time = time.time() - start_time
  352. # 统计信息
  353. success_count = sum(1 for r in results if r.get('success', False))
  354. error_count = len(results) - success_count
  355. total_processing_time = sum(r.get('processing_time', 0) for r in results if r.get('success', False))
  356. avg_processing_time = total_processing_time / success_count if success_count > 0 else 0
  357. # 按设备统计
  358. device_stats = {}
  359. worker_stats = {}
  360. for r in results:
  361. device = r.get('device', 'unknown')
  362. worker_id = r.get('worker_id', 'unknown')
  363. # 设备统计
  364. if device not in device_stats:
  365. device_stats[device] = {'success': 0, 'total': 0, 'total_time': 0}
  366. device_stats[device]['total'] += 1
  367. if r.get('success', False):
  368. device_stats[device]['success'] += 1
  369. device_stats[device]['total_time'] += r.get('processing_time', 0)
  370. # Worker统计
  371. if worker_id not in worker_stats:
  372. worker_stats[worker_id] = {'success': 0, 'total': 0, 'device': device}
  373. worker_stats[worker_id]['total'] += 1
  374. if r.get('success', False):
  375. worker_stats[worker_id]['success'] += 1
  376. # 保存详细结果
  377. detailed_results = {
  378. "configuration": {
  379. "pipeline": args.pipeline,
  380. "device": args.device,
  381. "instances_per_device": args.instances_per_device,
  382. "batch_size": args.batch_size,
  383. "input_glob_pattern": args.input_glob_pattern,
  384. "test_mode": args.test_mode
  385. },
  386. "statistics": {
  387. "total_files": len(image_paths),
  388. "success_count": success_count,
  389. "error_count": error_count,
  390. "success_rate": success_count / len(image_paths) if image_paths else 0,
  391. "total_time": total_time,
  392. "avg_processing_time": avg_processing_time,
  393. "throughput": len(image_paths) / total_time if total_time > 0 else 0,
  394. "device_stats": device_stats,
  395. "worker_stats": worker_stats
  396. },
  397. "results": results
  398. }
  399. # 保存结果文件
  400. result_file = output_dir / "processing_results.json"
  401. with open(result_file, 'w', encoding='utf-8') as f:
  402. json.dump(detailed_results, f, ensure_ascii=False, indent=2)
  403. # 打印统计信息
  404. print("\n" + "="*70)
  405. print("Processing completed!")
  406. print("="*70)
  407. print(f"Total files: {len(image_paths)}")
  408. print(f"Successfully processed: {success_count}")
  409. print(f"Failed: {error_count}")
  410. print(f"Success rate: {success_count / len(image_paths) * 100:.2f}%")
  411. print(f"Total time: {total_time:.2f} seconds")
  412. print(f"Average processing time: {avg_processing_time:.2f} seconds/image")
  413. print(f"Throughput: {len(image_paths) / total_time:.2f} images/second")
  414. # 设备统计
  415. print(f"\nDevice Statistics:")
  416. for device, stats in device_stats.items():
  417. if stats['total'] > 0:
  418. success_rate = stats['success'] / stats['total'] * 100
  419. avg_time = stats['total_time'] / stats['success'] if stats['success'] > 0 else 0
  420. print(f" {device}: {stats['success']}/{stats['total']} "
  421. f"({success_rate:.1f}%), avg {avg_time:.2f}s/image")
  422. # Worker统计
  423. print(f"\nWorker Statistics:")
  424. for worker_id, stats in worker_stats.items():
  425. if stats['total'] > 0:
  426. success_rate = stats['success'] / stats['total'] * 100
  427. print(f" Worker {worker_id} ({stats['device']}): {stats['success']}/{stats['total']} "
  428. f"({success_rate:.1f}%)")
  429. print(f"\nDetailed results saved to: {result_file}")
  430. print("All done!")
  431. return 0
  432. except Exception as e:
  433. print(f"Processing failed: {e}", file=sys.stderr)
  434. traceback.print_exc()
  435. return 1
  436. if __name__ == "__main__":
  437. print(f"🚀 启动OCR程序...")
  438. print(f"CUDA 版本: {paddle.device.cuda.get_device_name()}")
  439. print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
  440. available_gpus = detect_available_gpus()
  441. monitor_gpu_memory(available_gpus)
  442. if len(sys.argv) == 1:
  443. # 如果没有命令行参数,使用默认配置运行
  444. print("No command line arguments provided. Running with default configuration...")
  445. # 默认配置
  446. default_config = {
  447. "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
  448. "output_dir": "./OmniDocBench_Results_Official",
  449. "pipeline": "PP-StructureV3",
  450. "device": "gpu:0",
  451. "instances_per_device": 1,
  452. "batch_size": 4,
  453. # "test_mode": False
  454. }
  455. # 构造参数
  456. sys.argv = [sys.argv[0]]
  457. for key, value in default_config.items():
  458. sys.argv.extend([f"--{key}", str(value)])
  459. # 测试模式
  460. sys.argv.append("--test_mode")
  461. sys.exit(main())