"""仅运行 table_recognition_v2 管线,并将表格HTML转为Markdown保存""" import os import re import sys import json import time import argparse import traceback import warnings from pathlib import Path from typing import List, Dict, Any, Tuple, Optional warnings.filterwarnings("ignore", message="To copy construct from a tensor") warnings.filterwarnings("ignore", message="Setting `pad_token_id`") warnings.filterwarnings("ignore", category=UserWarning, module="paddlex") from paddlex import create_pipeline from tqdm import tqdm from dotenv import load_dotenv load_dotenv(override=True) # 复用你现有的输入获取与保存工具 from ppstructurev3_utils import ( save_output_images, # 支持保存 result.img 中的可视化 ) from utils import normalize_markdown_table, get_input_files # 🎯 新增:导入适配器 from adapters.table_recognition_adapter import apply_table_recognition_adapter, restore_original_function def html_table_to_markdown(html: str) -> str: """ 将简单HTML表格转换为Markdown表格。 支持thead/tbody/tr/td/th;对嵌套复杂标签仅提取纯文本。 """ #去掉
,以及结尾的 html = re.sub(r'?html>', '', html, flags=re.IGNORECASE) html = re.sub(r'?body>', '', html, flags=re.IGNORECASE) html = html.strip() return html def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str, normalize_numbers: bool = True) -> Tuple[str, List[str], int]: if not json_res: return "", [], 0 # 从 table_res_list 中取出 pred_html 转为 Markdown table_list = (json_res or {}).get("table_res_list", []) or [] md_tables = [] changes_count = 0 for idx, t in enumerate(table_list): html = t.get("pred_html", "") if not html: continue # 2. 标准化 table_res_list 中的HTML表格 elif normalize_numbers: normalized_html = normalize_markdown_table(html) if html != normalized_html: json_res['table_res_list'][idx]['pred_html'] = normalized_html changes_count += len([1 for o, n in zip(html, normalized_html) if o != n]) md_tables.append(html_table_to_markdown(html)) # 保存 JSON 结果 out_dir = Path(output_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) json_fp = out_dir / f"{base_name}.json" with open(json_fp, "w", encoding="utf-8") as f: json.dump(json_res, f, ensure_ascii=False, indent=2) return json_fp.as_posix(), md_tables, changes_count def save_markdown_tables(md_tables: List[str], output_dir: str, base_name: str, normalize_numbers: bool = True) -> str: """ 将多个Markdown表格分别保存为 base_name_table_{i}.md,返回保存路径列表。 同时生成一个合并文件 base_name_tables.md。 """ out_dir = Path(output_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) contents = [] for i, md in enumerate(md_tables, 1): content = normalize_markdown_table(md) if normalize_numbers else md contents.append(content) markdown_path = out_dir / f"{base_name}.md" with open(markdown_path, "w", encoding="utf-8") as f: for content in contents: f.write(content + "\n\n") return markdown_path.as_posix() def process_images_with_table_pipeline( image_paths: List[str], pipeline_cfg: str = "./my_config/table_recognition_v2.yaml", device: str = "gpu:0", output_dir: str = "./output", normalize_numbers: bool = True, use_enhanced_adapter: bool = True # 🎯 新增参数 ) -> List[Dict[str, Any]]: """ 运行 table_recognition_v2 管线,输出 JSON、可视化图,且将每个表格HTML转为Markdown保存。 """ output_path = Path(output_dir).resolve() output_path.mkdir(parents=True, exist_ok=True) # 🎯 应用适配器 adapter_applied = False if use_enhanced_adapter: adapter_applied = apply_table_recognition_adapter() if adapter_applied: print("🎯 Enhanced table recognition adapter activated") else: print("⚠️ Failed to apply adapter, using original implementation") print(f"Initializing pipeline '{pipeline_cfg}' on device '{device}'...") try: os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning' pipeline = create_pipeline(pipeline_cfg, device=device) print(f"Pipeline initialized successfully on {device}") except Exception as e: print(f"Failed to initialize pipeline: {e}", file=sys.stderr) if adapter_applied: restore_original_function() return [] try: results_all: List[Dict[str, Any]] = [] total = len(image_paths) print(f"Processing {total} images with table_recognition_v2") print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}") print(f"🎯 增强适配器: {'启用' if adapter_applied else '禁用'}") with tqdm(total=total, desc="Processing images", unit="img", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar: for img_path in image_paths: start = time.time() try: outputs = pipeline.predict( img_path, use_doc_orientation_classify=True, use_doc_unwarping=False, use_layout_detection=True, use_ocr_results_with_table_cells=True, use_table_orientation_classify=True, use_wired_table_cells_trans_to_html=True, # 🎯 注意:适配器模式下不需要这个参数 # use_table_cells_split_ocr=False, ) cost = time.time() - start # 一般每张图片只返回一个结果 for idx, res in enumerate(outputs): if idx > 0: raise ValueError("Multiple results found for a single image") input_path = Path(res["input_path"]) base_name = input_path.stem res.save_all(save_path=output_path.as_posix()) # 保存所有结果到指定路径 # 保存结构化JSON json_res = res.json.get("res", res.json) saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers) saved_md = save_markdown_tables(md_tables, str(output_path), base_name, normalize_numbers=normalize_numbers) results_all.append({ "image_path": str(input_path), "success": True, "time_sec": cost, "device": device, "json_path": saved_json, "markdown_path": saved_md, "tables_detected": len(md_tables), "is_pdf_page": "_page_" in input_path.name, "normalize_numbers": normalize_numbers, "changes_applied": changes_count > 0, "character_changes_count": changes_count, }) pbar.update(1) ok = sum(1 for r in results_all if r.get("success")) pbar.set_postfix(time=f"{cost:.2f}s", ok=ok) except Exception as e: traceback.print_exc() results_all.append({ "image_path": str(img_path), "success": False, "time_sec": 0, "device": device, "error": str(e) }) pbar.update(1) pbar.set_postfix_str("error") return results_all finally: # 🎯 清理:恢复原始函数 if adapter_applied: restore_original_function() print("🔄 Original function restored") def main(): parser = argparse.ArgumentParser(description="table_recognition_v2 单管线运行(输出Markdown表格)") g = parser.add_mutually_exclusive_group(required=True) g.add_argument("--input_file", type=str, help="单个文件(图片或PDF)") g.add_argument("--input_dir", type=str, help="目录") g.add_argument("--input_file_list", type=str, help="文件列表(每行一个路径)") g.add_argument("--input_csv", type=str, help="CSV,含 image_path 与 status 列") parser.add_argument("--output_dir", type=str, required=True, help="输出目录") parser.add_argument("--pipeline", type=str, default="./my_config/table_recognition_v2.yaml", help="管线名称或配置文件路径(默认使用本仓库的 table_recognition_v2.yaml)") parser.add_argument("--device", type=str, default="gpu:0", help="gpu:0 或 cpu") parser.add_argument("--pdf_dpi", type=int, default=200, help="PDF 转图像 DPI") parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化(仅对Markdown内容生效)") parser.add_argument("--test_mode", action="store_true", help="仅处理前20个文件") parser.add_argument("--collect_results", type=str, help="将处理状态收集到指定CSV") parser.add_argument("--no-adapter", action="store_true", help="禁用增强适配器") args = parser.parse_args() normalize_numbers = not args.no_normalize use_enhanced_adapter = not args.no_adapter # 复用 ppstructurev3_utils 的输入收集逻辑 input_files = get_input_files(args) if not input_files: print("❌ No input files found or processed") return 1 if args.test_mode: input_files = input_files[:20] print(f"Test mode: processing only {len(input_files)} images") print(f"Using device: {args.device}") start = time.time() results = process_images_with_table_pipeline( input_files, args.pipeline, args.device, args.output_dir, normalize_numbers=normalize_numbers, use_enhanced_adapter=use_enhanced_adapter # 🎯 传递参数 ) total_time = time.time() - start success = sum(1 for r in results if r.get("success")) failed = len(results) - success pdf_pages = sum(1 for r in results if r.get("is_pdf_page", False)) total_tables = sum(r.get("tables_detected", 0) for r in results if r.get("success")) print("\n" + "="*60) print("✅ Processing completed!") print("📊 Statistics:") print(f" Total files processed: {len(input_files)}") print(f" PDF pages processed: {pdf_pages}") print(f" Regular images processed: {len(input_files) - pdf_pages}") print(f" Successful: {success}") print(f" Failed: {failed}") print("⏱️ Performance:") print(f" Total time: {total_time:.2f} seconds") if total_time > 0: print(f" Throughput: {len(input_files) / total_time:.2f} files/second") print(f" Avg time per file: {total_time / len(input_files):.2f} seconds") print(f" Tables detected (total): {total_tables}") # 汇总保存 out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) summary = { "stats": { "total_files": len(input_files), "pdf_pages": pdf_pages, "regular_images": len(input_files) - pdf_pages, "success_count": success, "error_count": failed, "total_time_sec": total_time, "throughput_fps": len(input_files) / total_time if total_time > 0 else 0, "avg_time_per_file_sec": total_time / len(input_files) if len(input_files) > 0 else 0, "pipeline": args.pipeline, "device": args.device, "normalize_numbers": normalize_numbers, "total_tables": total_tables, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), }, "results": results } out_json = out_dir / f"{Path(args.output_dir).name}_table_recognition_v2.json" with open(out_json, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) print(f"💾 Results saved to: {out_json}") # 处理状态汇总CSV(可选) try: if args.collect_results: from utils import collect_pid_files processed_files = collect_pid_files(out_json.as_posix()) csv_path = Path(args.collect_results).resolve() with open(csv_path, "w", encoding="utf-8") as f: f.write("image_path,status\n") for file_path, status in processed_files: f.write(f"{file_path},{status}\n") print(f"💾 Processed files saved to: {csv_path}") except Exception as e: print(f"⚠️ Failed to save processed files CSV: {e}") return 0 if __name__ == "__main__": print("🚀 启动 table_recognition_v2 单管线处理程序...") if len(sys.argv) == 1: # 演示默认参数(请按需修改) demo = { "--input_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img", "--output_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/table_recognition_v2_Results", "--pipeline": "./my_config/table_recognition_v2.yaml", "--device": "cpu", } # demo = { # # "--input_file": "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_006.png", # "--input_file": "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_005.png", # "--output_dir": "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/table_recognition_v2_Results", # "--pipeline": "./my_config/table_recognition_v2.yaml", # "--device": "cpu", # } sys.argv = [sys.argv[0]] + [kv for kv in sum(demo.items(), ())] sys.exit(main())