ppstructurev3_single_process.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """单进程运行稳定"""
  2. import json
  3. import time
  4. import os
  5. import traceback
  6. import argparse
  7. import sys
  8. import warnings
  9. from pathlib import Path
  10. from typing import List, Dict, Any
  11. import cv2
  12. import numpy as np
  13. # 抑制特定警告
  14. warnings.filterwarnings("ignore", message="To copy construct from a tensor")
  15. warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
  16. warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
  17. from paddlex import create_pipeline
  18. from paddlex.utils.device import constr_device, parse_device
  19. from tqdm import tqdm
  20. from dotenv import load_dotenv
  21. load_dotenv(override=True)
  22. from utils import (
  23. get_image_files_from_dir,
  24. get_image_files_from_list,
  25. get_image_files_from_csv,
  26. collect_pid_files
  27. )
  28. def process_images_single_process(image_paths: List[str],
  29. pipeline_name: str = "PP-StructureV3",
  30. device: str = "gpu:0",
  31. batch_size: int = 1,
  32. output_dir: str = "./output") -> List[Dict[str, Any]]:
  33. """
  34. 单进程版本的图像处理函数
  35. Args:
  36. image_paths: 图像路径列表
  37. pipeline_name: Pipeline名称
  38. device: 设备字符串,如"gpu:0"或"cpu"
  39. batch_size: 批处理大小
  40. output_dir: 输出目录
  41. Returns:
  42. 处理结果列表
  43. """
  44. # 创建输出目录
  45. output_path = Path(output_dir)
  46. output_path.mkdir(parents=True, exist_ok=True)
  47. print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
  48. try:
  49. # 设置环境变量以减少警告
  50. os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
  51. # 初始化pipeline
  52. pipeline = create_pipeline(pipeline_name, device=device)
  53. print(f"Pipeline initialized successfully on {device}")
  54. except Exception as e:
  55. print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
  56. traceback.print_exc()
  57. return []
  58. all_results = []
  59. total_images = len(image_paths)
  60. print(f"Processing {total_images} images with batch size {batch_size}")
  61. # 使用tqdm显示进度,添加更多统计信息
  62. with tqdm(total=total_images, desc="Processing images", unit="img",
  63. bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
  64. # 按批次处理图像
  65. for i in range(0, total_images, batch_size):
  66. batch = image_paths[i:i + batch_size]
  67. batch_start_time = time.time()
  68. try:
  69. # 使用pipeline预测
  70. results = pipeline.predict(
  71. batch,
  72. use_doc_orientation_classify=True,
  73. use_doc_unwarping=False,
  74. use_seal_recognition=True,
  75. use_chart_recognition=True,
  76. use_table_recognition=True,
  77. use_formula_recognition=True,
  78. )
  79. batch_processing_time = time.time() - batch_start_time
  80. batch_results = []
  81. # 处理每个结果
  82. for result in results:
  83. try:
  84. input_path = Path(result["input_path"])
  85. # 生成输出文件名
  86. if result.get("page_index") is not None:
  87. output_filename = f"{input_path.stem}_{result['page_index']}"
  88. else:
  89. output_filename = f"{input_path.stem}"
  90. # 保存JSON和Markdown文件
  91. json_output_path = str(Path(output_dir, f"{output_filename}.json"))
  92. md_output_path = str(Path(output_dir, f"{output_filename}.md"))
  93. result.save_to_json(json_output_path)
  94. result.save_to_markdown(md_output_path)
  95. # 记录处理结果
  96. batch_results.append({
  97. "image_path": str(input_path),
  98. "processing_time": batch_processing_time / len(batch), # 平均时间
  99. "success": True,
  100. "device": device,
  101. "output_json": json_output_path,
  102. "output_md": md_output_path
  103. })
  104. except Exception as e:
  105. print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
  106. traceback.print_exc()
  107. batch_results.append({
  108. "image_path": str(input_path),
  109. "processing_time": 0,
  110. "success": False,
  111. "device": device,
  112. "error": str(e)
  113. })
  114. all_results.extend(batch_results)
  115. # 更新进度条
  116. success_count = sum(1 for r in batch_results if r.get('success', False))
  117. total_success = sum(1 for r in all_results if r.get('success', False))
  118. avg_time = batch_processing_time / len(batch)
  119. pbar.update(len(batch))
  120. pbar.set_postfix({
  121. 'batch_time': f"{batch_processing_time:.2f}s",
  122. 'avg_time': f"{avg_time:.2f}s/img",
  123. 'success': f"{total_success}/{len(all_results)}",
  124. 'rate': f"{total_success/len(all_results)*100:.1f}%"
  125. })
  126. except Exception as e:
  127. print(f"Error processing batch {[Path(p).name for p in batch]}: {e}", file=sys.stderr)
  128. traceback.print_exc()
  129. # 为批次中的所有图像添加错误结果
  130. error_results = []
  131. for img_path in batch:
  132. error_results.append({
  133. "image_path": str(img_path),
  134. "processing_time": 0,
  135. "success": False,
  136. "device": device,
  137. "error": str(e)
  138. })
  139. all_results.extend(error_results)
  140. pbar.update(len(batch))
  141. return all_results
  142. def main():
  143. """主函数"""
  144. parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Single Process Processing")
  145. # 参数定义
  146. input_group = parser.add_mutually_exclusive_group(required=True)
  147. input_group.add_argument("--input_dir", type=str, help="Input directory")
  148. input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
  149. input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
  150. parser.add_argument("--output_dir", type=str, help="Output directory")
  151. parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
  152. parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
  153. parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
  154. parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
  155. parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 images)")
  156. parser.add_argument("--collect_results",type=str, help="收集处理结果到指定CSV文件")
  157. args = parser.parse_args()
  158. try:
  159. # 获取图像文件列表
  160. if args.input_csv:
  161. # 从CSV文件读取
  162. image_files = get_image_files_from_csv(args.input_csv, "fail")
  163. print(f"📊 Loaded {len(image_files)} files from CSV with status filter: fail")
  164. elif args.input_file_list:
  165. # 从文件列表读取
  166. image_files = get_image_files_from_list(args.input_file_list)
  167. else:
  168. # 从目录读取
  169. input_dir = Path(args.input_dir).resolve()
  170. print(f"📁 Input dir: {input_dir}")
  171. if not input_dir.exists():
  172. print(f"❌ Input directory does not exist: {input_dir}")
  173. return 1
  174. print(f"Input dir: {input_dir}")
  175. image_files = get_image_files_from_dir(input_dir)
  176. output_dir = Path(args.output_dir).resolve()
  177. print(f"Output dir: {output_dir}")
  178. print(f"Found {len(image_files)} image files")
  179. if args.test_mode:
  180. image_files = image_files[:20]
  181. print(f"Test mode: processing only {len(image_files)} images")
  182. print(f"Using device: {args.device}")
  183. print(f"Batch size: {args.batch_size}")
  184. # 开始处理
  185. start_time = time.time()
  186. results = process_images_single_process(
  187. image_files,
  188. args.pipeline,
  189. args.device,
  190. args.batch_size,
  191. str(output_dir)
  192. )
  193. total_time = time.time() - start_time
  194. # 统计结果
  195. success_count = sum(1 for r in results if r.get('success', False))
  196. error_count = len(results) - success_count
  197. print(f"\n" + "="*60)
  198. print(f"✅ Processing completed!")
  199. print(f"📊 Statistics:")
  200. print(f" Total files: {len(image_files)}")
  201. print(f" Successful: {success_count}")
  202. print(f" Failed: {error_count}")
  203. if len(image_files) > 0:
  204. print(f" Success rate: {success_count / len(image_files) * 100:.2f}%")
  205. print(f"⏱️ Performance:")
  206. print(f" Total time: {total_time:.2f} seconds")
  207. if total_time > 0:
  208. print(f" Throughput: {len(image_files) / total_time:.2f} images/second")
  209. print(f" Avg time per image: {total_time / len(image_files):.2f} seconds")
  210. # 保存结果统计
  211. stats = {
  212. "total_files": len(image_files),
  213. "success_count": success_count,
  214. "error_count": error_count,
  215. "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
  216. "total_time": total_time,
  217. "throughput": len(image_files) / total_time if total_time > 0 else 0,
  218. "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
  219. "batch_size": args.batch_size,
  220. "device": args.device,
  221. "pipeline": args.pipeline,
  222. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
  223. }
  224. # 保存最终结果
  225. output_file_name = Path(output_dir).name
  226. output_file = os.path.join(output_dir, f"{output_file_name}.json")
  227. final_results = {
  228. "stats": stats,
  229. "results": results
  230. }
  231. with open(output_file, 'w', encoding='utf-8') as f:
  232. json.dump(final_results, f, ensure_ascii=False, indent=2)
  233. print(f"💾 Results saved to: {output_file}")
  234. if args.collect_results:
  235. processed_files = collect_pid_files(output_file)
  236. output_file_processed = Path(args.collect_results).resolve()
  237. with open(output_file_processed, 'w', encoding='utf-8') as f:
  238. f.write("image_path,status\n")
  239. for file_path, status in processed_files:
  240. f.write(f"{file_path},{status}\n")
  241. print(f"💾 Processed files saved to: {output_file_processed}")
  242. return 0
  243. except Exception as e:
  244. print(f"❌ Processing failed: {e}", file=sys.stderr)
  245. traceback.print_exc()
  246. return 1
  247. if __name__ == "__main__":
  248. print(f"🚀 启动单进程OCR程序...")
  249. print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
  250. if len(sys.argv) == 1:
  251. # 如果没有命令行参数,使用默认配置运行
  252. print("ℹ️ No command line arguments provided. Running with default configuration...")
  253. # 默认配置
  254. default_config = {
  255. "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
  256. "output_dir": "./OmniDocBench_PPStructureV3_Results",
  257. "pipeline": "PP-StructureV3",
  258. "device": "gpu:0",
  259. "batch_size": 2,
  260. "collect_results": "./OmniDocBench_PPStructureV3_Results/processed_files.csv",
  261. }
  262. # default_config = {
  263. # "input_csv": "./OmniDocBench_PPStructureV3_Results/processed_files.csv",
  264. # "output_dir": "./OmniDocBench_PPStructureV3_Results",
  265. # "pipeline": "PP-StructureV3",
  266. # "device": "gpu:0",
  267. # "batch_size": 2,
  268. # "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
  269. # }
  270. # 构造参数
  271. sys.argv = [sys.argv[0]]
  272. for key, value in default_config.items():
  273. sys.argv.extend([f"--{key}", str(value)])
  274. # 测试模式
  275. # sys.argv.append("--test_mode")
  276. sys.exit(main())