ppstructurev3_single_process.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. """单进程运行稳定"""
  2. import json
  3. import time
  4. import os
  5. import traceback
  6. import argparse
  7. import sys
  8. import warnings
  9. from pathlib import Path
  10. from typing import List, Dict, Any
  11. import cv2
  12. import numpy as np
  13. # 抑制特定警告
  14. warnings.filterwarnings("ignore", message="To copy construct from a tensor")
  15. warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
  16. warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
  17. from paddlex import create_pipeline
  18. from paddlex.utils.device import constr_device, parse_device
  19. from tqdm import tqdm
  20. from dotenv import load_dotenv
  21. load_dotenv(override=True)
  22. from utils import (
  23. get_image_files_from_dir,
  24. get_image_files_from_list,
  25. get_image_files_from_csv,
  26. collect_pid_files
  27. )
  28. def process_images_single_process(image_paths: List[str],
  29. pipeline_name: str = "PP-StructureV3",
  30. device: str = "gpu:0",
  31. output_dir: str = "./output") -> List[Dict[str, Any]]:
  32. """
  33. 单进程版本的图像处理函数
  34. Args:
  35. image_paths: 图像路径列表
  36. pipeline_name: Pipeline名称
  37. device: 设备字符串,如"gpu:0"或"cpu"
  38. output_dir: 输出目录
  39. Returns:
  40. 处理结果列表
  41. """
  42. # 创建输出目录
  43. output_path = Path(output_dir)
  44. output_path.mkdir(parents=True, exist_ok=True)
  45. print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
  46. try:
  47. # 设置环境变量以减少警告
  48. os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
  49. # 初始化pipeline
  50. pipeline = create_pipeline(pipeline_name, device=device)
  51. print(f"Pipeline initialized successfully on {device}")
  52. except Exception as e:
  53. print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
  54. traceback.print_exc()
  55. return []
  56. all_results = []
  57. total_images = len(image_paths)
  58. print(f"Processing {total_images} images one by one")
  59. # 使用tqdm显示进度
  60. with tqdm(total=total_images, desc="Processing images", unit="img",
  61. bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
  62. # 逐个处理图像
  63. for img_path in image_paths:
  64. start_time = time.time()
  65. try:
  66. # 使用pipeline预测单个图像
  67. results = pipeline.predict(
  68. img_path, # 传入单个文件路径
  69. use_doc_orientation_classify=True,
  70. use_doc_unwarping=False,
  71. use_seal_recognition=True,
  72. use_table_recognition=True,
  73. use_formula_recognition=False, # 暂时关闭公式识别以避免错误
  74. use_chart_recognition=True,
  75. )
  76. processing_time = time.time() - start_time
  77. # 处理结果
  78. for result in results:
  79. try:
  80. input_path = Path(result["input_path"])
  81. # 生成输出文件名
  82. if result.get("page_index") is not None:
  83. output_filename = f"{input_path.stem}_{result['page_index']}"
  84. else:
  85. output_filename = f"{input_path.stem}"
  86. # 保存JSON和Markdown文件
  87. json_output_path = str(Path(output_dir, f"{output_filename}.json"))
  88. md_output_path = str(Path(output_dir, f"{output_filename}.md"))
  89. result.save_to_json(json_output_path)
  90. result.save_to_markdown(md_output_path)
  91. # 记录处理结果
  92. all_results.append({
  93. "image_path": str(input_path),
  94. "processing_time": processing_time,
  95. "success": True,
  96. "device": device,
  97. "output_json": json_output_path,
  98. "output_md": md_output_path
  99. })
  100. except Exception as e:
  101. print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
  102. traceback.print_exc()
  103. all_results.append({
  104. "image_path": str(img_path),
  105. "processing_time": 0,
  106. "success": False,
  107. "device": device,
  108. "error": str(e)
  109. })
  110. # 更新进度条
  111. success_count = sum(1 for r in all_results if r.get('success', False))
  112. pbar.update(1)
  113. pbar.set_postfix({
  114. 'time': f"{processing_time:.2f}s",
  115. 'success': f"{success_count}/{len(all_results)}",
  116. 'rate': f"{success_count/len(all_results)*100:.1f}%"
  117. })
  118. except Exception as e:
  119. print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
  120. traceback.print_exc()
  121. # 添加错误结果
  122. all_results.append({
  123. "image_path": str(img_path),
  124. "processing_time": 0,
  125. "success": False,
  126. "device": device,
  127. "error": str(e)
  128. })
  129. pbar.update(1)
  130. return all_results
  131. def main():
  132. """主函数"""
  133. parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Single Process Processing")
  134. # 参数定义
  135. input_group = parser.add_mutually_exclusive_group(required=True)
  136. input_group.add_argument("--input_dir", type=str, help="Input directory")
  137. input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
  138. input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
  139. parser.add_argument("--output_dir", type=str, help="Output directory")
  140. parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
  141. parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
  142. parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
  143. parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 images)")
  144. parser.add_argument("--collect_results",type=str, help="收集处理结果到指定CSV文件")
  145. args = parser.parse_args()
  146. try:
  147. # 获取图像文件列表
  148. if args.input_csv:
  149. # 从CSV文件读取
  150. image_files = get_image_files_from_csv(args.input_csv, "fail")
  151. print(f"📊 Loaded {len(image_files)} files from CSV with status filter: fail")
  152. elif args.input_file_list:
  153. # 从文件列表读取
  154. image_files = get_image_files_from_list(args.input_file_list)
  155. else:
  156. # 从目录读取
  157. input_dir = Path(args.input_dir).resolve()
  158. print(f"📁 Input dir: {input_dir}")
  159. if not input_dir.exists():
  160. print(f"❌ Input directory does not exist: {input_dir}")
  161. return 1
  162. print(f"Input dir: {input_dir}")
  163. image_files = get_image_files_from_dir(input_dir)
  164. output_dir = Path(args.output_dir).resolve()
  165. print(f"Output dir: {output_dir}")
  166. print(f"Found {len(image_files)} image files")
  167. if args.test_mode:
  168. image_files = image_files[:20]
  169. print(f"Test mode: processing only {len(image_files)} images")
  170. print(f"Using device: {args.device}")
  171. # 开始处理(删除了 batch_size 参数)
  172. start_time = time.time()
  173. results = process_images_single_process(
  174. image_files,
  175. args.pipeline,
  176. args.device,
  177. str(output_dir)
  178. )
  179. total_time = time.time() - start_time
  180. # 统计结果
  181. success_count = sum(1 for r in results if r.get('success', False))
  182. error_count = len(results) - success_count
  183. print(f"\n" + "="*60)
  184. print(f"✅ Processing completed!")
  185. print(f"📊 Statistics:")
  186. print(f" Total files: {len(image_files)}")
  187. print(f" Successful: {success_count}")
  188. print(f" Failed: {error_count}")
  189. if len(image_files) > 0:
  190. print(f" Success rate: {success_count / len(image_files) * 100:.2f}%")
  191. print(f"⏱️ Performance:")
  192. print(f" Total time: {total_time:.2f} seconds")
  193. if total_time > 0:
  194. print(f" Throughput: {len(image_files) / total_time:.2f} images/second")
  195. print(f" Avg time per image: {total_time / len(image_files):.2f} seconds")
  196. # 保存结果统计(删除了 batch_size 统计)
  197. stats = {
  198. "total_files": len(image_files),
  199. "success_count": success_count,
  200. "error_count": error_count,
  201. "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
  202. "total_time": total_time,
  203. "throughput": len(image_files) / total_time if total_time > 0 else 0,
  204. "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
  205. "device": args.device,
  206. "pipeline": args.pipeline,
  207. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
  208. }
  209. # 保存最终结果
  210. output_file_name = Path(output_dir).name
  211. output_file = os.path.join(output_dir, f"{output_file_name}.json")
  212. final_results = {
  213. "stats": stats,
  214. "results": results
  215. }
  216. with open(output_file, 'w', encoding='utf-8') as f:
  217. json.dump(final_results, f, ensure_ascii=False, indent=2)
  218. print(f"💾 Results saved to: {output_file}")
  219. if args.collect_results:
  220. processed_files = collect_pid_files(output_file)
  221. output_file_processed = Path(args.collect_results).resolve()
  222. with open(output_file_processed, 'w', encoding='utf-8') as f:
  223. f.write("image_path,status\n")
  224. for file_path, status in processed_files:
  225. f.write(f"{file_path},{status}\n")
  226. print(f"💾 Processed files saved to: {output_file_processed}")
  227. return 0
  228. except Exception as e:
  229. print(f"❌ Processing failed: {e}", file=sys.stderr)
  230. traceback.print_exc()
  231. return 1
  232. if __name__ == "__main__":
  233. print(f"🚀 启动单进程OCR程序...")
  234. print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
  235. if len(sys.argv) == 1:
  236. # 如果没有命令行参数,使用默认配置运行
  237. print("ℹ️ No command line arguments provided. Running with default configuration...")
  238. # 默认配置(删除了 batch_size)
  239. default_config = {
  240. "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
  241. "output_dir": "./OmniDocBench_PPStructureV3_Results",
  242. "pipeline": "./my_config/PP-StructureV3.yaml",
  243. "device": "gpu:0",
  244. "collect_results": "./OmniDocBench_PPStructureV3_Results/processed_files.csv",
  245. }
  246. # default_config = {
  247. # "input_csv": "./OmniDocBench_PPStructureV3_Results/processed_files.csv",
  248. # "output_dir": "./OmniDocBench_PPStructureV3_Results",
  249. # "pipeline": "./my_config/PP-StructureV3.yaml",
  250. # "device": "gpu:0",
  251. # "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
  252. # }
  253. # 构造参数
  254. sys.argv = [sys.argv[0]]
  255. for key, value in default_config.items():
  256. sys.argv.extend([f"--{key}", str(value)])
  257. # 测试模式
  258. # sys.argv.append("--test_mode")
  259. sys.exit(main())