table_recognition_v2_single_process.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. """仅运行 table_recognition_v2 管线,并将表格HTML转为Markdown保存"""
  2. import os
  3. import re
  4. import sys
  5. import json
  6. import time
  7. import argparse
  8. import traceback
  9. import warnings
  10. from pathlib import Path
  11. from typing import List, Dict, Any, Tuple, Optional
  12. warnings.filterwarnings("ignore", message="To copy construct from a tensor")
  13. warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
  14. warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
  15. from paddlex import create_pipeline
  16. from tqdm import tqdm
  17. from dotenv import load_dotenv
  18. load_dotenv(override=True)
  19. # 复用你现有的输入获取与保存工具
  20. from ppstructurev3_utils import (
  21. get_input_files,
  22. save_output_images, # 支持保存 result.img 中的可视化
  23. )
  24. from utils import normalize_markdown_table
  25. def html_table_to_markdown(html: str) -> str:
  26. """
  27. 将简单HTML表格转换为Markdown表格。
  28. 支持thead/tbody/tr/td/th;对嵌套复杂标签仅提取纯文本。
  29. """
  30. #去掉<html><body>,以及结尾的</body></html>
  31. html = re.sub(r'</?html>', '', html, flags=re.IGNORECASE)
  32. html = re.sub(r'</?body>', '', html, flags=re.IGNORECASE)
  33. html = html.strip()
  34. return html
  35. def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str,
  36. normalize_numbers: bool = True) -> Tuple[str, List[str], int]:
  37. if not json_res:
  38. return "", [], 0
  39. # 从 table_res_list 中取出 pred_html 转为 Markdown
  40. table_list = (json_res or {}).get("table_res_list", []) or []
  41. md_tables = []
  42. changes_count = 0
  43. for idx, t in enumerate(table_list):
  44. html = t.get("pred_html", "")
  45. if not html:
  46. continue
  47. # 2. 标准化 table_res_list 中的HTML表格
  48. elif normalize_numbers:
  49. normalized_html = normalize_markdown_table(html)
  50. if html != normalized_html:
  51. json_res['table_res_list'][idx]['pred_html'] = normalized_html
  52. changes_count += len([1 for o, n in zip(html, normalized_html) if o != n])
  53. md_tables.append(html_table_to_markdown(html))
  54. # 保存 JSON 结果
  55. out_dir = Path(output_dir).resolve()
  56. out_dir.mkdir(parents=True, exist_ok=True)
  57. json_fp = out_dir / f"{base_name}.json"
  58. with open(json_fp, "w", encoding="utf-8") as f:
  59. json.dump(json_res, f, ensure_ascii=False, indent=2)
  60. return json_fp.as_posix(), md_tables, changes_count
  61. def save_markdown_tables(md_tables: List[str], output_dir: str, base_name: str,
  62. normalize_numbers: bool = True) -> str:
  63. """
  64. 将多个Markdown表格分别保存为 base_name_table_{i}.md,返回保存路径列表。
  65. 同时生成一个合并文件 base_name_tables.md。
  66. """
  67. out_dir = Path(output_dir).resolve()
  68. out_dir.mkdir(parents=True, exist_ok=True)
  69. contents = []
  70. for i, md in enumerate(md_tables, 1):
  71. content = normalize_markdown_table(md) if normalize_numbers else md
  72. contents.append(content)
  73. markdown_path = out_dir / f"{base_name}.md"
  74. with open(markdown_path, "w", encoding="utf-8") as f:
  75. for content in contents:
  76. f.write(content + "\n\n")
  77. return markdown_path.as_posix()
  78. def process_images_with_table_pipeline(
  79. image_paths: List[str],
  80. pipeline_cfg: str = "./my_config/table_recognition_v2.yaml",
  81. device: str = "gpu:0",
  82. output_dir: str = "./output",
  83. normalize_numbers: bool = True
  84. ) -> List[Dict[str, Any]]:
  85. """
  86. 运行 table_recognition_v2 管线,输出 JSON、可视化图,且将每个表格HTML转为Markdown保存。
  87. """
  88. output_path = Path(output_dir).resolve()
  89. output_path.mkdir(parents=True, exist_ok=True)
  90. print(f"Initializing pipeline '{pipeline_cfg}' on device '{device}'...")
  91. try:
  92. os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
  93. pipeline = create_pipeline(pipeline_cfg, device=device)
  94. print(f"Pipeline initialized successfully on {device}")
  95. except Exception as e:
  96. print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
  97. traceback.print_exc()
  98. return []
  99. results_all: List[Dict[str, Any]] = []
  100. total = len(image_paths)
  101. print(f"Processing {total} images with table_recognition_v2")
  102. print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
  103. with tqdm(total=total, desc="Processing images", unit="img",
  104. bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
  105. for img_path in image_paths:
  106. start = time.time()
  107. try:
  108. outputs = pipeline.predict(
  109. img_path,
  110. use_doc_preprocessor=True,
  111. use_layout_detection=True,
  112. use_ocr_model=True
  113. )
  114. cost = time.time() - start
  115. # 一般每张图片只返回一个结果
  116. for idx, res in enumerate(outputs):
  117. if idx > 0:
  118. raise ValueError("Multiple results found for a single image")
  119. input_path = Path(res["input_path"])
  120. base_name = input_path.stem
  121. res.save_all(save_path=output_path.as_posix()) # 保存所有结果到指定路径
  122. # 保存结构化JSON
  123. json_res = res.json.get("res", res.json)
  124. saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers)
  125. saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
  126. normalize_numbers=normalize_numbers)
  127. results_all.append({
  128. "image_path": str(input_path),
  129. "success": True,
  130. "time_sec": cost,
  131. "device": device,
  132. "json_path": saved_json,
  133. "markdown_path": saved_md,
  134. "tables_detected": len(md_tables),
  135. "is_pdf_page": "_page_" in input_path.name,
  136. "normalize_numbers": normalize_numbers,
  137. "changes_applied": changes_count > 0,
  138. "character_changes_count": changes_count,
  139. })
  140. pbar.update(1)
  141. ok = sum(1 for r in results_all if r.get("success"))
  142. pbar.set_postfix(time=f"{cost:.2f}s", ok=ok)
  143. except Exception as e:
  144. traceback.print_exc()
  145. results_all.append({
  146. "image_path": str(img_path),
  147. "success": False,
  148. "time_sec": 0,
  149. "device": device,
  150. "error": str(e)
  151. })
  152. pbar.update(1)
  153. pbar.set_postfix_str("error")
  154. return results_all
  155. def main():
  156. parser = argparse.ArgumentParser(description="table_recognition_v2 单管线运行(输出Markdown表格)")
  157. g = parser.add_mutually_exclusive_group(required=True)
  158. g.add_argument("--input_file", type=str, help="单个文件(图片或PDF)")
  159. g.add_argument("--input_dir", type=str, help="目录")
  160. g.add_argument("--input_file_list", type=str, help="文件列表(每行一个路径)")
  161. g.add_argument("--input_csv", type=str, help="CSV,含 image_path 与 status 列")
  162. parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
  163. parser.add_argument("--pipeline", type=str, default="./my_config/table_recognition_v2.yaml",
  164. help="管线名称或配置文件路径(默认使用本仓库的 table_recognition_v2.yaml)")
  165. parser.add_argument("--device", type=str, default="gpu:0", help="gpu:0 或 cpu")
  166. parser.add_argument("--pdf_dpi", type=int, default=200, help="PDF 转图像 DPI")
  167. parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化(仅对Markdown内容生效)")
  168. parser.add_argument("--test_mode", action="store_true", help="仅处理前20个文件")
  169. parser.add_argument("--collect_results", type=str, help="将处理状态收集到指定CSV")
  170. args = parser.parse_args()
  171. normalize_numbers = not args.no_normalize
  172. # 复用 ppstructurev3_utils 的输入收集逻辑
  173. input_files = get_input_files(args)
  174. if not input_files:
  175. print("❌ No input files found or processed")
  176. return 1
  177. if args.test_mode:
  178. input_files = input_files[:20]
  179. print(f"Test mode: processing only {len(input_files)} images")
  180. print(f"Using device: {args.device}")
  181. start = time.time()
  182. results = process_images_with_table_pipeline(
  183. input_files,
  184. args.pipeline,
  185. args.device,
  186. args.output_dir,
  187. normalize_numbers=normalize_numbers
  188. )
  189. total_time = time.time() - start
  190. success = sum(1 for r in results if r.get("success"))
  191. failed = len(results) - success
  192. pdf_pages = sum(1 for r in results if r.get("is_pdf_page", False))
  193. total_tables = sum(r.get("tables_detected", 0) for r in results if r.get("success"))
  194. print("\n" + "="*60)
  195. print("✅ Processing completed!")
  196. print("📊 Statistics:")
  197. print(f" Total files processed: {len(input_files)}")
  198. print(f" PDF pages processed: {pdf_pages}")
  199. print(f" Regular images processed: {len(input_files) - pdf_pages}")
  200. print(f" Successful: {success}")
  201. print(f" Failed: {failed}")
  202. print("⏱️ Performance:")
  203. print(f" Total time: {total_time:.2f} seconds")
  204. if total_time > 0:
  205. print(f" Throughput: {len(input_files) / total_time:.2f} files/second")
  206. print(f" Avg time per file: {total_time / len(input_files):.2f} seconds")
  207. print(f" Tables detected (total): {total_tables}")
  208. # 汇总保存
  209. out_dir = Path(args.output_dir)
  210. out_dir.mkdir(parents=True, exist_ok=True)
  211. summary = {
  212. "stats": {
  213. "total_files": len(input_files),
  214. "pdf_pages": pdf_pages,
  215. "regular_images": len(input_files) - pdf_pages,
  216. "success_count": success,
  217. "error_count": failed,
  218. "total_time_sec": total_time,
  219. "throughput_fps": len(input_files) / total_time if total_time > 0 else 0,
  220. "avg_time_per_file_sec": total_time / len(input_files) if len(input_files) > 0 else 0,
  221. "pipeline": args.pipeline,
  222. "device": args.device,
  223. "normalize_numbers": normalize_numbers,
  224. "total_tables": total_tables,
  225. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
  226. },
  227. "results": results
  228. }
  229. out_json = out_dir / f"{Path(args.output_dir).name}_table_recognition_v2.json"
  230. with open(out_json, "w", encoding="utf-8") as f:
  231. json.dump(summary, f, ensure_ascii=False, indent=2)
  232. print(f"💾 Results saved to: {out_json}")
  233. # 处理状态汇总CSV(可选)
  234. try:
  235. if args.collect_results:
  236. from utils import collect_pid_files
  237. processed_files = collect_pid_files(out_json.as_posix())
  238. csv_path = Path(args.collect_results).resolve()
  239. with open(csv_path, "w", encoding="utf-8") as f:
  240. f.write("image_path,status\n")
  241. for file_path, status in processed_files:
  242. f.write(f"{file_path},{status}\n")
  243. print(f"💾 Processed files saved to: {csv_path}")
  244. except Exception as e:
  245. print(f"⚠️ Failed to save processed files CSV: {e}")
  246. return 0
  247. if __name__ == "__main__":
  248. print("🚀 启动 table_recognition_v2 单管线处理程序...")
  249. if len(sys.argv) == 1:
  250. # 演示默认参数(请按需修改)
  251. demo = {
  252. "--input_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img",
  253. "--output_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/table_recognition_v2_Results",
  254. "--pipeline": "./my_config/table_recognition_v2.yaml",
  255. "--device": "cpu",
  256. }
  257. sys.argv = [sys.argv[0]] + [kv for kv in sum(demo.items(), ())]
  258. sys.exit(main())