mp_infer_orignal.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import argparse
  2. import sys
  3. from multiprocessing import Manager, Process
  4. from pathlib import Path
  5. from queue import Empty
  6. from paddlex import create_pipeline
  7. from paddlex.utils.device import constr_device, parse_device
  8. def worker(pipeline_name_or_config_path, device, task_queue, batch_size, output_dir):
  9. pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
  10. should_end = False
  11. batch = []
  12. while not should_end:
  13. try:
  14. input_path = task_queue.get_nowait()
  15. except Empty:
  16. should_end = True
  17. else:
  18. batch.append(input_path)
  19. if batch and (len(batch) == batch_size or should_end):
  20. try:
  21. for result in pipeline.predict(batch):
  22. input_path = Path(result["input_path"])
  23. if result.get("page_index") is not None:
  24. output_path = f"{input_path.stem}_{result['page_index']}.json"
  25. else:
  26. output_path = f"{input_path.stem}.json"
  27. output_path = str(Path(output_dir, output_path))
  28. result.save_to_json(output_path)
  29. print(f"Processed {repr(str(input_path))}")
  30. except Exception as e:
  31. print(
  32. f"Error processing {batch} on {repr(device)}: {e}", file=sys.stderr
  33. )
  34. batch.clear()
  35. def main():
  36. parser = argparse.ArgumentParser()
  37. parser.add_argument(
  38. "--pipeline", type=str, required=True, help="Pipeline name or config path."
  39. )
  40. parser.add_argument("--input_dir", type=str, required=True, help="Input directory.")
  41. parser.add_argument(
  42. "--device",
  43. type=str,
  44. required=True,
  45. help="Specifies the devices for performing parallel inference.",
  46. )
  47. parser.add_argument(
  48. "--output_dir", type=str, default="output", help="Output directory."
  49. )
  50. parser.add_argument(
  51. "--instances_per_device",
  52. type=int,
  53. default=1,
  54. help="Number of pipeline instances per device.",
  55. )
  56. parser.add_argument(
  57. "--batch_size",
  58. type=int,
  59. default=1,
  60. help="Inference batch size for each pipeline instance.",
  61. )
  62. parser.add_argument(
  63. "--input_glob_pattern",
  64. type=str,
  65. default="*",
  66. help="Pattern to find the input files.",
  67. )
  68. args = parser.parse_args()
  69. input_dir = Path(args.input_dir).resolve()
  70. print(f"Input directory: {input_dir}")
  71. if not input_dir.exists():
  72. print(f"The input directory does not exist: {input_dir}", file=sys.stderr)
  73. return 2
  74. if not input_dir.is_dir():
  75. print(f"{repr(str(input_dir))} is not a directory.", file=sys.stderr)
  76. return 2
  77. output_dir = Path(args.output_dir).resolve()
  78. print(f"Output directory: {output_dir}")
  79. if output_dir.exists() and not output_dir.is_dir():
  80. print(f"{repr(str(output_dir))} is not a directory.", file=sys.stderr)
  81. return 2
  82. output_dir.mkdir(parents=True, exist_ok=True)
  83. device_type, device_ids = parse_device(args.device)
  84. if device_ids is None or len(device_ids) == 1:
  85. print(
  86. "Please specify at least two devices for performing parallel inference.",
  87. file=sys.stderr,
  88. )
  89. return 2
  90. if args.batch_size <= 0:
  91. print("Batch size must be greater than 0.", file=sys.stderr)
  92. return 2
  93. with Manager() as manager:
  94. task_queue = manager.Queue()
  95. for img_path in input_dir.glob(args.input_glob_pattern):
  96. task_queue.put(str(img_path))
  97. processes = []
  98. for device_id in device_ids:
  99. for _ in range(args.instances_per_device):
  100. device = constr_device(device_type, [device_id])
  101. p = Process(
  102. target=worker,
  103. args=(
  104. args.pipeline,
  105. device,
  106. task_queue,
  107. args.batch_size,
  108. str(output_dir),
  109. ),
  110. )
  111. p.start()
  112. processes.append(p)
  113. for p in processes:
  114. p.join()
  115. print("All done")
  116. return 0
  117. if __name__ == "__main__":
  118. sys.exit(main())