"""仅运行 table_recognition_v2 管线,并将表格HTML转为Markdown保存""" import os import re import sys import json import time import argparse import traceback import warnings from pathlib import Path from typing import List, Dict, Any, Tuple, Optional warnings.filterwarnings("ignore", message="To copy construct from a tensor") warnings.filterwarnings("ignore", message="Setting `pad_token_id`") warnings.filterwarnings("ignore", category=UserWarning, module="paddlex") from paddlex import create_pipeline from tqdm import tqdm from dotenv import load_dotenv load_dotenv(override=True) # 复用你现有的输入获取与保存工具 from ppstructurev3_utils import ( save_output_images, # 支持保存 result.img 中的可视化 ) from utils import normalize_markdown_table, get_input_files def html_table_to_markdown(html: str) -> str: """ 将简单HTML表格转换为Markdown表格。 支持thead/tbody/tr/td/th;对嵌套复杂标签仅提取纯文本。 """ #去掉
,以及结尾的 html = re.sub(r'?html>', '', html, flags=re.IGNORECASE) html = re.sub(r'?body>', '', html, flags=re.IGNORECASE) html = html.strip() return html def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str, normalize_numbers: bool = True) -> Tuple[str, List[str], int]: if not json_res: return "", [], 0 # 从 table_res_list 中取出 pred_html 转为 Markdown table_list = (json_res or {}).get("table_res_list", []) or [] md_tables = [] changes_count = 0 for idx, t in enumerate(table_list): html = t.get("pred_html", "") if not html: continue # 2. 标准化 table_res_list 中的HTML表格 elif normalize_numbers: normalized_html = normalize_markdown_table(html) if html != normalized_html: json_res['table_res_list'][idx]['pred_html'] = normalized_html changes_count += len([1 for o, n in zip(html, normalized_html) if o != n]) md_tables.append(html_table_to_markdown(html)) # 保存 JSON 结果 out_dir = Path(output_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) json_fp = out_dir / f"{base_name}.json" with open(json_fp, "w", encoding="utf-8") as f: json.dump(json_res, f, ensure_ascii=False, indent=2) return json_fp.as_posix(), md_tables, changes_count def save_markdown_tables(md_tables: List[str], output_dir: str, base_name: str, normalize_numbers: bool = True) -> str: """ 将多个Markdown表格分别保存为 base_name_table_{i}.md,返回保存路径列表。 同时生成一个合并文件 base_name_tables.md。 """ out_dir = Path(output_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) contents = [] for i, md in enumerate(md_tables, 1): content = normalize_markdown_table(md) if normalize_numbers else md contents.append(content) markdown_path = out_dir / f"{base_name}.md" with open(markdown_path, "w", encoding="utf-8") as f: for content in contents: f.write(content + "\n\n") return markdown_path.as_posix() def process_images_with_table_pipeline( image_paths: List[str], pipeline_cfg: str = "./my_config/table_recognition_v2.yaml", device: str = "gpu:0", output_dir: str = "./output", normalize_numbers: bool = True ) -> List[Dict[str, Any]]: """ 运行 table_recognition_v2 管线,输出 JSON、可视化图,且将每个表格HTML转为Markdown保存。 """ output_path = Path(output_dir).resolve() output_path.mkdir(parents=True, exist_ok=True) print(f"Initializing pipeline '{pipeline_cfg}' on device '{device}'...") try: os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning' pipeline = create_pipeline(pipeline_cfg, device=device) print(f"Pipeline initialized successfully on {device}") except Exception as e: print(f"Failed to initialize pipeline: {e}", file=sys.stderr) traceback.print_exc() return [] results_all: List[Dict[str, Any]] = [] total = len(image_paths) print(f"Processing {total} images with table_recognition_v2") print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}") with tqdm(total=total, desc="Processing images", unit="img", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar: for img_path in image_paths: start = time.time() try: outputs = pipeline.predict( img_path, use_doc_orientation_classify=False, use_doc_unwarping=False, use_layout_detection=True, use_ocr_results_with_table_cells=True, use_table_orientation_classify=False, use_wired_table_cells_trans_to_html=True, # 新增:关闭单元格内拆分,整格识别以保留折行文本, # 修改paddlex/inference/pipelines/table_recognition/pipeline_v2.py # get_table_recognition_res传入参数self.cells_split_ocr=False,保证单元格内换行不被拆分 use_table_cells_split_ocr=False, ) cost = time.time() - start # 一般每张图片只返回一个结果 for idx, res in enumerate(outputs): if idx > 0: raise ValueError("Multiple results found for a single image") input_path = Path(res["input_path"]) base_name = input_path.stem res.save_all(save_path=output_path.as_posix()) # 保存所有结果到指定路径 # 保存结构化JSON json_res = res.json.get("res", res.json) saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers) saved_md = save_markdown_tables(md_tables, str(output_path), base_name, normalize_numbers=normalize_numbers) results_all.append({ "image_path": str(input_path), "success": True, "time_sec": cost, "device": device, "json_path": saved_json, "markdown_path": saved_md, "tables_detected": len(md_tables), "is_pdf_page": "_page_" in input_path.name, "normalize_numbers": normalize_numbers, "changes_applied": changes_count > 0, "character_changes_count": changes_count, }) pbar.update(1) ok = sum(1 for r in results_all if r.get("success")) pbar.set_postfix(time=f"{cost:.2f}s", ok=ok) except Exception as e: traceback.print_exc() results_all.append({ "image_path": str(img_path), "success": False, "time_sec": 0, "device": device, "error": str(e) }) pbar.update(1) pbar.set_postfix_str("error") return results_all def main(): parser = argparse.ArgumentParser(description="table_recognition_v2 单管线运行(输出Markdown表格)") g = parser.add_mutually_exclusive_group(required=True) g.add_argument("--input_file", type=str, help="单个文件(图片或PDF)") g.add_argument("--input_dir", type=str, help="目录") g.add_argument("--input_file_list", type=str, help="文件列表(每行一个路径)") g.add_argument("--input_csv", type=str, help="CSV,含 image_path 与 status 列") parser.add_argument("--output_dir", type=str, required=True, help="输出目录") parser.add_argument("--pipeline", type=str, default="./my_config/table_recognition_v2.yaml", help="管线名称或配置文件路径(默认使用本仓库的 table_recognition_v2.yaml)") parser.add_argument("--device", type=str, default="gpu:0", help="gpu:0 或 cpu") parser.add_argument("--pdf_dpi", type=int, default=200, help="PDF 转图像 DPI") parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化(仅对Markdown内容生效)") parser.add_argument("--test_mode", action="store_true", help="仅处理前20个文件") parser.add_argument("--collect_results", type=str, help="将处理状态收集到指定CSV") args = parser.parse_args() normalize_numbers = not args.no_normalize # 复用 ppstructurev3_utils 的输入收集逻辑 input_files = get_input_files(args) if not input_files: print("❌ No input files found or processed") return 1 if args.test_mode: input_files = input_files[:20] print(f"Test mode: processing only {len(input_files)} images") print(f"Using device: {args.device}") start = time.time() results = process_images_with_table_pipeline( input_files, args.pipeline, args.device, args.output_dir, normalize_numbers=normalize_numbers ) total_time = time.time() - start success = sum(1 for r in results if r.get("success")) failed = len(results) - success pdf_pages = sum(1 for r in results if r.get("is_pdf_page", False)) total_tables = sum(r.get("tables_detected", 0) for r in results if r.get("success")) print("\n" + "="*60) print("✅ Processing completed!") print("📊 Statistics:") print(f" Total files processed: {len(input_files)}") print(f" PDF pages processed: {pdf_pages}") print(f" Regular images processed: {len(input_files) - pdf_pages}") print(f" Successful: {success}") print(f" Failed: {failed}") print("⏱️ Performance:") print(f" Total time: {total_time:.2f} seconds") if total_time > 0: print(f" Throughput: {len(input_files) / total_time:.2f} files/second") print(f" Avg time per file: {total_time / len(input_files):.2f} seconds") print(f" Tables detected (total): {total_tables}") # 汇总保存 out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) summary = { "stats": { "total_files": len(input_files), "pdf_pages": pdf_pages, "regular_images": len(input_files) - pdf_pages, "success_count": success, "error_count": failed, "total_time_sec": total_time, "throughput_fps": len(input_files) / total_time if total_time > 0 else 0, "avg_time_per_file_sec": total_time / len(input_files) if len(input_files) > 0 else 0, "pipeline": args.pipeline, "device": args.device, "normalize_numbers": normalize_numbers, "total_tables": total_tables, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), }, "results": results } out_json = out_dir / f"{Path(args.output_dir).name}_table_recognition_v2.json" with open(out_json, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) print(f"💾 Results saved to: {out_json}") # 处理状态汇总CSV(可选) try: if args.collect_results: from utils import collect_pid_files processed_files = collect_pid_files(out_json.as_posix()) csv_path = Path(args.collect_results).resolve() with open(csv_path, "w", encoding="utf-8") as f: f.write("image_path,status\n") for file_path, status in processed_files: f.write(f"{file_path},{status}\n") print(f"💾 Processed files saved to: {csv_path}") except Exception as e: print(f"⚠️ Failed to save processed files CSV: {e}") return 0 if __name__ == "__main__": print("🚀 启动 table_recognition_v2 单管线处理程序...") if len(sys.argv) == 1: # 演示默认参数(请按需修改) # demo = { # "--input_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img", # "--output_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/table_recognition_v2_Results", # "--pipeline": "./my_config/table_recognition_v2.yaml", # "--device": "cpu", # } demo = { "--input_file": "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_004.png", "--output_dir": "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/table_recognition_v2_Results", "--pipeline": "./my_config/table_recognition_v2.yaml", "--device": "cpu", } sys.argv = [sys.argv[0]] + [kv for kv in sum(demo.items(), ())] sys.exit(main())