| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- import argparse
- import sys
- import os
- import time
- import traceback
- from multiprocessing import Manager, Process
- from pathlib import Path
- from queue import Empty
- from paddlex import create_pipeline
- from paddlex.utils.device import constr_device, parse_device
- def worker(pipeline_name_or_config_path, device, task_queue, batch_size, output_dir):
- import paddle
- paddle.utils.run_check()
- # 限制GPU内存使用,减少CUDA冲突
- # os.environ["FLAGS_fraction_of_gpu_memory_to_use"] = "0.6"
- # 使用确定性算法
- # os.environ["FLAGS_cudnn_deterministic"] = "1"
- # 立即释放内存
- # os.environ["FLAGS_eager_delete_tensor_gb"] = "0.0"
- pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
- should_end = False
- batch = []
- processed_count = 0
- while not should_end:
- try:
- input_path = task_queue.get_nowait()
- except Empty:
- should_end = True
- except Exception as e:
- # 处理其他可能的异常
- print(f"Unexpected error while getting task: {e}", file=sys.stderr)
- traceback.print_exc()
- should_end = True
- else:
- if input_path is None:
- should_end = True
- else:
- batch.append(input_path)
- if batch and (len(batch) == batch_size or should_end):
- try:
- start_time = time.time()
- # 使用pipeline预测,添加PP-StructureV3的参数
- results = pipeline.predict(
- batch,
- use_doc_orientation_classify=True,
- use_doc_unwarping=False,
- use_seal_recognition=True,
- use_chart_recognition=True,
- use_table_recognition=True,
- use_formula_recognition=True,
- )
- batch_processing_time = time.time() - start_time
- for result in results:
- try:
- input_path = Path(result["input_path"])
- # 保存结果 - 按照ppstructurev3的方式处理文件名
- if result.get("page_index") is not None:
- output_filename = f"{input_path.stem}_{result['page_index']}"
- else:
- output_filename = f"{input_path.stem}"
- # 保存JSON和Markdown文件
- json_output_path = str(Path(output_dir, f"{output_filename}.json"))
- md_output_path = str(Path(output_dir, f"{output_filename}.md"))
- result.save_to_json(json_output_path)
- result.save_to_markdown(md_output_path)
- processed_count += 1
- print(
- f"Processed {repr(str(input_path))} -> {json_output_path}, {md_output_path}"
- )
- except Exception as e:
- print(
- f"Error saving result for {result.get('input_path', 'unknown')}: {e}",
- file=sys.stderr,
- )
- traceback.print_exc()
- print(
- f"Batch processed: {len(batch)} files in {batch_processing_time:.2f}s on {device}"
- )
- except Exception as e:
- print(f"Error processing batch {batch} on {repr(device)}: {e}", file=sys.stderr)
- traceback.print_exc()
- batch.clear()
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--pipeline", type=str, required=True, help="Pipeline name or config path."
- )
- parser.add_argument("--input_dir", type=str, required=True, help="Input directory.")
- parser.add_argument(
- "--device",
- type=str,
- required=True,
- help="Specifies the devices for performing parallel inference.",
- )
- parser.add_argument(
- "--output_dir", type=str, default="output", help="Output directory."
- )
- parser.add_argument(
- "--instances_per_device",
- type=int,
- default=1,
- help="Number of pipeline instances per device.",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="Inference batch size for each pipeline instance.",
- )
- parser.add_argument(
- "--input_glob_pattern",
- type=str,
- default="*",
- help="Pattern to find the input files.",
- )
- parser.add_argument(
- "--test_mode",
- action="store_true",
- help="Test mode (process only 20 images)"
- )
- args = parser.parse_args()
- input_dir = Path(args.input_dir).resolve()
- print(f"Input directory: {input_dir}")
- if not input_dir.exists():
- print(f"The input directory does not exist: {input_dir}", file=sys.stderr)
- return 2
- if not input_dir.is_dir():
- print(f"{repr(str(input_dir))} is not a directory.", file=sys.stderr)
- return 2
- output_dir = Path(args.output_dir).resolve()
- print(f"Output directory: {output_dir}")
- if output_dir.exists() and not output_dir.is_dir():
- print(f"{repr(str(output_dir))} is not a directory.", file=sys.stderr)
- return 2
- output_dir.mkdir(parents=True, exist_ok=True)
- device_type, device_ids = parse_device(args.device)
- if device_ids is None or len(device_ids) == 1:
- print(
- "Please specify at least two devices for performing parallel inference.",
- file=sys.stderr,
- )
- return 2
- if args.batch_size <= 0:
- print("Batch size must be greater than 0.", file=sys.stderr)
- return 2
- # 查找图像文件
- image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
- image_files = []
- for ext in image_extensions:
- image_files.extend(list(input_dir.glob(f"*{ext}")))
- image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
- print(f"Found {len(image_files)} image files")
- if args.test_mode:
- image_files = image_files[:20]
- print(f"Test mode: processing only {len(image_files)} images")
- with Manager() as manager:
- task_queue = manager.Queue()
- # 将图像文件路径放入队列
- for img_path in image_files:
- task_queue.put(str(img_path))
- processes = []
- for device_id in device_ids:
- for _ in range(args.instances_per_device):
- device = constr_device(device_type, [device_id])
- p = Process(
- target=worker,
- args=(
- args.pipeline,
- device,
- task_queue,
- args.batch_size,
- str(output_dir),
- ),
- )
- p.start()
- processes.append(p)
- # 发送结束信号
- for _ in range(len(device_ids) * args.instances_per_device):
- task_queue.put(None)
- for p in processes:
- p.join()
- print("All done")
- return 0
- if __name__ == "__main__":
- sys.exit(main())
|