| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- import argparse
- import sys
- from multiprocessing import Manager, Process
- from pathlib import Path
- from queue import Empty
- from paddlex import create_pipeline
- from paddlex.utils.device import constr_device, parse_device
- def worker(pipeline_name_or_config_path, device, task_queue, batch_size, output_dir):
- pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
- should_end = False
- batch = []
- while not should_end:
- try:
- input_path = task_queue.get_nowait()
- except Empty:
- should_end = True
- else:
- batch.append(input_path)
- if batch and (len(batch) == batch_size or should_end):
- try:
- for result in pipeline.predict(batch):
- input_path = Path(result["input_path"])
- if result.get("page_index") is not None:
- output_path = f"{input_path.stem}_{result['page_index']}.json"
- else:
- output_path = f"{input_path.stem}.json"
- output_path = str(Path(output_dir, output_path))
- result.save_to_json(output_path)
- print(f"Processed {repr(str(input_path))}")
- except Exception as e:
- print(
- f"Error processing {batch} on {repr(device)}: {e}", file=sys.stderr
- )
- batch.clear()
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--pipeline", type=str, required=True, help="Pipeline name or config path."
- )
- parser.add_argument("--input_dir", type=str, required=True, help="Input directory.")
- parser.add_argument(
- "--device",
- type=str,
- required=True,
- help="Specifies the devices for performing parallel inference.",
- )
- parser.add_argument(
- "--output_dir", type=str, default="output", help="Output directory."
- )
- parser.add_argument(
- "--instances_per_device",
- type=int,
- default=1,
- help="Number of pipeline instances per device.",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="Inference batch size for each pipeline instance.",
- )
- parser.add_argument(
- "--input_glob_pattern",
- type=str,
- default="*",
- help="Pattern to find the input files.",
- )
- args = parser.parse_args()
- input_dir = Path(args.input_dir).resolve()
- print(f"Input directory: {input_dir}")
- if not input_dir.exists():
- print(f"The input directory does not exist: {input_dir}", file=sys.stderr)
- return 2
- if not input_dir.is_dir():
- print(f"{repr(str(input_dir))} is not a directory.", file=sys.stderr)
- return 2
- output_dir = Path(args.output_dir).resolve()
- print(f"Output directory: {output_dir}")
- if output_dir.exists() and not output_dir.is_dir():
- print(f"{repr(str(output_dir))} is not a directory.", file=sys.stderr)
- return 2
- output_dir.mkdir(parents=True, exist_ok=True)
- device_type, device_ids = parse_device(args.device)
- if device_ids is None or len(device_ids) == 1:
- print(
- "Please specify at least two devices for performing parallel inference.",
- file=sys.stderr,
- )
- return 2
- if args.batch_size <= 0:
- print("Batch size must be greater than 0.", file=sys.stderr)
- return 2
- with Manager() as manager:
- task_queue = manager.Queue()
- for img_path in input_dir.glob(args.input_glob_pattern):
- task_queue.put(str(img_path))
- processes = []
- for device_id in device_ids:
- for _ in range(args.instances_per_device):
- device = constr_device(device_type, [device_id])
- p = Process(
- target=worker,
- args=(
- args.pipeline,
- device,
- task_queue,
- args.batch_size,
- str(output_dir),
- ),
- )
- p.start()
- processes.append(p)
- for p in processes:
- p.join()
- print("All done")
- return 0
- if __name__ == "__main__":
- sys.exit(main())
|