mp_infer.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import argparse
  2. import sys
  3. import os
  4. import time
  5. import traceback
  6. from multiprocessing import Manager, Process
  7. from pathlib import Path
  8. from queue import Empty
  9. from paddlex import create_pipeline
  10. from paddlex.utils.device import constr_device, parse_device
  11. def worker(pipeline_name_or_config_path, device, task_queue, batch_size, output_dir):
  12. import paddle
  13. paddle.utils.run_check()
  14. # 限制GPU内存使用,减少CUDA冲突
  15. # os.environ["FLAGS_fraction_of_gpu_memory_to_use"] = "0.6"
  16. # 使用确定性算法
  17. # os.environ["FLAGS_cudnn_deterministic"] = "1"
  18. # 立即释放内存
  19. # os.environ["FLAGS_eager_delete_tensor_gb"] = "0.0"
  20. pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
  21. should_end = False
  22. batch = []
  23. processed_count = 0
  24. while not should_end:
  25. try:
  26. input_path = task_queue.get_nowait()
  27. except Empty:
  28. should_end = True
  29. except Exception as e:
  30. # 处理其他可能的异常
  31. print(f"Unexpected error while getting task: {e}", file=sys.stderr)
  32. traceback.print_exc()
  33. should_end = True
  34. else:
  35. if input_path is None:
  36. should_end = True
  37. else:
  38. batch.append(input_path)
  39. if batch and (len(batch) == batch_size or should_end):
  40. try:
  41. start_time = time.time()
  42. # 使用pipeline预测,添加PP-StructureV3的参数
  43. results = pipeline.predict(
  44. batch,
  45. use_doc_orientation_classify=True,
  46. use_doc_unwarping=False,
  47. use_seal_recognition=True,
  48. use_chart_recognition=True,
  49. use_table_recognition=True,
  50. use_formula_recognition=True,
  51. )
  52. batch_processing_time = time.time() - start_time
  53. for result in results:
  54. try:
  55. input_path = Path(result["input_path"])
  56. # 保存结果 - 按照ppstructurev3的方式处理文件名
  57. if result.get("page_index") is not None:
  58. output_filename = f"{input_path.stem}_{result['page_index']}"
  59. else:
  60. output_filename = f"{input_path.stem}"
  61. # 保存JSON和Markdown文件
  62. json_output_path = str(Path(output_dir, f"{output_filename}.json"))
  63. md_output_path = str(Path(output_dir, f"{output_filename}.md"))
  64. result.save_to_json(json_output_path)
  65. result.save_to_markdown(md_output_path)
  66. processed_count += 1
  67. print(
  68. f"Processed {repr(str(input_path))} -> {json_output_path}, {md_output_path}"
  69. )
  70. except Exception as e:
  71. print(
  72. f"Error saving result for {result.get('input_path', 'unknown')}: {e}",
  73. file=sys.stderr,
  74. )
  75. traceback.print_exc()
  76. print(
  77. f"Batch processed: {len(batch)} files in {batch_processing_time:.2f}s on {device}"
  78. )
  79. except Exception as e:
  80. print(f"Error processing batch {batch} on {repr(device)}: {e}", file=sys.stderr)
  81. traceback.print_exc()
  82. batch.clear()
  83. def main():
  84. parser = argparse.ArgumentParser()
  85. parser.add_argument(
  86. "--pipeline", type=str, required=True, help="Pipeline name or config path."
  87. )
  88. parser.add_argument("--input_dir", type=str, required=True, help="Input directory.")
  89. parser.add_argument(
  90. "--device",
  91. type=str,
  92. required=True,
  93. help="Specifies the devices for performing parallel inference.",
  94. )
  95. parser.add_argument(
  96. "--output_dir", type=str, default="output", help="Output directory."
  97. )
  98. parser.add_argument(
  99. "--instances_per_device",
  100. type=int,
  101. default=1,
  102. help="Number of pipeline instances per device.",
  103. )
  104. parser.add_argument(
  105. "--batch_size",
  106. type=int,
  107. default=1,
  108. help="Inference batch size for each pipeline instance.",
  109. )
  110. parser.add_argument(
  111. "--input_glob_pattern",
  112. type=str,
  113. default="*",
  114. help="Pattern to find the input files.",
  115. )
  116. parser.add_argument(
  117. "--test_mode",
  118. action="store_true",
  119. help="Test mode (process only 20 images)"
  120. )
  121. args = parser.parse_args()
  122. input_dir = Path(args.input_dir).resolve()
  123. print(f"Input directory: {input_dir}")
  124. if not input_dir.exists():
  125. print(f"The input directory does not exist: {input_dir}", file=sys.stderr)
  126. return 2
  127. if not input_dir.is_dir():
  128. print(f"{repr(str(input_dir))} is not a directory.", file=sys.stderr)
  129. return 2
  130. output_dir = Path(args.output_dir).resolve()
  131. print(f"Output directory: {output_dir}")
  132. if output_dir.exists() and not output_dir.is_dir():
  133. print(f"{repr(str(output_dir))} is not a directory.", file=sys.stderr)
  134. return 2
  135. output_dir.mkdir(parents=True, exist_ok=True)
  136. device_type, device_ids = parse_device(args.device)
  137. if device_ids is None or len(device_ids) == 1:
  138. print(
  139. "Please specify at least two devices for performing parallel inference.",
  140. file=sys.stderr,
  141. )
  142. return 2
  143. if args.batch_size <= 0:
  144. print("Batch size must be greater than 0.", file=sys.stderr)
  145. return 2
  146. # 查找图像文件
  147. image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
  148. image_files = []
  149. for ext in image_extensions:
  150. image_files.extend(list(input_dir.glob(f"*{ext}")))
  151. image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
  152. print(f"Found {len(image_files)} image files")
  153. if args.test_mode:
  154. image_files = image_files[:20]
  155. print(f"Test mode: processing only {len(image_files)} images")
  156. with Manager() as manager:
  157. task_queue = manager.Queue()
  158. # 将图像文件路径放入队列
  159. for img_path in image_files:
  160. task_queue.put(str(img_path))
  161. processes = []
  162. for device_id in device_ids:
  163. for _ in range(args.instances_per_device):
  164. device = constr_device(device_type, [device_id])
  165. p = Process(
  166. target=worker,
  167. args=(
  168. args.pipeline,
  169. device,
  170. task_queue,
  171. args.batch_size,
  172. str(output_dir),
  173. ),
  174. )
  175. p.start()
  176. processes.append(p)
  177. # 发送结束信号
  178. for _ in range(len(device_ids) * args.instances_per_device):
  179. task_queue.put(None)
  180. for p in processes:
  181. p.join()
  182. print("All done")
  183. return 0
  184. if __name__ == "__main__":
  185. sys.exit(main())