ppstructurev3_single_process.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import json
  2. import time
  3. import os
  4. import traceback
  5. import argparse
  6. import sys
  7. import warnings
  8. from pathlib import Path
  9. from typing import List, Dict, Any
  10. import cv2
  11. import numpy as np
  12. # 抑制特定警告
  13. warnings.filterwarnings("ignore", message="To copy construct from a tensor")
  14. warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
  15. warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
  16. from paddlex import create_pipeline
  17. from paddlex.utils.device import constr_device, parse_device
  18. from tqdm import tqdm
  19. from dotenv import load_dotenv
  20. load_dotenv(override=True)
  21. def process_images_single_process(image_paths: List[str],
  22. pipeline_name: str = "PP-StructureV3",
  23. device: str = "gpu:0",
  24. batch_size: int = 1,
  25. output_dir: str = "./output") -> List[Dict[str, Any]]:
  26. """
  27. 单进程版本的图像处理函数
  28. Args:
  29. image_paths: 图像路径列表
  30. pipeline_name: Pipeline名称
  31. device: 设备字符串,如"gpu:0"或"cpu"
  32. batch_size: 批处理大小
  33. output_dir: 输出目录
  34. Returns:
  35. 处理结果列表
  36. """
  37. # 创建输出目录
  38. output_path = Path(output_dir)
  39. output_path.mkdir(parents=True, exist_ok=True)
  40. print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
  41. try:
  42. # 设置环境变量以减少警告
  43. os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
  44. # 初始化pipeline
  45. pipeline = create_pipeline(pipeline_name, device=device)
  46. print(f"Pipeline initialized successfully on {device}")
  47. except Exception as e:
  48. print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
  49. traceback.print_exc()
  50. return []
  51. all_results = []
  52. total_images = len(image_paths)
  53. print(f"Processing {total_images} images with batch size {batch_size}")
  54. # 使用tqdm显示进度,添加更多统计信息
  55. with tqdm(total=total_images, desc="Processing images", unit="img",
  56. bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
  57. # 按批次处理图像
  58. for i in range(0, total_images, batch_size):
  59. batch = image_paths[i:i + batch_size]
  60. batch_start_time = time.time()
  61. try:
  62. # 使用pipeline预测
  63. results = pipeline.predict(
  64. batch,
  65. use_doc_orientation_classify=True,
  66. use_doc_unwarping=False,
  67. use_seal_recognition=True,
  68. use_chart_recognition=True,
  69. use_table_recognition=True,
  70. use_formula_recognition=True,
  71. )
  72. batch_processing_time = time.time() - batch_start_time
  73. batch_results = []
  74. # 处理每个结果
  75. for result in results:
  76. try:
  77. input_path = Path(result["input_path"])
  78. # 生成输出文件名
  79. if result.get("page_index") is not None:
  80. output_filename = f"{input_path.stem}_{result['page_index']}"
  81. else:
  82. output_filename = f"{input_path.stem}"
  83. # 保存JSON和Markdown文件
  84. json_output_path = str(Path(output_dir, f"{output_filename}.json"))
  85. md_output_path = str(Path(output_dir, f"{output_filename}.md"))
  86. result.save_to_json(json_output_path)
  87. result.save_to_markdown(md_output_path)
  88. # 记录处理结果
  89. batch_results.append({
  90. "image_path": input_path.name,
  91. "processing_time": batch_processing_time / len(batch), # 平均时间
  92. "success": True,
  93. "device": device,
  94. "output_json": json_output_path,
  95. "output_md": md_output_path
  96. })
  97. except Exception as e:
  98. print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
  99. traceback.print_exc()
  100. batch_results.append({
  101. "image_path": Path(result["input_path"]).name,
  102. "processing_time": 0,
  103. "success": False,
  104. "device": device,
  105. "error": str(e)
  106. })
  107. all_results.extend(batch_results)
  108. # 更新进度条
  109. success_count = sum(1 for r in batch_results if r.get('success', False))
  110. total_success = sum(1 for r in all_results if r.get('success', False))
  111. avg_time = batch_processing_time / len(batch)
  112. pbar.update(len(batch))
  113. pbar.set_postfix({
  114. 'batch_time': f"{batch_processing_time:.2f}s",
  115. 'avg_time': f"{avg_time:.2f}s/img",
  116. 'success': f"{total_success}/{len(all_results)}",
  117. 'rate': f"{total_success/len(all_results)*100:.1f}%"
  118. })
  119. except Exception as e:
  120. print(f"Error processing batch {[Path(p).name for p in batch]}: {e}", file=sys.stderr)
  121. traceback.print_exc()
  122. # 为批次中的所有图像添加错误结果
  123. error_results = []
  124. for img_path in batch:
  125. error_results.append({
  126. "image_path": Path(img_path).name,
  127. "processing_time": 0,
  128. "success": False,
  129. "device": device,
  130. "error": str(e)
  131. })
  132. all_results.extend(error_results)
  133. pbar.update(len(batch))
  134. return all_results
  135. def main():
  136. """主函数"""
  137. parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Single Process Processing")
  138. # 参数定义
  139. parser.add_argument("--input_dir", type=str, default="../../OmniDocBench/OpenDataLab___OmniDocBench/images", help="Input directory")
  140. parser.add_argument("--input_file_list", type=str, default=None, help="Input file list (one file per line)") # 新增
  141. parser.add_argument("--output_dir", type=str, default="./OmniDocBench_Results_Single", help="Output directory")
  142. parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
  143. parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
  144. parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
  145. parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
  146. parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 images)")
  147. args = parser.parse_args()
  148. try:
  149. # 获取图像文件列表
  150. if args.input_file_list:
  151. # 从文件列表读取
  152. print(f"Reading file list from: {args.input_file_list}")
  153. with open(args.input_file_list, 'r', encoding='utf-8') as f:
  154. image_files = [line.strip() for line in f if line.strip()]
  155. # 验证文件存在
  156. valid_files = []
  157. for file_path in image_files:
  158. if Path(file_path).exists():
  159. valid_files.append(file_path)
  160. else:
  161. print(f"Warning: File not found: {file_path}")
  162. image_files = valid_files
  163. else:
  164. # 从目录读取(原有逻辑)
  165. input_dir = Path(args.input_dir).resolve()
  166. output_dir = Path(args.output_dir).resolve()
  167. print(f"Input dir: {input_dir}")
  168. if not input_dir.exists():
  169. print(f"Input directory does not exist: {input_dir}")
  170. return 1
  171. # 查找图像文件
  172. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
  173. image_files = []
  174. for ext in image_extensions:
  175. image_files.extend(list(input_dir.glob(f"*{ext}")))
  176. image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
  177. if not image_files:
  178. print(f"No image files found in {input_dir}")
  179. return 1
  180. # 去重并排序
  181. image_files = sorted(list(set(str(f) for f in image_files)))
  182. output_dir = Path(args.output_dir).resolve()
  183. print(f"Output dir: {output_dir}")
  184. print(f"Found {len(image_files)} image files")
  185. if args.test_mode:
  186. image_files = image_files[:20]
  187. print(f"Test mode: processing only {len(image_files)} images")
  188. # 验证设备
  189. if args.device.startswith('gpu'):
  190. try:
  191. import paddle
  192. if not paddle.device.is_compiled_with_cuda():
  193. print("GPU requested but CUDA not available, falling back to CPU")
  194. args.device = "cpu"
  195. else:
  196. gpu_count = paddle.device.cuda.device_count()
  197. device_id = int(args.device.split(':')[1]) if ':' in args.device else 0
  198. if device_id >= gpu_count:
  199. print(f"GPU {device_id} not available (only {gpu_count} GPUs), falling back to GPU 0")
  200. args.device = "gpu:0"
  201. # 显示GPU信息
  202. if args.verbose:
  203. for i in range(gpu_count):
  204. props = paddle.device.cuda.get_device_properties(i)
  205. print(f"GPU {i}: {props.name} - {props.total_memory // 1024**3}GB")
  206. except Exception as e:
  207. print(f"Error checking GPU availability: {e}, falling back to CPU")
  208. args.device = "cpu"
  209. print(f"Using device: {args.device}")
  210. print(f"Batch size: {args.batch_size}")
  211. # 开始处理
  212. start_time = time.time()
  213. results = process_images_single_process(
  214. image_files,
  215. args.pipeline,
  216. args.device,
  217. args.batch_size,
  218. str(output_dir)
  219. )
  220. total_time = time.time() - start_time
  221. # 统计结果
  222. success_count = sum(1 for r in results if r.get('success', False))
  223. error_count = len(results) - success_count
  224. print(f"\n" + "="*60)
  225. print(f"✅ Processing completed!")
  226. print(f"📊 Statistics:")
  227. print(f" Total files: {len(image_files)}")
  228. print(f" Successful: {success_count}")
  229. print(f" Failed: {error_count}")
  230. if len(image_files) > 0:
  231. print(f" Success rate: {success_count / len(image_files) * 100:.2f}%")
  232. print(f"⏱️ Performance:")
  233. print(f" Total time: {total_time:.2f} seconds")
  234. if total_time > 0:
  235. print(f" Throughput: {len(image_files) / total_time:.2f} images/second")
  236. print(f" Avg time per image: {total_time / len(image_files):.2f} seconds")
  237. # 保存结果统计
  238. stats = {
  239. "total_files": len(image_files),
  240. "success_count": success_count,
  241. "error_count": error_count,
  242. "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
  243. "total_time": total_time,
  244. "throughput": len(image_files) / total_time if total_time > 0 else 0,
  245. "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
  246. "batch_size": args.batch_size,
  247. "device": args.device,
  248. "pipeline": args.pipeline,
  249. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
  250. }
  251. # 保存最终结果
  252. output_file = os.path.join(output_dir, f"OmniDocBench_Single_batch{args.batch_size}.json")
  253. final_results = {
  254. "stats": stats,
  255. "results": results
  256. }
  257. with open(output_file, 'w', encoding='utf-8') as f:
  258. json.dump(final_results, f, ensure_ascii=False, indent=2)
  259. print(f"💾 Results saved to: {output_file}")
  260. return 0
  261. except Exception as e:
  262. print(f"❌ Processing failed: {e}", file=sys.stderr)
  263. traceback.print_exc()
  264. return 1
  265. if __name__ == "__main__":
  266. print(f"🚀 启动单进程OCR程序...")
  267. print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
  268. if len(sys.argv) == 1:
  269. # 如果没有命令行参数,使用默认配置运行
  270. print("ℹ️ No command line arguments provided. Running with default configuration...")
  271. # 默认配置
  272. default_config = {
  273. "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
  274. "output_dir": "./OmniDocBench_Results_Single",
  275. "pipeline": "PP-StructureV3",
  276. "device": "gpu:0",
  277. "batch_size": 4,
  278. }
  279. # 构造参数
  280. sys.argv = [sys.argv[0]]
  281. for key, value in default_config.items():
  282. sys.argv.extend([f"--{key}", str(value)])
  283. # 测试模式
  284. sys.argv.append("--test_mode")
  285. sys.exit(main())