Эх сурвалжийг харах

feat(zhch): 添加多进程推理脚本

- 新增 mp_infer.py 文件实现多进程推理功能
- 支持通过命令行参数指定管道名称或配置路径、输入目录、输出目录等
- 可以指定多个设备进行并行推理
- 支持自定义每个设备的管道实例数和推理批次大小
- 输入文件路径支持 glob 模式匹配
- 处理后的结果以 JSON 格式保存到输出目录
zhch158_admin 3 сар өмнө
parent
commit
05491184da
1 өөрчлөгдсөн 137 нэмэгдсэн , 0 устгасан
  1. 137 0
      zhch/mp_infer.py

+ 137 - 0
zhch/mp_infer.py

@@ -0,0 +1,137 @@
+import argparse
+import sys
+from multiprocessing import Manager, Process
+from pathlib import Path
+from queue import Empty
+
+from paddlex import create_pipeline
+from paddlex.utils.device import constr_device, parse_device
+
+
+def worker(pipeline_name_or_config_path, device, task_queue, batch_size, output_dir):
+    pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
+
+    should_end = False
+    batch = []
+
+    while not should_end:
+        try:
+            input_path = task_queue.get_nowait()
+        except Empty:
+            should_end = True
+        else:
+            batch.append(input_path)
+
+        if batch and (len(batch) == batch_size or should_end):
+            try:
+                for result in pipeline.predict(batch):
+                    input_path = Path(result["input_path"])
+                    if result.get("page_index") is not None:
+                        output_path = f"{input_path.stem}_{result['page_index']}.json"
+                    else:
+                        output_path = f"{input_path.stem}.json"
+                    output_path = str(Path(output_dir, output_path))
+                    result.save_to_json(output_path)
+                    print(f"Processed {repr(str(input_path))}")
+            except Exception as e:
+                print(
+                    f"Error processing {batch} on {repr(device)}: {e}", file=sys.stderr
+                )
+            batch.clear()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--pipeline", type=str, required=True, help="Pipeline name or config path."
+    )
+    parser.add_argument("--input_dir", type=str, required=True, help="Input directory.")
+    parser.add_argument(
+        "--device",
+        type=str,
+        required=True,
+        help="Specifies the devices for performing parallel inference.",
+    )
+    parser.add_argument(
+        "--output_dir", type=str, default="output", help="Output directory."
+    )
+    parser.add_argument(
+        "--instances_per_device",
+        type=int,
+        default=1,
+        help="Number of pipeline instances per device.",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="Inference batch size for each pipeline instance.",
+    )
+    parser.add_argument(
+        "--input_glob_pattern",
+        type=str,
+        default="*",
+        help="Pattern to find the input files.",
+    )
+    args = parser.parse_args()
+
+    input_dir = Path(args.input_dir).resolve()
+    print(f"Input directory: {input_dir}")
+    if not input_dir.exists():
+        print(f"The input directory does not exist: {input_dir}", file=sys.stderr)
+        return 2
+    if not input_dir.is_dir():
+        print(f"{repr(str(input_dir))} is not a directory.", file=sys.stderr)
+        return 2
+
+    output_dir = Path(args.output_dir).resolve()
+    print(f"Output directory: {output_dir}")
+    if output_dir.exists() and not output_dir.is_dir():
+        print(f"{repr(str(output_dir))} is not a directory.", file=sys.stderr)
+        return 2
+    output_dir.mkdir(parents=True, exist_ok=True)
+
+    device_type, device_ids = parse_device(args.device)
+    if device_ids is None or len(device_ids) == 1:
+        print(
+            "Please specify at least two devices for performing parallel inference.",
+            file=sys.stderr,
+        )
+        return 2
+
+    if args.batch_size <= 0:
+        print("Batch size must be greater than 0.", file=sys.stderr)
+        return 2
+
+    with Manager() as manager:
+        task_queue = manager.Queue()
+        for img_path in input_dir.glob(args.input_glob_pattern):
+            task_queue.put(str(img_path))
+
+        processes = []
+        for device_id in device_ids:
+            for _ in range(args.instances_per_device):
+                device = constr_device(device_type, [device_id])
+                p = Process(
+                    target=worker,
+                    args=(
+                        args.pipeline,
+                        device,
+                        task_queue,
+                        args.batch_size,
+                        str(output_dir),
+                    ),
+                )
+                p.start()
+                processes.append(p)
+
+        for p in processes:
+            p.join()
+
+    print("All done")
+
+    return 0
+
+
+if __name__ == "__main__":
+    sys.exit(main())