| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 |
- """仅运行 table_recognition_v2 管线,并将表格HTML转为Markdown保存"""
- import os
- import re
- import sys
- import json
- import time
- import argparse
- import traceback
- import warnings
- from pathlib import Path
- from typing import List, Dict, Any, Tuple, Optional
- warnings.filterwarnings("ignore", message="To copy construct from a tensor")
- warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
- warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
- from paddlex import create_pipeline
- from tqdm import tqdm
- from dotenv import load_dotenv
- load_dotenv(override=True)
- # 复用你现有的输入获取与保存工具
- from ppstructurev3_utils import (
- get_input_files,
- save_output_images, # 支持保存 result.img 中的可视化
- )
- from utils import normalize_markdown_table
- def html_table_to_markdown(html: str) -> str:
- """
- 将简单HTML表格转换为Markdown表格。
- 支持thead/tbody/tr/td/th;对嵌套复杂标签仅提取纯文本。
- """
- #去掉<html><body>,以及结尾的</body></html>
- html = re.sub(r'</?html>', '', html, flags=re.IGNORECASE)
- html = re.sub(r'</?body>', '', html, flags=re.IGNORECASE)
- html = html.strip()
- return html
- def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str,
- normalize_numbers: bool = True) -> Tuple[str, List[str], int]:
- if not json_res:
- return "", [], 0
- # 从 table_res_list 中取出 pred_html 转为 Markdown
- table_list = (json_res or {}).get("table_res_list", []) or []
- md_tables = []
- changes_count = 0
- for idx, t in enumerate(table_list):
- html = t.get("pred_html", "")
- if not html:
- continue
- # 2. 标准化 table_res_list 中的HTML表格
- elif normalize_numbers:
- normalized_html = normalize_markdown_table(html)
-
- if html != normalized_html:
- json_res['table_res_list'][idx]['pred_html'] = normalized_html
- changes_count += len([1 for o, n in zip(html, normalized_html) if o != n])
- md_tables.append(html_table_to_markdown(html))
- # 保存 JSON 结果
- out_dir = Path(output_dir).resolve()
- out_dir.mkdir(parents=True, exist_ok=True)
- json_fp = out_dir / f"{base_name}.json"
- with open(json_fp, "w", encoding="utf-8") as f:
- json.dump(json_res, f, ensure_ascii=False, indent=2)
- return json_fp.as_posix(), md_tables, changes_count
- def save_markdown_tables(md_tables: List[str], output_dir: str, base_name: str,
- normalize_numbers: bool = True) -> str:
- """
- 将多个Markdown表格分别保存为 base_name_table_{i}.md,返回保存路径列表。
- 同时生成一个合并文件 base_name_tables.md。
- """
- out_dir = Path(output_dir).resolve()
- out_dir.mkdir(parents=True, exist_ok=True)
- contents = []
- for i, md in enumerate(md_tables, 1):
- content = normalize_markdown_table(md) if normalize_numbers else md
- contents.append(content)
- markdown_path = out_dir / f"{base_name}.md"
- with open(markdown_path, "w", encoding="utf-8") as f:
- for content in contents:
- f.write(content + "\n\n")
- return markdown_path.as_posix()
- def process_images_with_table_pipeline(
- image_paths: List[str],
- pipeline_cfg: str = "./my_config/table_recognition_v2.yaml",
- device: str = "gpu:0",
- output_dir: str = "./output",
- normalize_numbers: bool = True
- ) -> List[Dict[str, Any]]:
- """
- 运行 table_recognition_v2 管线,输出 JSON、可视化图,且将每个表格HTML转为Markdown保存。
- """
- output_path = Path(output_dir).resolve()
- output_path.mkdir(parents=True, exist_ok=True)
- print(f"Initializing pipeline '{pipeline_cfg}' on device '{device}'...")
- try:
- os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
- pipeline = create_pipeline(pipeline_cfg, device=device)
- print(f"Pipeline initialized successfully on {device}")
- except Exception as e:
- print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
- traceback.print_exc()
- return []
- results_all: List[Dict[str, Any]] = []
- total = len(image_paths)
- print(f"Processing {total} images with table_recognition_v2")
- print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
- with tqdm(total=total, desc="Processing images", unit="img",
- bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
- for img_path in image_paths:
- start = time.time()
- try:
- outputs = pipeline.predict(
- img_path,
- use_doc_preprocessor=True,
- use_layout_detection=True,
- use_ocr_model=True
- )
- cost = time.time() - start
- # 一般每张图片只返回一个结果
- for idx, res in enumerate(outputs):
- if idx > 0:
- raise ValueError("Multiple results found for a single image")
- input_path = Path(res["input_path"])
- base_name = input_path.stem
- res.save_all(save_path=output_path.as_posix()) # 保存所有结果到指定路径
- # 保存结构化JSON
- json_res = res.json.get("res", res.json)
- saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers)
-
- saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
- normalize_numbers=normalize_numbers)
- results_all.append({
- "image_path": str(input_path),
- "success": True,
- "time_sec": cost,
- "device": device,
- "json_path": saved_json,
- "markdown_path": saved_md,
- "tables_detected": len(md_tables),
- "is_pdf_page": "_page_" in input_path.name,
- "normalize_numbers": normalize_numbers,
- "changes_applied": changes_count > 0,
- "character_changes_count": changes_count,
- })
- pbar.update(1)
- ok = sum(1 for r in results_all if r.get("success"))
- pbar.set_postfix(time=f"{cost:.2f}s", ok=ok)
- except Exception as e:
- traceback.print_exc()
- results_all.append({
- "image_path": str(img_path),
- "success": False,
- "time_sec": 0,
- "device": device,
- "error": str(e)
- })
- pbar.update(1)
- pbar.set_postfix_str("error")
- return results_all
- def main():
- parser = argparse.ArgumentParser(description="table_recognition_v2 单管线运行(输出Markdown表格)")
- g = parser.add_mutually_exclusive_group(required=True)
- g.add_argument("--input_file", type=str, help="单个文件(图片或PDF)")
- g.add_argument("--input_dir", type=str, help="目录")
- g.add_argument("--input_file_list", type=str, help="文件列表(每行一个路径)")
- g.add_argument("--input_csv", type=str, help="CSV,含 image_path 与 status 列")
- parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
- parser.add_argument("--pipeline", type=str, default="./my_config/table_recognition_v2.yaml",
- help="管线名称或配置文件路径(默认使用本仓库的 table_recognition_v2.yaml)")
- parser.add_argument("--device", type=str, default="gpu:0", help="gpu:0 或 cpu")
- parser.add_argument("--pdf_dpi", type=int, default=200, help="PDF 转图像 DPI")
- parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化(仅对Markdown内容生效)")
- parser.add_argument("--test_mode", action="store_true", help="仅处理前20个文件")
- parser.add_argument("--collect_results", type=str, help="将处理状态收集到指定CSV")
- args = parser.parse_args()
- normalize_numbers = not args.no_normalize
- # 复用 ppstructurev3_utils 的输入收集逻辑
- input_files = get_input_files(args)
- if not input_files:
- print("❌ No input files found or processed")
- return 1
- if args.test_mode:
- input_files = input_files[:20]
- print(f"Test mode: processing only {len(input_files)} images")
- print(f"Using device: {args.device}")
- start = time.time()
- results = process_images_with_table_pipeline(
- input_files,
- args.pipeline,
- args.device,
- args.output_dir,
- normalize_numbers=normalize_numbers
- )
- total_time = time.time() - start
- success = sum(1 for r in results if r.get("success"))
- failed = len(results) - success
- pdf_pages = sum(1 for r in results if r.get("is_pdf_page", False))
- total_tables = sum(r.get("tables_detected", 0) for r in results if r.get("success"))
- print("\n" + "="*60)
- print("✅ Processing completed!")
- print("📊 Statistics:")
- print(f" Total files processed: {len(input_files)}")
- print(f" PDF pages processed: {pdf_pages}")
- print(f" Regular images processed: {len(input_files) - pdf_pages}")
- print(f" Successful: {success}")
- print(f" Failed: {failed}")
- print("⏱️ Performance:")
- print(f" Total time: {total_time:.2f} seconds")
- if total_time > 0:
- print(f" Throughput: {len(input_files) / total_time:.2f} files/second")
- print(f" Avg time per file: {total_time / len(input_files):.2f} seconds")
- print(f" Tables detected (total): {total_tables}")
- # 汇总保存
- out_dir = Path(args.output_dir)
- out_dir.mkdir(parents=True, exist_ok=True)
- summary = {
- "stats": {
- "total_files": len(input_files),
- "pdf_pages": pdf_pages,
- "regular_images": len(input_files) - pdf_pages,
- "success_count": success,
- "error_count": failed,
- "total_time_sec": total_time,
- "throughput_fps": len(input_files) / total_time if total_time > 0 else 0,
- "avg_time_per_file_sec": total_time / len(input_files) if len(input_files) > 0 else 0,
- "pipeline": args.pipeline,
- "device": args.device,
- "normalize_numbers": normalize_numbers,
- "total_tables": total_tables,
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
- },
- "results": results
- }
- out_json = out_dir / f"{Path(args.output_dir).name}_table_recognition_v2.json"
- with open(out_json, "w", encoding="utf-8") as f:
- json.dump(summary, f, ensure_ascii=False, indent=2)
- print(f"💾 Results saved to: {out_json}")
- # 处理状态汇总CSV(可选)
- try:
- if args.collect_results:
- from utils import collect_pid_files
- processed_files = collect_pid_files(out_json.as_posix())
- csv_path = Path(args.collect_results).resolve()
- with open(csv_path, "w", encoding="utf-8") as f:
- f.write("image_path,status\n")
- for file_path, status in processed_files:
- f.write(f"{file_path},{status}\n")
- print(f"💾 Processed files saved to: {csv_path}")
- except Exception as e:
- print(f"⚠️ Failed to save processed files CSV: {e}")
- return 0
- if __name__ == "__main__":
- print("🚀 启动 table_recognition_v2 单管线处理程序...")
- if len(sys.argv) == 1:
- # 演示默认参数(请按需修改)
- demo = {
- "--input_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img",
- "--output_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/table_recognition_v2_Results",
- "--pipeline": "./my_config/table_recognition_v2.yaml",
- "--device": "cpu",
- }
- sys.argv = [sys.argv[0]] + [kv for kv in sum(demo.items(), ())]
- sys.exit(main())
|