table_recognition_v2_single_process.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. """仅运行 table_recognition_v2 管线,并将表格HTML转为Markdown保存"""
  2. import os
  3. import re
  4. import sys
  5. import json
  6. import time
  7. import argparse
  8. import traceback
  9. import warnings
  10. from pathlib import Path
  11. from typing import List, Dict, Any, Tuple, Optional
  12. warnings.filterwarnings("ignore", message="To copy construct from a tensor")
  13. warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
  14. warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
  15. from paddlex import create_pipeline
  16. from tqdm import tqdm
  17. from dotenv import load_dotenv
  18. load_dotenv(override=True)
  19. # 复用你现有的输入获取与保存工具
  20. from ppstructurev3_utils import (
  21. save_output_images, # 支持保存 result.img 中的可视化
  22. )
  23. from utils import normalize_markdown_table, get_input_files
  24. def html_table_to_markdown(html: str) -> str:
  25. """
  26. 将简单HTML表格转换为Markdown表格。
  27. 支持thead/tbody/tr/td/th;对嵌套复杂标签仅提取纯文本。
  28. """
  29. #去掉<html><body>,以及结尾的</body></html>
  30. html = re.sub(r'</?html>', '', html, flags=re.IGNORECASE)
  31. html = re.sub(r'</?body>', '', html, flags=re.IGNORECASE)
  32. html = html.strip()
  33. return html
  34. def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str,
  35. normalize_numbers: bool = True) -> Tuple[str, List[str], int]:
  36. if not json_res:
  37. return "", [], 0
  38. # 从 table_res_list 中取出 pred_html 转为 Markdown
  39. table_list = (json_res or {}).get("table_res_list", []) or []
  40. md_tables = []
  41. changes_count = 0
  42. for idx, t in enumerate(table_list):
  43. html = t.get("pred_html", "")
  44. if not html:
  45. continue
  46. # 2. 标准化 table_res_list 中的HTML表格
  47. elif normalize_numbers:
  48. normalized_html = normalize_markdown_table(html)
  49. if html != normalized_html:
  50. json_res['table_res_list'][idx]['pred_html'] = normalized_html
  51. changes_count += len([1 for o, n in zip(html, normalized_html) if o != n])
  52. md_tables.append(html_table_to_markdown(html))
  53. # 保存 JSON 结果
  54. out_dir = Path(output_dir).resolve()
  55. out_dir.mkdir(parents=True, exist_ok=True)
  56. json_fp = out_dir / f"{base_name}.json"
  57. with open(json_fp, "w", encoding="utf-8") as f:
  58. json.dump(json_res, f, ensure_ascii=False, indent=2)
  59. return json_fp.as_posix(), md_tables, changes_count
  60. def save_markdown_tables(md_tables: List[str], output_dir: str, base_name: str,
  61. normalize_numbers: bool = True) -> str:
  62. """
  63. 将多个Markdown表格分别保存为 base_name_table_{i}.md,返回保存路径列表。
  64. 同时生成一个合并文件 base_name_tables.md。
  65. """
  66. out_dir = Path(output_dir).resolve()
  67. out_dir.mkdir(parents=True, exist_ok=True)
  68. contents = []
  69. for i, md in enumerate(md_tables, 1):
  70. content = normalize_markdown_table(md) if normalize_numbers else md
  71. contents.append(content)
  72. markdown_path = out_dir / f"{base_name}.md"
  73. with open(markdown_path, "w", encoding="utf-8") as f:
  74. for content in contents:
  75. f.write(content + "\n\n")
  76. return markdown_path.as_posix()
  77. def process_images_with_table_pipeline(
  78. image_paths: List[str],
  79. pipeline_cfg: str = "./my_config/table_recognition_v2.yaml",
  80. device: str = "gpu:0",
  81. output_dir: str = "./output",
  82. normalize_numbers: bool = True
  83. ) -> List[Dict[str, Any]]:
  84. """
  85. 运行 table_recognition_v2 管线,输出 JSON、可视化图,且将每个表格HTML转为Markdown保存。
  86. """
  87. output_path = Path(output_dir).resolve()
  88. output_path.mkdir(parents=True, exist_ok=True)
  89. print(f"Initializing pipeline '{pipeline_cfg}' on device '{device}'...")
  90. try:
  91. os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
  92. pipeline = create_pipeline(pipeline_cfg, device=device)
  93. print(f"Pipeline initialized successfully on {device}")
  94. except Exception as e:
  95. print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
  96. traceback.print_exc()
  97. return []
  98. results_all: List[Dict[str, Any]] = []
  99. total = len(image_paths)
  100. print(f"Processing {total} images with table_recognition_v2")
  101. print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
  102. with tqdm(total=total, desc="Processing images", unit="img",
  103. bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
  104. for img_path in image_paths:
  105. start = time.time()
  106. try:
  107. outputs = pipeline.predict(
  108. img_path,
  109. use_doc_orientation_classify=True,
  110. use_doc_unwarping=False,
  111. use_layout_detection=True,
  112. use_ocr_results_with_table_cells=True,
  113. use_table_orientation_classify=True,
  114. use_wired_table_cells_trans_to_html=True,
  115. # 新增:关闭单元格内拆分,整格识别以保留折行文本,
  116. # 修改paddlex/inference/pipelines/table_recognition/pipeline_v2.py
  117. # get_table_recognition_res传入参数self.cells_split_ocr=False,保证单元格内换行不被拆分
  118. use_table_cells_split_ocr=False,
  119. )
  120. cost = time.time() - start
  121. # 一般每张图片只返回一个结果
  122. for idx, res in enumerate(outputs):
  123. if idx > 0:
  124. raise ValueError("Multiple results found for a single image")
  125. input_path = Path(res["input_path"])
  126. base_name = input_path.stem
  127. res.save_all(save_path=output_path.as_posix()) # 保存所有结果到指定路径
  128. # 保存结构化JSON
  129. json_res = res.json.get("res", res.json)
  130. saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers)
  131. saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
  132. normalize_numbers=normalize_numbers)
  133. results_all.append({
  134. "image_path": str(input_path),
  135. "success": True,
  136. "time_sec": cost,
  137. "device": device,
  138. "json_path": saved_json,
  139. "markdown_path": saved_md,
  140. "tables_detected": len(md_tables),
  141. "is_pdf_page": "_page_" in input_path.name,
  142. "normalize_numbers": normalize_numbers,
  143. "changes_applied": changes_count > 0,
  144. "character_changes_count": changes_count,
  145. })
  146. pbar.update(1)
  147. ok = sum(1 for r in results_all if r.get("success"))
  148. pbar.set_postfix(time=f"{cost:.2f}s", ok=ok)
  149. except Exception as e:
  150. traceback.print_exc()
  151. results_all.append({
  152. "image_path": str(img_path),
  153. "success": False,
  154. "time_sec": 0,
  155. "device": device,
  156. "error": str(e)
  157. })
  158. pbar.update(1)
  159. pbar.set_postfix_str("error")
  160. return results_all
  161. def main():
  162. parser = argparse.ArgumentParser(description="table_recognition_v2 单管线运行(输出Markdown表格)")
  163. g = parser.add_mutually_exclusive_group(required=True)
  164. g.add_argument("--input_file", type=str, help="单个文件(图片或PDF)")
  165. g.add_argument("--input_dir", type=str, help="目录")
  166. g.add_argument("--input_file_list", type=str, help="文件列表(每行一个路径)")
  167. g.add_argument("--input_csv", type=str, help="CSV,含 image_path 与 status 列")
  168. parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
  169. parser.add_argument("--pipeline", type=str, default="./my_config/table_recognition_v2.yaml",
  170. help="管线名称或配置文件路径(默认使用本仓库的 table_recognition_v2.yaml)")
  171. parser.add_argument("--device", type=str, default="gpu:0", help="gpu:0 或 cpu")
  172. parser.add_argument("--pdf_dpi", type=int, default=200, help="PDF 转图像 DPI")
  173. parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化(仅对Markdown内容生效)")
  174. parser.add_argument("--test_mode", action="store_true", help="仅处理前20个文件")
  175. parser.add_argument("--collect_results", type=str, help="将处理状态收集到指定CSV")
  176. args = parser.parse_args()
  177. normalize_numbers = not args.no_normalize
  178. # 复用 ppstructurev3_utils 的输入收集逻辑
  179. input_files = get_input_files(args)
  180. if not input_files:
  181. print("❌ No input files found or processed")
  182. return 1
  183. if args.test_mode:
  184. input_files = input_files[:20]
  185. print(f"Test mode: processing only {len(input_files)} images")
  186. print(f"Using device: {args.device}")
  187. start = time.time()
  188. results = process_images_with_table_pipeline(
  189. input_files,
  190. args.pipeline,
  191. args.device,
  192. args.output_dir,
  193. normalize_numbers=normalize_numbers
  194. )
  195. total_time = time.time() - start
  196. success = sum(1 for r in results if r.get("success"))
  197. failed = len(results) - success
  198. pdf_pages = sum(1 for r in results if r.get("is_pdf_page", False))
  199. total_tables = sum(r.get("tables_detected", 0) for r in results if r.get("success"))
  200. print("\n" + "="*60)
  201. print("✅ Processing completed!")
  202. print("📊 Statistics:")
  203. print(f" Total files processed: {len(input_files)}")
  204. print(f" PDF pages processed: {pdf_pages}")
  205. print(f" Regular images processed: {len(input_files) - pdf_pages}")
  206. print(f" Successful: {success}")
  207. print(f" Failed: {failed}")
  208. print("⏱️ Performance:")
  209. print(f" Total time: {total_time:.2f} seconds")
  210. if total_time > 0:
  211. print(f" Throughput: {len(input_files) / total_time:.2f} files/second")
  212. print(f" Avg time per file: {total_time / len(input_files):.2f} seconds")
  213. print(f" Tables detected (total): {total_tables}")
  214. # 汇总保存
  215. out_dir = Path(args.output_dir)
  216. out_dir.mkdir(parents=True, exist_ok=True)
  217. summary = {
  218. "stats": {
  219. "total_files": len(input_files),
  220. "pdf_pages": pdf_pages,
  221. "regular_images": len(input_files) - pdf_pages,
  222. "success_count": success,
  223. "error_count": failed,
  224. "total_time_sec": total_time,
  225. "throughput_fps": len(input_files) / total_time if total_time > 0 else 0,
  226. "avg_time_per_file_sec": total_time / len(input_files) if len(input_files) > 0 else 0,
  227. "pipeline": args.pipeline,
  228. "device": args.device,
  229. "normalize_numbers": normalize_numbers,
  230. "total_tables": total_tables,
  231. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
  232. },
  233. "results": results
  234. }
  235. out_json = out_dir / f"{Path(args.output_dir).name}_table_recognition_v2.json"
  236. with open(out_json, "w", encoding="utf-8") as f:
  237. json.dump(summary, f, ensure_ascii=False, indent=2)
  238. print(f"💾 Results saved to: {out_json}")
  239. # 处理状态汇总CSV(可选)
  240. try:
  241. if args.collect_results:
  242. from utils import collect_pid_files
  243. processed_files = collect_pid_files(out_json.as_posix())
  244. csv_path = Path(args.collect_results).resolve()
  245. with open(csv_path, "w", encoding="utf-8") as f:
  246. f.write("image_path,status\n")
  247. for file_path, status in processed_files:
  248. f.write(f"{file_path},{status}\n")
  249. print(f"💾 Processed files saved to: {csv_path}")
  250. except Exception as e:
  251. print(f"⚠️ Failed to save processed files CSV: {e}")
  252. return 0
  253. if __name__ == "__main__":
  254. print("🚀 启动 table_recognition_v2 单管线处理程序...")
  255. if len(sys.argv) == 1:
  256. # 演示默认参数(请按需修改)
  257. # demo = {
  258. # "--input_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img",
  259. # "--output_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/table_recognition_v2_Results",
  260. # "--pipeline": "./my_config/table_recognition_v2.yaml",
  261. # "--device": "cpu",
  262. # }
  263. demo = {
  264. "--input_file": "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
  265. "--output_dir": "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/table_recognition_v2_Results",
  266. "--pipeline": "./my_config/table_recognition_v2.yaml",
  267. "--device": "cpu",
  268. }
  269. sys.argv = [sys.argv[0]] + [kv for kv in sum(demo.items(), ())]
  270. sys.exit(main())