import argparse import sys import time import traceback from multiprocessing import Manager, Process from pathlib import Path from queue import Empty from paddlex import create_pipeline from paddlex.utils.device import constr_device, parse_device def worker(pipeline_name_or_config_path, device, task_queue, batch_size, output_dir): pipeline = create_pipeline(pipeline_name_or_config_path, device=device) should_end = False batch = [] processed_count = 0 while not should_end: try: input_path = task_queue.get_nowait() except Empty: should_end = True except Exception as e: # 处理其他可能的异常 print(f"Unexpected error while getting task: {e}", file=sys.stderr) traceback.print_exc() should_end = True else: if input_path is None: should_end = True else: batch.append(input_path) if batch and (len(batch) == batch_size or should_end): try: start_time = time.time() # 使用pipeline预测,添加PP-StructureV3的参数 results = pipeline.predict( batch, use_doc_orientation_classify=True, use_doc_unwarping=False, use_seal_recognition=True, use_chart_recognition=True, use_table_recognition=True, use_formula_recognition=True, ) batch_processing_time = time.time() - start_time for result in results: try: input_path = Path(result["input_path"]) # 保存结果 - 按照ppstructurev3的方式处理文件名 if result.get("page_index") is not None: output_filename = f"{input_path.stem}_{result['page_index']}" else: output_filename = f"{input_path.stem}" # 保存JSON和Markdown文件 json_output_path = str(Path(output_dir, f"{output_filename}.json")) md_output_path = str(Path(output_dir, f"{output_filename}.md")) result.save_to_json(json_output_path) result.save_to_markdown(md_output_path) processed_count += 1 print( f"Processed {repr(str(input_path))} -> {json_output_path}, {md_output_path}" ) except Exception as e: print( f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr, ) traceback.print_exc() print( f"Batch processed: {len(batch)} files in {batch_processing_time:.2f}s on {device}" ) except Exception as e: print(f"Error processing batch {batch} on {repr(device)}: {e}", file=sys.stderr) traceback.print_exc() batch.clear() def main(): parser = argparse.ArgumentParser() parser.add_argument( "--pipeline", type=str, required=True, help="Pipeline name or config path." ) parser.add_argument("--input_dir", type=str, required=True, help="Input directory.") parser.add_argument( "--device", type=str, required=True, help="Specifies the devices for performing parallel inference.", ) parser.add_argument( "--output_dir", type=str, default="output", help="Output directory." ) parser.add_argument( "--instances_per_device", type=int, default=1, help="Number of pipeline instances per device.", ) parser.add_argument( "--batch_size", type=int, default=1, help="Inference batch size for each pipeline instance.", ) parser.add_argument( "--input_glob_pattern", type=str, default="*", help="Pattern to find the input files.", ) parser.add_argument( "--test_mode", action="store_true", help="Test mode (process only 20 images)" ) args = parser.parse_args() input_dir = Path(args.input_dir).resolve() print(f"Input directory: {input_dir}") if not input_dir.exists(): print(f"The input directory does not exist: {input_dir}", file=sys.stderr) return 2 if not input_dir.is_dir(): print(f"{repr(str(input_dir))} is not a directory.", file=sys.stderr) return 2 output_dir = Path(args.output_dir).resolve() print(f"Output directory: {output_dir}") if output_dir.exists() and not output_dir.is_dir(): print(f"{repr(str(output_dir))} is not a directory.", file=sys.stderr) return 2 output_dir.mkdir(parents=True, exist_ok=True) device_type, device_ids = parse_device(args.device) if device_ids is None or len(device_ids) == 1: print( "Please specify at least two devices for performing parallel inference.", file=sys.stderr, ) return 2 if args.batch_size <= 0: print("Batch size must be greater than 0.", file=sys.stderr) return 2 # 查找图像文件 image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"] image_files = [] for ext in image_extensions: image_files.extend(list(input_dir.glob(f"*{ext}"))) image_files.extend(list(input_dir.glob(f"*{ext.upper()}"))) print(f"Found {len(image_files)} image files") if args.test_mode: image_files = image_files[:20] print(f"Test mode: processing only {len(image_files)} images") with Manager() as manager: task_queue = manager.Queue() # 将图像文件路径放入队列 for img_path in image_files: task_queue.put(str(img_path)) processes = [] for device_id in device_ids: for _ in range(args.instances_per_device): device = constr_device(device_type, [device_id]) p = Process( target=worker, args=( args.pipeline, device, task_queue, args.batch_size, str(output_dir), ), ) p.start() processes.append(p) # 发送结束信号 for _ in range(len(device_ids) * args.instances_per_device): task_queue.put(None) for p in processes: p.join() print("All done") return 0 if __name__ == "__main__": sys.exit(main())