|
|
@@ -0,0 +1,137 @@
|
|
|
+import argparse
|
|
|
+import sys
|
|
|
+from multiprocessing import Manager, Process
|
|
|
+from pathlib import Path
|
|
|
+from queue import Empty
|
|
|
+
|
|
|
+from paddlex import create_pipeline
|
|
|
+from paddlex.utils.device import constr_device, parse_device
|
|
|
+
|
|
|
+
|
|
|
+def worker(pipeline_name_or_config_path, device, task_queue, batch_size, output_dir):
|
|
|
+ pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
|
|
|
+
|
|
|
+ should_end = False
|
|
|
+ batch = []
|
|
|
+
|
|
|
+ while not should_end:
|
|
|
+ try:
|
|
|
+ input_path = task_queue.get_nowait()
|
|
|
+ except Empty:
|
|
|
+ should_end = True
|
|
|
+ else:
|
|
|
+ batch.append(input_path)
|
|
|
+
|
|
|
+ if batch and (len(batch) == batch_size or should_end):
|
|
|
+ try:
|
|
|
+ for result in pipeline.predict(batch):
|
|
|
+ input_path = Path(result["input_path"])
|
|
|
+ if result.get("page_index") is not None:
|
|
|
+ output_path = f"{input_path.stem}_{result['page_index']}.json"
|
|
|
+ else:
|
|
|
+ output_path = f"{input_path.stem}.json"
|
|
|
+ output_path = str(Path(output_dir, output_path))
|
|
|
+ result.save_to_json(output_path)
|
|
|
+ print(f"Processed {repr(str(input_path))}")
|
|
|
+ except Exception as e:
|
|
|
+ print(
|
|
|
+ f"Error processing {batch} on {repr(device)}: {e}", file=sys.stderr
|
|
|
+ )
|
|
|
+ batch.clear()
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument(
|
|
|
+ "--pipeline", type=str, required=True, help="Pipeline name or config path."
|
|
|
+ )
|
|
|
+ parser.add_argument("--input_dir", type=str, required=True, help="Input directory.")
|
|
|
+ parser.add_argument(
|
|
|
+ "--device",
|
|
|
+ type=str,
|
|
|
+ required=True,
|
|
|
+ help="Specifies the devices for performing parallel inference.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--output_dir", type=str, default="output", help="Output directory."
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--instances_per_device",
|
|
|
+ type=int,
|
|
|
+ default=1,
|
|
|
+ help="Number of pipeline instances per device.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--batch_size",
|
|
|
+ type=int,
|
|
|
+ default=1,
|
|
|
+ help="Inference batch size for each pipeline instance.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--input_glob_pattern",
|
|
|
+ type=str,
|
|
|
+ default="*",
|
|
|
+ help="Pattern to find the input files.",
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ input_dir = Path(args.input_dir).resolve()
|
|
|
+ print(f"Input directory: {input_dir}")
|
|
|
+ if not input_dir.exists():
|
|
|
+ print(f"The input directory does not exist: {input_dir}", file=sys.stderr)
|
|
|
+ return 2
|
|
|
+ if not input_dir.is_dir():
|
|
|
+ print(f"{repr(str(input_dir))} is not a directory.", file=sys.stderr)
|
|
|
+ return 2
|
|
|
+
|
|
|
+ output_dir = Path(args.output_dir).resolve()
|
|
|
+ print(f"Output directory: {output_dir}")
|
|
|
+ if output_dir.exists() and not output_dir.is_dir():
|
|
|
+ print(f"{repr(str(output_dir))} is not a directory.", file=sys.stderr)
|
|
|
+ return 2
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ device_type, device_ids = parse_device(args.device)
|
|
|
+ if device_ids is None or len(device_ids) == 1:
|
|
|
+ print(
|
|
|
+ "Please specify at least two devices for performing parallel inference.",
|
|
|
+ file=sys.stderr,
|
|
|
+ )
|
|
|
+ return 2
|
|
|
+
|
|
|
+ if args.batch_size <= 0:
|
|
|
+ print("Batch size must be greater than 0.", file=sys.stderr)
|
|
|
+ return 2
|
|
|
+
|
|
|
+ with Manager() as manager:
|
|
|
+ task_queue = manager.Queue()
|
|
|
+ for img_path in input_dir.glob(args.input_glob_pattern):
|
|
|
+ task_queue.put(str(img_path))
|
|
|
+
|
|
|
+ processes = []
|
|
|
+ for device_id in device_ids:
|
|
|
+ for _ in range(args.instances_per_device):
|
|
|
+ device = constr_device(device_type, [device_id])
|
|
|
+ p = Process(
|
|
|
+ target=worker,
|
|
|
+ args=(
|
|
|
+ args.pipeline,
|
|
|
+ device,
|
|
|
+ task_queue,
|
|
|
+ args.batch_size,
|
|
|
+ str(output_dir),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ p.start()
|
|
|
+ processes.append(p)
|
|
|
+
|
|
|
+ for p in processes:
|
|
|
+ p.join()
|
|
|
+
|
|
|
+ print("All done")
|
|
|
+
|
|
|
+ return 0
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ sys.exit(main())
|