ppstructurev3_single_process.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. """PDF转图像后统一处理"""
  2. import json
  3. import time
  4. import os
  5. import traceback
  6. import argparse
  7. import sys
  8. import warnings
  9. from pathlib import Path
  10. from typing import List, Dict, Any, Union
  11. import cv2
  12. import numpy as np
  13. # 抑制特定警告
  14. warnings.filterwarnings("ignore", message="To copy construct from a tensor")
  15. warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
  16. warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
  17. from paddlex import create_pipeline
  18. from paddlex.utils.device import constr_device, parse_device
  19. from tqdm import tqdm
  20. from dotenv import load_dotenv
  21. load_dotenv(override=True)
  22. from utils import (
  23. collect_pid_files,
  24. get_input_files,
  25. )
  26. from ppstructurev3_utils import (
  27. convert_pruned_result_to_json,
  28. save_output_images,
  29. save_markdown_content
  30. )
  31. # 🎯 新增:导入适配器
  32. from adapters import apply_table_recognition_adapter, restore_original_function
  33. def process_images_unified(image_paths: List[str],
  34. pipeline_name: str = "PP-StructureV3",
  35. device: str = "gpu:0",
  36. output_dir: str = "./output",
  37. normalize_numbers: bool = True,
  38. use_enhanced_adapter: bool = True,
  39. **kwargs) -> List[Dict[str, Any]]: # 🎯 新增 **kwargs
  40. """
  41. 统一的图像处理函数,支持数字标准化和多种 pipeline
  42. """
  43. # 创建输出目录
  44. output_path = Path(output_dir)
  45. output_path.mkdir(parents=True, exist_ok=True)
  46. # 🎯 应用适配器
  47. adapter_applied = False
  48. if use_enhanced_adapter:
  49. adapter_applied = apply_table_recognition_adapter()
  50. if adapter_applied:
  51. print("🎯 Enhanced table recognition adapter activated")
  52. else:
  53. print("⚠️ Failed to apply adapter, using original implementation")
  54. print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
  55. try:
  56. # 设置环境变量以减少警告
  57. os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
  58. # 初始化pipeline
  59. pipeline = create_pipeline(pipeline_name, device=device)
  60. print(f"Pipeline initialized successfully on {device}")
  61. except Exception as e:
  62. print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
  63. traceback.print_exc()
  64. if adapter_applied:
  65. restore_original_function()
  66. return []
  67. try:
  68. all_results = []
  69. total_images = len(image_paths)
  70. print(f"Processing {total_images} images one by one")
  71. print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
  72. print(f"🎯 增强适配器: {'启用' if adapter_applied else '禁用'}")
  73. # 🎯 检测 pipeline 类型
  74. is_paddleocr_vl = 'PaddleOCR-VL'.lower() in str(pipeline_name).lower()
  75. # 使用tqdm显示进度
  76. with tqdm(total=total_images, desc="Processing images", unit="img",
  77. bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
  78. # 逐个处理图像
  79. for img_path in image_paths:
  80. start_time = time.time()
  81. try:
  82. # 🎯 根据 pipeline 类型使用不同的参数
  83. if is_paddleocr_vl:
  84. # PaddleOCR-VL 使用驼峰命名
  85. predict_kwargs = {
  86. 'input': img_path,
  87. 'useLayoutDetection': kwargs.get('use_layout_detection', False),
  88. 'useDocOrientationClassify': kwargs.get('use_doc_orientation', False),
  89. 'useDocUnwarping': kwargs.get('use_doc_unwarping', False),
  90. }
  91. else:
  92. # PP-StructureV3 使用下划线命名
  93. predict_kwargs = {
  94. 'img_path': img_path,
  95. 'use_doc_orientation_classify': kwargs.get('use_doc_orientation', False), # 流水分析场景关闭方向分类
  96. 'use_doc_unwarping': kwargs.get('use_doc_unwarping', False),
  97. 'use_layout_detection': kwargs.get('use_layout_detection', True),
  98. 'use_seal_recognition': kwargs.get('use_seal_recognition', True),
  99. 'use_table_recognition': kwargs.get('use_table_recognition', True),
  100. 'use_formula_recognition': kwargs.get('use_formula_recognition', False),
  101. 'use_chart_recognition': kwargs.get('use_chart_recognition', True),
  102. 'use_ocr_results_with_table_cells': kwargs.get('use_ocr_results_with_table_cells', True),
  103. 'use_table_orientation_classify': kwargs.get('use_table_orientation_classify', False),
  104. 'use_wired_table_cells_trans_to_html': kwargs.get('use_wired_table_cells_trans_to_html', True),
  105. 'use_wireless_table_cells_trans_to_html': kwargs.get('use_wireless_table_cells_trans_to_html', True),
  106. }
  107. # 使用pipeline预测
  108. results = pipeline.predict(**predict_kwargs)
  109. processing_time = time.time() - start_time
  110. # 处理结果
  111. for idx, result in enumerate(results):
  112. if idx > 0:
  113. raise ValueError("Multiple results found for a single image")
  114. try:
  115. input_path = Path(result["input_path"])
  116. # 生成输出文件名
  117. if result.get("page_index") is not None:
  118. output_filename = f"{input_path.stem}_{result['page_index']}"
  119. else:
  120. output_filename = f"{input_path.stem}"
  121. # 转换并保存标准JSON格式
  122. json_content = result.json['res']
  123. json_output_path, converted_json = convert_pruned_result_to_json(
  124. json_content,
  125. str(input_path),
  126. output_dir,
  127. output_filename,
  128. normalize_numbers=normalize_numbers
  129. )
  130. # 保存输出图像
  131. img_content = result.img
  132. saved_images = save_output_images(img_content, str(output_dir), output_filename)
  133. # 保存Markdown内容
  134. markdown_content = result.markdown
  135. md_output_path = save_markdown_content(
  136. markdown_content,
  137. output_dir,
  138. output_filename,
  139. normalize_numbers=normalize_numbers,
  140. key_text='markdown_texts',
  141. key_images='markdown_images',
  142. json_data=converted_json # 🎯 新增参数
  143. )
  144. # 记录处理结果
  145. all_results.append({
  146. "image_path": str(input_path),
  147. "processing_time": processing_time,
  148. "success": True,
  149. "device": device,
  150. "output_json": json_output_path,
  151. "output_md": md_output_path,
  152. "is_pdf_page": "_page_" in input_path.name,
  153. "processing_info": converted_json.get('processing_info', {})
  154. })
  155. except Exception as e:
  156. print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
  157. traceback.print_exc()
  158. all_results.append({
  159. "image_path": str(img_path),
  160. "processing_time": 0,
  161. "success": False,
  162. "device": device,
  163. "error": str(e)
  164. })
  165. # 更新进度条
  166. success_count = sum(1 for r in all_results if r.get('success', False))
  167. pbar.update(1)
  168. pbar.set_postfix({
  169. 'time': f"{processing_time:.2f}s",
  170. 'success': f"{success_count}/{len(all_results)}",
  171. 'rate': f"{success_count/len(all_results)*100:.1f}%"
  172. })
  173. except Exception as e:
  174. print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
  175. traceback.print_exc()
  176. # 添加错误结果
  177. all_results.append({
  178. "image_path": str(img_path),
  179. "processing_time": 0,
  180. "success": False,
  181. "device": device,
  182. "error": str(e)
  183. })
  184. pbar.update(1)
  185. return all_results
  186. finally:
  187. # 🎯 清理:恢复原始函数
  188. if adapter_applied:
  189. restore_original_function()
  190. print("🔄 Original function restored")
  191. def main():
  192. """主函数"""
  193. parser = argparse.ArgumentParser(description="PaddleX Unified PDF/Image Processor")
  194. # 参数定义
  195. input_group = parser.add_mutually_exclusive_group(required=True)
  196. input_group.add_argument("--input_file", type=str, help="Input file (supports both PDF and image file)")
  197. input_group.add_argument("--input_dir", type=str, help="Input directory (supports both PDF and image files)")
  198. input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
  199. input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
  200. parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
  201. parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
  202. parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
  203. parser.add_argument("--pdf_dpi", type=int, default=200, help="DPI for PDF to image conversion")
  204. parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化")
  205. parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
  206. parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
  207. parser.add_argument("--no-adapter", action="store_true", help="禁用增强适配器") # 🎯 新增参数
  208. args = parser.parse_args()
  209. normalize_numbers = not args.no_normalize
  210. use_enhanced_adapter = not args.no_adapter
  211. # 🎯 构建 predict 参数
  212. predict_kwargs = {}
  213. try:
  214. # 获取并预处理输入文件
  215. print("🔄 Preprocessing input files...")
  216. input_files = get_input_files(args)
  217. if not input_files:
  218. print("❌ No input files found or processed")
  219. return 1
  220. if args.test_mode:
  221. input_files = input_files[:20]
  222. print(f"Test mode: processing only {len(input_files)} images")
  223. print(f"Using device: {args.device}")
  224. # 开始处理
  225. start_time = time.time()
  226. results = process_images_unified(
  227. input_files,
  228. args.pipeline,
  229. args.device,
  230. args.output_dir,
  231. normalize_numbers=normalize_numbers,
  232. use_enhanced_adapter=use_enhanced_adapter,
  233. **predict_kwargs # 🎯 传递所有参数
  234. )
  235. total_time = time.time() - start_time
  236. # 统计结果
  237. success_count = sum(1 for r in results if r.get('success', False))
  238. error_count = len(results) - success_count
  239. pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
  240. total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
  241. print(f"\n" + "="*60)
  242. print(f"✅ Processing completed!")
  243. print(f"📊 Statistics:")
  244. print(f" Total files processed: {len(input_files)}")
  245. print(f" PDF pages processed: {pdf_page_count}")
  246. print(f" Regular images processed: {len(input_files) - pdf_page_count}")
  247. print(f" Successful: {success_count}")
  248. print(f" Failed: {error_count}")
  249. if len(input_files) > 0:
  250. print(f" Success rate: {success_count / len(input_files) * 100:.2f}%")
  251. if normalize_numbers:
  252. print(f" 总标准化字符数: {total_changes}")
  253. print(f"⏱️ Performance:")
  254. print(f" Total time: {total_time:.2f} seconds")
  255. if total_time > 0:
  256. print(f" Throughput: {len(input_files) / total_time:.2f} files/second")
  257. print(f" Avg time per file: {total_time / len(input_files):.2f} seconds")
  258. # 保存结果统计
  259. stats = {
  260. "total_files": len(input_files),
  261. "pdf_pages": pdf_page_count,
  262. "regular_images": len(input_files) - pdf_page_count,
  263. "success_count": success_count,
  264. "error_count": error_count,
  265. "success_rate": success_count / len(input_files) if len(input_files) > 0 else 0,
  266. "total_time": total_time,
  267. "throughput": len(input_files) / total_time if total_time > 0 else 0,
  268. "avg_time_per_file": total_time / len(input_files) if len(input_files) > 0 else 0,
  269. "device": args.device,
  270. "pipeline": args.pipeline,
  271. "pdf_dpi": args.pdf_dpi,
  272. "normalize_numbers": normalize_numbers,
  273. "total_character_changes": total_changes,
  274. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
  275. }
  276. # 保存最终结果
  277. output_file_name = Path(args.output_dir).name
  278. output_file = os.path.join(args.output_dir, f"{output_file_name}_unified.json")
  279. final_results = {
  280. "stats": stats,
  281. "results": results
  282. }
  283. with open(output_file, 'w', encoding='utf-8') as f:
  284. json.dump(final_results, f, ensure_ascii=False, indent=2)
  285. print(f"💾 Results saved to: {output_file}")
  286. # 如果没有收集结果的路径,使用缺省文件名,和output_dir同一路径
  287. if not args.collect_results:
  288. output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
  289. else:
  290. output_file_processed = Path(args.collect_results).resolve()
  291. processed_files = collect_pid_files(output_file)
  292. with open(output_file_processed, 'w', encoding='utf-8') as f:
  293. f.write("image_path,status\n")
  294. for file_path, status in processed_files:
  295. f.write(f"{file_path},{status}\n")
  296. print(f"💾 Processed files saved to: {output_file_processed}")
  297. return 0
  298. except Exception as e:
  299. print(f"❌ Processing failed: {e}", file=sys.stderr)
  300. traceback.print_exc()
  301. return 1
  302. if __name__ == "__main__":
  303. print(f"🚀 启动统一PDF/图像处理程序...")
  304. print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
  305. if len(sys.argv) == 1:
  306. # 如果没有命令行参数,使用默认配置运行
  307. print("ℹ️ No command line arguments provided. Running with default configuration...")
  308. # 默认配置
  309. default_config = {
  310. # "input_file": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/2023年度报告母公司.pdf",
  311. "input_file": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/PPStructureV3_Results/2023年度报告母公司/2023年度报告母公司_page_027.png",
  312. "output_dir": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/PPStructureV3_Results",
  313. "collect_results": f"/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
  314. # "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
  315. # "output_dir": "./OmniDocBench_PPStructureV3_Results",
  316. # "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
  317. "pipeline": "./my_config/PP-StructureV3.yaml",
  318. "device": "gpu",
  319. }
  320. # 构造参数
  321. sys.argv = [sys.argv[0]]
  322. for key, value in default_config.items():
  323. sys.argv.extend([f"--{key}", str(value)])
  324. # 可以添加禁用标准化选项
  325. # sys.argv.append("--no-normalize")
  326. # 测试模式
  327. # sys.argv.append("--test_mode")
  328. sys.exit(main())