|
|
@@ -0,0 +1,295 @@
|
|
|
+"""仅运行 table_recognition_v2 管线,并将表格HTML转为Markdown保存"""
|
|
|
+import os
|
|
|
+import re
|
|
|
+import sys
|
|
|
+import json
|
|
|
+import time
|
|
|
+import argparse
|
|
|
+import traceback
|
|
|
+import warnings
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Dict, Any, Tuple, Optional
|
|
|
+
|
|
|
+warnings.filterwarnings("ignore", message="To copy construct from a tensor")
|
|
|
+warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
|
|
|
+warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
|
|
|
+
|
|
|
+from paddlex import create_pipeline
|
|
|
+from tqdm import tqdm
|
|
|
+from dotenv import load_dotenv
|
|
|
+load_dotenv(override=True)
|
|
|
+
|
|
|
+# 复用你现有的输入获取与保存工具
|
|
|
+from ppstructurev3_utils import (
|
|
|
+ get_input_files,
|
|
|
+ save_output_images, # 支持保存 result.img 中的可视化
|
|
|
+)
|
|
|
+from utils import normalize_markdown_table
|
|
|
+
|
|
|
+def html_table_to_markdown(html: str) -> str:
|
|
|
+ """
|
|
|
+ 将简单HTML表格转换为Markdown表格。
|
|
|
+ 支持thead/tbody/tr/td/th;对嵌套复杂标签仅提取纯文本。
|
|
|
+ """
|
|
|
+
|
|
|
+ #去掉<html><body>,以及结尾的</body></html>
|
|
|
+ html = re.sub(r'</?html>', '', html, flags=re.IGNORECASE)
|
|
|
+ html = re.sub(r'</?body>', '', html, flags=re.IGNORECASE)
|
|
|
+ html = html.strip()
|
|
|
+ return html
|
|
|
+
|
|
|
+def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str
|
|
|
+ ) -> Tuple[str, List[str]]:
|
|
|
+ if not json_res:
|
|
|
+ return "", []
|
|
|
+
|
|
|
+ # 从 table_res_list 中取出 pred_html 转为 Markdown
|
|
|
+ table_list = (json_res or {}).get("table_res_list", []) or []
|
|
|
+ md_tables = []
|
|
|
+ changes_count = 0
|
|
|
+ for idx, t in enumerate(table_list):
|
|
|
+ html = t.get("pred_html", "")
|
|
|
+ if not html:
|
|
|
+ continue
|
|
|
+ # 2. 标准化 table_res_list 中的HTML表格
|
|
|
+ else:
|
|
|
+ normalized_html = normalize_markdown_table(html)
|
|
|
+
|
|
|
+ if html != normalized_html:
|
|
|
+ json_res['table_res_list'][idx]['pred_html'] = normalized_html
|
|
|
+ changes_count += len([1 for o, n in zip(html, normalized_html) if o != n])
|
|
|
+ md_tables.append(html_table_to_markdown(html))
|
|
|
+
|
|
|
+ # 保存 JSON 结果
|
|
|
+ out_dir = Path(output_dir).resolve()
|
|
|
+ out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ json_fp = out_dir / f"{base_name}.json"
|
|
|
+ with open(json_fp, "w", encoding="utf-8") as f:
|
|
|
+ json.dump(json_res, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ return json_fp.as_posix(), md_tables
|
|
|
+
|
|
|
+def save_markdown_tables(md_tables: List[str], output_dir: str, base_name: str,
|
|
|
+ normalize_numbers: bool = True) -> str:
|
|
|
+ """
|
|
|
+ 将多个Markdown表格分别保存为 base_name_table_{i}.md,返回保存路径列表。
|
|
|
+ 同时生成一个合并文件 base_name_tables.md。
|
|
|
+ """
|
|
|
+ out_dir = Path(output_dir).resolve()
|
|
|
+ out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ contents = []
|
|
|
+ for i, md in enumerate(md_tables, 1):
|
|
|
+ content = normalize_markdown_table(md) if normalize_numbers else md
|
|
|
+ contents.append(content)
|
|
|
+
|
|
|
+ markdown_path = out_dir / f"{base_name}.md"
|
|
|
+ with open(markdown_path, "w", encoding="utf-8") as f:
|
|
|
+ for content in contents:
|
|
|
+ f.write(content + "\n\n")
|
|
|
+
|
|
|
+ return markdown_path.as_posix()
|
|
|
+
|
|
|
+def process_images_with_table_pipeline(
|
|
|
+ image_paths: List[str],
|
|
|
+ pipeline_cfg: str = "./my_config/table_recognition_v2.yaml",
|
|
|
+ device: str = "gpu:0",
|
|
|
+ output_dir: str = "./output",
|
|
|
+ normalize_numbers: bool = True
|
|
|
+) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 运行 table_recognition_v2 管线,输出 JSON、可视化图,且将每个表格HTML转为Markdown保存。
|
|
|
+ """
|
|
|
+ output_path = Path(output_dir).resolve()
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ print(f"Initializing pipeline '{pipeline_cfg}' on device '{device}'...")
|
|
|
+ try:
|
|
|
+ os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
|
|
|
+ pipeline = create_pipeline(pipeline_cfg, device=device)
|
|
|
+ print(f"Pipeline initialized successfully on {device}")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+ return []
|
|
|
+
|
|
|
+ results_all: List[Dict[str, Any]] = []
|
|
|
+ total = len(image_paths)
|
|
|
+ print(f"Processing {total} images with table_recognition_v2")
|
|
|
+ print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
|
|
|
+
|
|
|
+ with tqdm(total=total, desc="Processing images", unit="img",
|
|
|
+ bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
|
|
|
+ for img_path in image_paths:
|
|
|
+ img_path = str(img_path)
|
|
|
+ start = time.time()
|
|
|
+ try:
|
|
|
+ outputs = pipeline.predict(
|
|
|
+ img_path,
|
|
|
+ use_doc_preprocessor=True,
|
|
|
+ use_layout_detection=True,
|
|
|
+ use_ocr_model=True
|
|
|
+ )
|
|
|
+ cost = time.time() - start
|
|
|
+
|
|
|
+ # 一般每张图片只返回一个结果
|
|
|
+ for idx, res in enumerate(outputs):
|
|
|
+ if idx > 0:
|
|
|
+ raise ValueError("Multiple results found for a single image")
|
|
|
+
|
|
|
+ input_path = Path(res["input_path"])
|
|
|
+ base_name = input_path.stem
|
|
|
+
|
|
|
+ res.save_all(save_path=output_path.as_posix()) # 保存所有结果到指定路径
|
|
|
+ # 保存结构化JSON
|
|
|
+ json_res = res.json.get("res", res.json)
|
|
|
+
|
|
|
+ saved_json, md_tables = save_json_tables(json_res, str(output_path), base_name)
|
|
|
+
|
|
|
+ saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
|
|
|
+ normalize_numbers=normalize_numbers)
|
|
|
+
|
|
|
+ results_all.append({
|
|
|
+ "image_path": str(input_path),
|
|
|
+ "success": True,
|
|
|
+ "time_sec": cost,
|
|
|
+ "device": device,
|
|
|
+ "json_path": saved_json,
|
|
|
+ "markdown_path": saved_md,
|
|
|
+ "tables_detected": len(md_tables),
|
|
|
+ "is_pdf_page": "_page_" in input_path.name
|
|
|
+ })
|
|
|
+
|
|
|
+ pbar.update(1)
|
|
|
+ ok = sum(1 for r in results_all if r.get("success"))
|
|
|
+ pbar.set_postfix(time=f"{cost:.2f}s", ok=ok)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ traceback.print_exc()
|
|
|
+ results_all.append({
|
|
|
+ "image_path": str(img_path),
|
|
|
+ "success": False,
|
|
|
+ "time_sec": 0,
|
|
|
+ "device": device,
|
|
|
+ "error": str(e)
|
|
|
+ })
|
|
|
+ pbar.update(1)
|
|
|
+ pbar.set_postfix_str("error")
|
|
|
+
|
|
|
+ return results_all
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = argparse.ArgumentParser(description="table_recognition_v2 单管线运行(输出Markdown表格)")
|
|
|
+ g = parser.add_mutually_exclusive_group(required=True)
|
|
|
+ g.add_argument("--input_file", type=str, help="单个文件(图片或PDF)")
|
|
|
+ g.add_argument("--input_dir", type=str, help="目录")
|
|
|
+ g.add_argument("--input_file_list", type=str, help="文件列表(每行一个路径)")
|
|
|
+ g.add_argument("--input_csv", type=str, help="CSV,含 image_path 与 status 列")
|
|
|
+
|
|
|
+ parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
|
|
|
+ parser.add_argument("--pipeline", type=str, default="./my_config/table_recognition_v2.yaml",
|
|
|
+ help="管线名称或配置文件路径(默认使用本仓库的 table_recognition_v2.yaml)")
|
|
|
+ parser.add_argument("--device", type=str, default="gpu:0", help="gpu:0 或 cpu")
|
|
|
+ parser.add_argument("--pdf_dpi", type=int, default=200, help="PDF 转图像 DPI")
|
|
|
+ parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化(仅对Markdown内容生效)")
|
|
|
+ parser.add_argument("--test_mode", action="store_true", help="仅处理前20个文件")
|
|
|
+ parser.add_argument("--collect_results", type=str, help="将处理状态收集到指定CSV")
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+ normalize_numbers = not args.no_normalize
|
|
|
+
|
|
|
+ # 复用 ppstructurev3_utils 的输入收集逻辑
|
|
|
+ input_files = get_input_files(args)
|
|
|
+ if not input_files:
|
|
|
+ print("❌ No input files found or processed")
|
|
|
+ return 1
|
|
|
+ if args.test_mode:
|
|
|
+ input_files = input_files[:20]
|
|
|
+ print(f"Test mode: processing only {len(input_files)} images")
|
|
|
+
|
|
|
+ print(f"Using device: {args.device}")
|
|
|
+ start = time.time()
|
|
|
+ results = process_images_with_table_pipeline(
|
|
|
+ input_files,
|
|
|
+ args.pipeline,
|
|
|
+ args.device,
|
|
|
+ args.output_dir,
|
|
|
+ normalize_numbers=normalize_numbers
|
|
|
+ )
|
|
|
+ total_time = time.time() - start
|
|
|
+
|
|
|
+ success = sum(1 for r in results if r.get("success"))
|
|
|
+ failed = len(results) - success
|
|
|
+ pdf_pages = sum(1 for r in results if r.get("is_pdf_page", False))
|
|
|
+ total_tables = sum(r.get("tables_detected", 0) for r in results if r.get("success"))
|
|
|
+
|
|
|
+ print("\n" + "="*60)
|
|
|
+ print("✅ Processing completed!")
|
|
|
+ print("📊 Statistics:")
|
|
|
+ print(f" Total files processed: {len(input_files)}")
|
|
|
+ print(f" PDF pages processed: {pdf_pages}")
|
|
|
+ print(f" Regular images processed: {len(input_files) - pdf_pages}")
|
|
|
+ print(f" Successful: {success}")
|
|
|
+ print(f" Failed: {failed}")
|
|
|
+ print("⏱️ Performance:")
|
|
|
+ print(f" Total time: {total_time:.2f} seconds")
|
|
|
+ if total_time > 0:
|
|
|
+ print(f" Throughput: {len(input_files) / total_time:.2f} files/second")
|
|
|
+ print(f" Avg time per file: {total_time / len(input_files):.2f} seconds")
|
|
|
+ print(f" Tables detected (total): {total_tables}")
|
|
|
+
|
|
|
+ # 汇总保存
|
|
|
+ out_dir = Path(args.output_dir)
|
|
|
+ out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ summary = {
|
|
|
+ "stats": {
|
|
|
+ "total_files": len(input_files),
|
|
|
+ "pdf_pages": pdf_pages,
|
|
|
+ "regular_images": len(input_files) - pdf_pages,
|
|
|
+ "success_count": success,
|
|
|
+ "error_count": failed,
|
|
|
+ "total_time_sec": total_time,
|
|
|
+ "throughput_fps": len(input_files) / total_time if total_time > 0 else 0,
|
|
|
+ "avg_time_per_file_sec": total_time / len(input_files) if len(input_files) > 0 else 0,
|
|
|
+ "pipeline": args.pipeline,
|
|
|
+ "device": args.device,
|
|
|
+ "normalize_numbers": normalize_numbers,
|
|
|
+ "total_tables": total_tables,
|
|
|
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ },
|
|
|
+ "results": results
|
|
|
+ }
|
|
|
+ out_json = out_dir / f"{Path(args.output_dir).name}_table_recognition_v2.json"
|
|
|
+ with open(out_json, "w", encoding="utf-8") as f:
|
|
|
+ json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
+ print(f"💾 Results saved to: {out_json}")
|
|
|
+
|
|
|
+ # 处理状态汇总CSV(可选)
|
|
|
+ try:
|
|
|
+ if args.collect_results:
|
|
|
+ from utils import collect_pid_files
|
|
|
+ processed_files = collect_pid_files(out_json.as_posix())
|
|
|
+ csv_path = Path(args.collect_results).resolve()
|
|
|
+ with open(csv_path, "w", encoding="utf-8") as f:
|
|
|
+ f.write("image_path,status\n")
|
|
|
+ for file_path, status in processed_files:
|
|
|
+ f.write(f"{file_path},{status}\n")
|
|
|
+ print(f"💾 Processed files saved to: {csv_path}")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"⚠️ Failed to save processed files CSV: {e}")
|
|
|
+
|
|
|
+ return 0
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ print("🚀 启动 table_recognition_v2 单管线处理程序...")
|
|
|
+ if len(sys.argv) == 1:
|
|
|
+ # 演示默认参数(请按需修改)
|
|
|
+ demo = {
|
|
|
+ "--input_file": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img",
|
|
|
+ "--output_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/table_recognition_v2_Results",
|
|
|
+ "--pipeline": "./my_config/table_recognition_v2.yaml",
|
|
|
+ "--device": "cpu",
|
|
|
+ }
|
|
|
+ sys.argv = [sys.argv[0]] + [kv for kv in sum(demo.items(), ())]
|
|
|
+
|
|
|
+ sys.exit(main())
|