ppstructurev3_single_process.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. """PDF转图像后统一处理"""
  2. import json
  3. import time
  4. import os
  5. import traceback
  6. import argparse
  7. import sys
  8. import warnings
  9. from pathlib import Path
  10. from typing import List, Dict, Any, Union
  11. import cv2
  12. import numpy as np
  13. # 抑制特定警告
  14. warnings.filterwarnings("ignore", message="To copy construct from a tensor")
  15. warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
  16. warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
  17. from paddlex import create_pipeline
  18. from paddlex.utils.device import constr_device, parse_device
  19. from tqdm import tqdm
  20. from dotenv import load_dotenv
  21. load_dotenv(override=True)
  22. from utils import (
  23. get_image_files_from_dir,
  24. get_image_files_from_list,
  25. get_image_files_from_csv,
  26. collect_pid_files,
  27. load_images_from_pdf,
  28. normalize_financial_numbers,
  29. normalize_markdown_table
  30. )
  31. def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
  32. """
  33. 将PDF转换为图像文件
  34. Args:
  35. pdf_file: PDF文件路径
  36. output_dir: 输出目录
  37. dpi: 图像分辨率
  38. Returns:
  39. 生成的图像文件路径列表
  40. """
  41. pdf_path = Path(pdf_file)
  42. if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
  43. print(f"❌ Invalid PDF file: {pdf_path}")
  44. return []
  45. # 如果没有指定输出目录,使用PDF同名目录
  46. if output_dir is None:
  47. output_path = pdf_path.parent / f"{pdf_path.stem}"
  48. else:
  49. output_path = Path(output_dir) / f"{pdf_path.stem}"
  50. output_path = output_path.resolve()
  51. output_path.mkdir(parents=True, exist_ok=True)
  52. try:
  53. # 使用doc_utils中的函数加载PDF图像
  54. images = load_images_from_pdf(str(pdf_path), dpi=dpi)
  55. image_paths = []
  56. for i, image in enumerate(images):
  57. # 生成图像文件名
  58. image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
  59. image_path = output_path / image_filename
  60. # 保存图像
  61. image.save(str(image_path))
  62. image_paths.append(str(image_path))
  63. print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
  64. return image_paths
  65. except Exception as e:
  66. print(f"❌ Error converting PDF {pdf_path}: {e}")
  67. traceback.print_exc()
  68. return []
  69. def get_input_files(args) -> List[str]:
  70. """
  71. 获取输入文件列表,统一处理PDF和图像文件
  72. Args:
  73. args: 命令行参数
  74. Returns:
  75. 处理后的图像文件路径列表
  76. """
  77. input_files = []
  78. # 获取原始输入文件
  79. if args.input_csv:
  80. raw_files = get_image_files_from_csv(args.input_csv, "fail")
  81. elif args.input_file_list:
  82. raw_files = get_image_files_from_list(args.input_file_list)
  83. elif args.input_file:
  84. raw_files = [Path(args.input_file).resolve()]
  85. else:
  86. input_dir = Path(args.input_dir).resolve()
  87. if not input_dir.exists():
  88. print(f"❌ Input directory does not exist: {input_dir}")
  89. return []
  90. # 获取所有支持的文件(图像和PDF)
  91. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
  92. pdf_extensions = ['.pdf']
  93. raw_files = []
  94. for ext in image_extensions + pdf_extensions:
  95. raw_files.extend(list(input_dir.glob(f"*{ext}")))
  96. raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
  97. raw_files = [str(f) for f in raw_files]
  98. # 分别处理PDF和图像文件
  99. pdf_count = 0
  100. image_count = 0
  101. for file_path in raw_files:
  102. file_path = Path(file_path)
  103. if file_path.suffix.lower() == '.pdf':
  104. # 转换PDF为图像
  105. print(f"📄 Processing PDF: {file_path.name}")
  106. pdf_images = convert_pdf_to_images(
  107. str(file_path),
  108. args.output_dir,
  109. dpi=args.pdf_dpi
  110. )
  111. input_files.extend(pdf_images)
  112. pdf_count += 1
  113. else:
  114. # 直接添加图像文件
  115. if file_path.exists():
  116. input_files.append(str(file_path))
  117. image_count += 1
  118. print(f"📊 Input summary:")
  119. print(f" PDF files processed: {pdf_count}")
  120. print(f" Image files found: {image_count}")
  121. print(f" Total image files to process: {len(input_files)}")
  122. return input_files
  123. def normalize_pipeline_result(result: Dict[str, Any], normalize_numbers: bool = True) -> Dict[str, Any]:
  124. """
  125. 对pipeline结果进行数字标准化处理
  126. Args:
  127. result: pipeline返回的结果对象
  128. normalize_numbers: 是否启用数字标准化
  129. Returns:
  130. 包含标准化信息的字典
  131. """
  132. if not normalize_numbers:
  133. return {
  134. "normalize_numbers": False,
  135. "changes_applied": False,
  136. "character_changes_count": 0,
  137. "parsing_res_tables_count": 0,
  138. "table_res_list_count": 0,
  139. "table_consistency_fixed": False
  140. }
  141. changes_count = 0
  142. original_data = {}
  143. # 获取原始数据进行备份
  144. if 'parsing_res_list' in result:
  145. original_data['parsing_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['parsing_res_list']]
  146. if 'table_res_list' in result:
  147. original_data['table_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['table_res_list']]
  148. try:
  149. # 1. 标准化 parsing_res_list 中的文本内容
  150. if 'parsing_res_list' in result:
  151. for item in result['parsing_res_list']:
  152. if 'block_content' in item and item['block_content']:
  153. original_content = str(item['block_content'])
  154. normalized_content = original_content
  155. # 根据block_label类型选择标准化方法
  156. if 'block_label' in item and item['block_label'] == 'table':
  157. normalized_content = normalize_markdown_table(original_content)
  158. if original_content != normalized_content:
  159. item['block_content'] = normalized_content
  160. changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
  161. # 2. 标准化 table_res_list 中的HTML表格
  162. if 'table_res_list' in result:
  163. for table_item in result['table_res_list']:
  164. if 'pred_html' in table_item and table_item['pred_html']:
  165. original_html = str(table_item['pred_html'])
  166. normalized_html = normalize_markdown_table(original_html)
  167. if original_html != normalized_html:
  168. table_item['pred_html'] = normalized_html
  169. changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
  170. # 统计表格数量
  171. parsing_res_tables_count = 0
  172. table_res_list_count = 0
  173. if 'parsing_res_list' in result:
  174. parsing_res_tables_count = len([item for item in result['parsing_res_list']
  175. if 'block_label' in item and item['block_label'] == 'table'])
  176. if 'table_res_list' in result:
  177. table_res_list_count = len(result['table_res_list'])
  178. # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
  179. table_consistency_fixed = False
  180. if parsing_res_tables_count != table_res_list_count:
  181. warnings.warn(f"⚠️ Warning: Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
  182. f"but table_res_list has {table_res_list_count} tables.")
  183. table_consistency_fixed = True
  184. # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
  185. # 但由于缺乏具体规则,暂时只做统计和警告
  186. return {
  187. "normalize_numbers": normalize_numbers,
  188. "changes_applied": changes_count > 0,
  189. "character_changes_count": changes_count,
  190. "parsing_res_tables_count": parsing_res_tables_count,
  191. "table_res_list_count": table_res_list_count,
  192. "table_consistency_fixed": table_consistency_fixed
  193. }
  194. except Exception as e:
  195. print(f"⚠️ Warning: Error during normalization: {e}")
  196. return {
  197. "normalize_numbers": normalize_numbers,
  198. "changes_applied": False,
  199. "character_changes_count": 0,
  200. "normalization_error": str(e)
  201. }
  202. def save_normalized_files(result, output_dir: str, filename: str,
  203. processing_info: Dict[str, Any], normalize_numbers: bool = True):
  204. """
  205. 保存标准化处理后的文件,包括原始版本
  206. """
  207. output_path = Path(output_dir)
  208. # 保存标准化后的版本
  209. json_output_path = str(output_path / f"{filename}.json")
  210. md_output_path = str(output_path / f"{filename}.md")
  211. result.save_to_json(json_output_path)
  212. result.save_to_markdown(md_output_path)
  213. # 如果有标准化变化,在JSON中添加处理信息
  214. if normalize_numbers and processing_info.get('changes_applied', False):
  215. try:
  216. # 读取生成的JSON文件,添加处理信息
  217. with open(json_output_path, 'r', encoding='utf-8') as f:
  218. json_data = json.load(f)
  219. json_data['processing_info'] = processing_info
  220. # 重新保存包含处理信息的JSON
  221. with open(json_output_path, 'w', encoding='utf-8') as f:
  222. json.dump(json_data, f, ensure_ascii=False, indent=2)
  223. except Exception as e:
  224. print(f"⚠️ Warning: Could not add processing info to JSON: {e}")
  225. return json_output_path, md_output_path
  226. def process_images_unified(image_paths: List[str],
  227. pipeline_name: str = "PP-StructureV3",
  228. device: str = "gpu:0",
  229. output_dir: str = "./output",
  230. normalize_numbers: bool = True) -> List[Dict[str, Any]]:
  231. """
  232. 统一的图像处理函数,支持数字标准化
  233. """
  234. # 创建输出目录
  235. output_path = Path(output_dir)
  236. output_path.mkdir(parents=True, exist_ok=True)
  237. print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
  238. try:
  239. # 设置环境变量以减少警告
  240. os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
  241. # 初始化pipeline
  242. pipeline = create_pipeline(pipeline_name, device=device)
  243. print(f"Pipeline initialized successfully on {device}")
  244. except Exception as e:
  245. print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
  246. traceback.print_exc()
  247. return []
  248. all_results = []
  249. total_images = len(image_paths)
  250. print(f"Processing {total_images} images one by one")
  251. print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
  252. # 使用tqdm显示进度
  253. with tqdm(total=total_images, desc="Processing images", unit="img",
  254. bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
  255. # 逐个处理图像
  256. for img_path in image_paths:
  257. start_time = time.time()
  258. try:
  259. # 使用pipeline预测单个图像
  260. results = pipeline.predict(
  261. img_path,
  262. use_doc_orientation_classify=True,
  263. use_doc_unwarping=False,
  264. use_seal_recognition=True,
  265. use_table_recognition=True,
  266. use_formula_recognition=False,
  267. use_chart_recognition=True,
  268. )
  269. processing_time = time.time() - start_time
  270. # 处理结果
  271. for result in results:
  272. try:
  273. input_path = Path(result["input_path"])
  274. # 生成输出文件名
  275. if result.get("page_index") is not None:
  276. output_filename = f"{input_path.stem}_{result['page_index']}"
  277. else:
  278. output_filename = f"{input_path.stem}"
  279. # 应用数字标准化
  280. processing_info = normalize_pipeline_result(result, normalize_numbers)
  281. # 保存JSON和Markdown文件(包含标准化处理)
  282. json_output_path, md_output_path = save_normalized_files(
  283. result, output_dir, output_filename, processing_info, normalize_numbers
  284. )
  285. # 如果有表格一致性修复,输出提示
  286. if processing_info.get('table_consistency_fixed', False):
  287. print(f"🔧 修复了表格一致性问题:{input_path.name}")
  288. # 记录处理结果
  289. all_results.append({
  290. "image_path": str(input_path),
  291. "processing_time": processing_time,
  292. "success": True,
  293. "device": device,
  294. "output_json": json_output_path,
  295. "output_md": md_output_path,
  296. "is_pdf_page": "_page_" in input_path.name, # 标记是否为PDF页面
  297. "processing_info": processing_info
  298. })
  299. except Exception as e:
  300. print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
  301. traceback.print_exc()
  302. all_results.append({
  303. "image_path": str(img_path),
  304. "processing_time": 0,
  305. "success": False,
  306. "device": device,
  307. "error": str(e)
  308. })
  309. # 更新进度条
  310. success_count = sum(1 for r in all_results if r.get('success', False))
  311. pbar.update(1)
  312. pbar.set_postfix({
  313. 'time': f"{processing_time:.2f}s",
  314. 'success': f"{success_count}/{len(all_results)}",
  315. 'rate': f"{success_count/len(all_results)*100:.1f}%"
  316. })
  317. except Exception as e:
  318. print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
  319. traceback.print_exc()
  320. # 添加错误结果
  321. all_results.append({
  322. "image_path": str(img_path),
  323. "processing_time": 0,
  324. "success": False,
  325. "device": device,
  326. "error": str(e)
  327. })
  328. pbar.update(1)
  329. return all_results
  330. def main():
  331. """主函数"""
  332. parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Unified PDF/Image Processor")
  333. # 参数定义
  334. input_group = parser.add_mutually_exclusive_group(required=True)
  335. input_group.add_argument("--input_file", type=str, help="Input file (supports both PDF and image file)")
  336. input_group.add_argument("--input_dir", type=str, help="Input directory (supports both PDF and image files)")
  337. input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
  338. input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
  339. parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
  340. parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
  341. parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
  342. parser.add_argument("--pdf_dpi", type=int, default=200, help="DPI for PDF to image conversion")
  343. parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化")
  344. parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
  345. parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
  346. args = parser.parse_args()
  347. normalize_numbers = not args.no_normalize
  348. try:
  349. # 获取并预处理输入文件
  350. print("🔄 Preprocessing input files...")
  351. input_files = get_input_files(args)
  352. if not input_files:
  353. print("❌ No input files found or processed")
  354. return 1
  355. if args.test_mode:
  356. input_files = input_files[:20]
  357. print(f"Test mode: processing only {len(input_files)} images")
  358. print(f"Using device: {args.device}")
  359. # 开始处理
  360. start_time = time.time()
  361. results = process_images_unified(
  362. input_files,
  363. args.pipeline,
  364. args.device,
  365. args.output_dir,
  366. normalize_numbers=normalize_numbers
  367. )
  368. total_time = time.time() - start_time
  369. # 统计结果
  370. success_count = sum(1 for r in results if r.get('success', False))
  371. error_count = len(results) - success_count
  372. pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
  373. total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
  374. print(f"\n" + "="*60)
  375. print(f"✅ Processing completed!")
  376. print(f"📊 Statistics:")
  377. print(f" Total files processed: {len(input_files)}")
  378. print(f" PDF pages processed: {pdf_page_count}")
  379. print(f" Regular images processed: {len(input_files) - pdf_page_count}")
  380. print(f" Successful: {success_count}")
  381. print(f" Failed: {error_count}")
  382. if len(input_files) > 0:
  383. print(f" Success rate: {success_count / len(input_files) * 100:.2f}%")
  384. if normalize_numbers:
  385. print(f" 总标准化字符数: {total_changes}")
  386. print(f"⏱️ Performance:")
  387. print(f" Total time: {total_time:.2f} seconds")
  388. if total_time > 0:
  389. print(f" Throughput: {len(input_files) / total_time:.2f} files/second")
  390. print(f" Avg time per file: {total_time / len(input_files):.2f} seconds")
  391. # 保存结果统计
  392. stats = {
  393. "total_files": len(input_files),
  394. "pdf_pages": pdf_page_count,
  395. "regular_images": len(input_files) - pdf_page_count,
  396. "success_count": success_count,
  397. "error_count": error_count,
  398. "success_rate": success_count / len(input_files) if len(input_files) > 0 else 0,
  399. "total_time": total_time,
  400. "throughput": len(input_files) / total_time if total_time > 0 else 0,
  401. "avg_time_per_file": total_time / len(input_files) if len(input_files) > 0 else 0,
  402. "device": args.device,
  403. "pipeline": args.pipeline,
  404. "pdf_dpi": args.pdf_dpi,
  405. "normalize_numbers": normalize_numbers,
  406. "total_character_changes": total_changes,
  407. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
  408. }
  409. # 保存最终结果
  410. output_file_name = Path(args.output_dir).name
  411. output_file = os.path.join(args.output_dir, f"{output_file_name}_unified.json")
  412. final_results = {
  413. "stats": stats,
  414. "results": results
  415. }
  416. with open(output_file, 'w', encoding='utf-8') as f:
  417. json.dump(final_results, f, ensure_ascii=False, indent=2)
  418. print(f"💾 Results saved to: {output_file}")
  419. # 如果没有收集结果的路径,使用缺省文件名,和output_dir同一路径
  420. if not args.collect_results:
  421. output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
  422. else:
  423. output_file_processed = Path(args.collect_results).resolve()
  424. processed_files = collect_pid_files(output_file)
  425. with open(output_file_processed, 'w', encoding='utf-8') as f:
  426. f.write("image_path,status\n")
  427. for file_path, status in processed_files:
  428. f.write(f"{file_path},{status}\n")
  429. print(f"💾 Processed files saved to: {output_file_processed}")
  430. return 0
  431. except Exception as e:
  432. print(f"❌ Processing failed: {e}", file=sys.stderr)
  433. traceback.print_exc()
  434. return 1
  435. if __name__ == "__main__":
  436. print(f"🚀 启动统一PDF/图像处理程序...")
  437. print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
  438. if len(sys.argv) == 1:
  439. # 如果没有命令行参数,使用默认配置运行
  440. print("ℹ️ No command line arguments provided. Running with default configuration...")
  441. # 默认配置
  442. default_config = {
  443. "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
  444. "output_dir": "./OmniDocBench_PPStructureV3_Results",
  445. "pipeline": "./my_config/PP-StructureV3.yaml",
  446. "device": "gpu:0",
  447. "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
  448. }
  449. # 构造参数
  450. sys.argv = [sys.argv[0]]
  451. for key, value in default_config.items():
  452. sys.argv.extend([f"--{key}", str(value)])
  453. # 可以添加禁用标准化选项
  454. # sys.argv.append("--no-normalize")
  455. # 测试模式
  456. # sys.argv.append("--test_mode")
  457. sys.exit(main())