Jelajahi Sumber

feat: 新增表格识别管线,支持HTML转Markdown及结果保存

zhch158_admin 1 bulan lalu
induk
melakukan
a5cb6a797a
1 mengubah file dengan 295 tambahan dan 0 penghapusan
  1. 295 0
      zhch/ table_recognition_v2_single_process.py

+ 295 - 0
zhch/ table_recognition_v2_single_process.py

@@ -0,0 +1,295 @@
+"""仅运行 table_recognition_v2 管线,并将表格HTML转为Markdown保存"""
+import os
+import re
+import sys
+import json
+import time
+import argparse
+import traceback
+import warnings
+from pathlib import Path
+from typing import List, Dict, Any, Tuple, Optional
+
+warnings.filterwarnings("ignore", message="To copy construct from a tensor")
+warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
+warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
+
+from paddlex import create_pipeline
+from tqdm import tqdm
+from dotenv import load_dotenv
+load_dotenv(override=True)
+
+# 复用你现有的输入获取与保存工具
+from ppstructurev3_utils import (
+    get_input_files,
+    save_output_images,   # 支持保存 result.img 中的可视化
+)
+from utils import normalize_markdown_table
+
+def html_table_to_markdown(html: str) -> str:
+    """
+    将简单HTML表格转换为Markdown表格。
+    支持thead/tbody/tr/td/th;对嵌套复杂标签仅提取纯文本。
+    """
+
+    #去掉<html><body>,以及结尾的</body></html>
+    html = re.sub(r'</?html>', '', html, flags=re.IGNORECASE)
+    html = re.sub(r'</?body>', '', html, flags=re.IGNORECASE)
+    html = html.strip()
+    return html
+ 
+def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str
+                     ) -> Tuple[str, List[str]]:
+    if not json_res:
+        return "", []
+
+    # 从 table_res_list 中取出 pred_html 转为 Markdown
+    table_list = (json_res or {}).get("table_res_list", []) or []
+    md_tables = []
+    changes_count = 0
+    for idx, t in enumerate(table_list):
+        html = t.get("pred_html", "")
+        if not html:
+            continue
+                # 2. 标准化 table_res_list 中的HTML表格
+        else:
+            normalized_html = normalize_markdown_table(html)
+            
+            if html != normalized_html:
+                json_res['table_res_list'][idx]['pred_html'] = normalized_html
+                changes_count += len([1 for o, n in zip(html, normalized_html) if o != n])
+        md_tables.append(html_table_to_markdown(html))
+
+    # 保存 JSON 结果
+    out_dir = Path(output_dir).resolve()
+    out_dir.mkdir(parents=True, exist_ok=True)
+    json_fp = out_dir / f"{base_name}.json"
+    with open(json_fp, "w", encoding="utf-8") as f:
+        json.dump(json_res, f, ensure_ascii=False, indent=2)
+
+    return json_fp.as_posix(), md_tables
+
+def save_markdown_tables(md_tables: List[str], output_dir: str, base_name: str,
+                         normalize_numbers: bool = True) -> str:
+    """
+    将多个Markdown表格分别保存为 base_name_table_{i}.md,返回保存路径列表。
+    同时生成一个合并文件 base_name_tables.md。
+    """
+    out_dir = Path(output_dir).resolve()
+    out_dir.mkdir(parents=True, exist_ok=True)
+
+    contents = []
+    for i, md in enumerate(md_tables, 1):
+        content = normalize_markdown_table(md) if normalize_numbers else md
+        contents.append(content)
+
+    markdown_path = out_dir / f"{base_name}.md"
+    with open(markdown_path, "w", encoding="utf-8") as f:
+        for content in contents:
+            f.write(content + "\n\n")
+
+    return markdown_path.as_posix()
+
+def process_images_with_table_pipeline(
+    image_paths: List[str],
+    pipeline_cfg: str = "./my_config/table_recognition_v2.yaml",
+    device: str = "gpu:0",
+    output_dir: str = "./output",
+    normalize_numbers: bool = True
+) -> List[Dict[str, Any]]:
+    """
+    运行 table_recognition_v2 管线,输出 JSON、可视化图,且将每个表格HTML转为Markdown保存。
+    """
+    output_path = Path(output_dir).resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    print(f"Initializing pipeline '{pipeline_cfg}' on device '{device}'...")
+    try:
+        os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
+        pipeline = create_pipeline(pipeline_cfg, device=device)
+        print(f"Pipeline initialized successfully on {device}")
+    except Exception as e:
+        print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
+        traceback.print_exc()
+        return []
+
+    results_all: List[Dict[str, Any]] = []
+    total = len(image_paths)
+    print(f"Processing {total} images with table_recognition_v2")
+    print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
+
+    with tqdm(total=total, desc="Processing images", unit="img",
+              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+        for img_path in image_paths:
+            img_path = str(img_path)
+            start = time.time()
+            try:
+                outputs = pipeline.predict(
+                    img_path,
+                    use_doc_preprocessor=True,
+                    use_layout_detection=True,
+                    use_ocr_model=True
+                )
+                cost = time.time() - start
+
+                # 一般每张图片只返回一个结果
+                for idx, res in enumerate(outputs):
+                    if idx > 0:
+                        raise ValueError("Multiple results found for a single image")
+
+                    input_path = Path(res["input_path"])
+                    base_name = input_path.stem
+
+                    res.save_all(save_path=output_path.as_posix())  # 保存所有结果到指定路径
+                    # 保存结构化JSON
+                    json_res = res.json.get("res", res.json)
+
+                    saved_json, md_tables = save_json_tables(json_res, str(output_path), base_name)
+    
+                    saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
+                                                     normalize_numbers=normalize_numbers)
+
+                    results_all.append({
+                        "image_path": str(input_path),
+                        "success": True,
+                        "time_sec": cost,
+                        "device": device,
+                        "json_path": saved_json,
+                        "markdown_path": saved_md,
+                        "tables_detected": len(md_tables),
+                        "is_pdf_page": "_page_" in input_path.name
+                    })
+
+                pbar.update(1)
+                ok = sum(1 for r in results_all if r.get("success"))
+                pbar.set_postfix(time=f"{cost:.2f}s", ok=ok)
+
+            except Exception as e:
+                traceback.print_exc()
+                results_all.append({
+                    "image_path": str(img_path),
+                    "success": False,
+                    "time_sec": 0,
+                    "device": device,
+                    "error": str(e)
+                })
+                pbar.update(1)
+                pbar.set_postfix_str("error")
+
+    return results_all
+
+def main():
+    parser = argparse.ArgumentParser(description="table_recognition_v2 单管线运行(输出Markdown表格)")
+    g = parser.add_mutually_exclusive_group(required=True)
+    g.add_argument("--input_file", type=str, help="单个文件(图片或PDF)")
+    g.add_argument("--input_dir", type=str, help="目录")
+    g.add_argument("--input_file_list", type=str, help="文件列表(每行一个路径)")
+    g.add_argument("--input_csv", type=str, help="CSV,含 image_path 与 status 列")
+
+    parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
+    parser.add_argument("--pipeline", type=str, default="./my_config/table_recognition_v2.yaml",
+                        help="管线名称或配置文件路径(默认使用本仓库的 table_recognition_v2.yaml)")
+    parser.add_argument("--device", type=str, default="gpu:0", help="gpu:0 或 cpu")
+    parser.add_argument("--pdf_dpi", type=int, default=200, help="PDF 转图像 DPI")
+    parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化(仅对Markdown内容生效)")
+    parser.add_argument("--test_mode", action="store_true", help="仅处理前20个文件")
+    parser.add_argument("--collect_results", type=str, help="将处理状态收集到指定CSV")
+
+    args = parser.parse_args()
+    normalize_numbers = not args.no_normalize
+
+    # 复用 ppstructurev3_utils 的输入收集逻辑
+    input_files = get_input_files(args)
+    if not input_files:
+        print("❌ No input files found or processed")
+        return 1
+    if args.test_mode:
+        input_files = input_files[:20]
+        print(f"Test mode: processing only {len(input_files)} images")
+
+    print(f"Using device: {args.device}")
+    start = time.time()
+    results = process_images_with_table_pipeline(
+        input_files,
+        args.pipeline,
+        args.device,
+        args.output_dir,
+        normalize_numbers=normalize_numbers
+    )
+    total_time = time.time() - start
+
+    success = sum(1 for r in results if r.get("success"))
+    failed = len(results) - success
+    pdf_pages = sum(1 for r in results if r.get("is_pdf_page", False))
+    total_tables = sum(r.get("tables_detected", 0) for r in results if r.get("success"))
+
+    print("\n" + "="*60)
+    print("✅ Processing completed!")
+    print("📊 Statistics:")
+    print(f"  Total files processed: {len(input_files)}")
+    print(f"  PDF pages processed: {pdf_pages}")
+    print(f"  Regular images processed: {len(input_files) - pdf_pages}")
+    print(f"  Successful: {success}")
+    print(f"  Failed: {failed}")
+    print("⏱️ Performance:")
+    print(f"  Total time: {total_time:.2f} seconds")
+    if total_time > 0:
+        print(f"  Throughput: {len(input_files) / total_time:.2f} files/second")
+        print(f"  Avg time per file: {total_time / len(input_files):.2f} seconds")
+    print(f"  Tables detected (total): {total_tables}")
+
+    # 汇总保存
+    out_dir = Path(args.output_dir)
+    out_dir.mkdir(parents=True, exist_ok=True)
+    summary = {
+        "stats": {
+            "total_files": len(input_files),
+            "pdf_pages": pdf_pages,
+            "regular_images": len(input_files) - pdf_pages,
+            "success_count": success,
+            "error_count": failed,
+            "total_time_sec": total_time,
+            "throughput_fps": len(input_files) / total_time if total_time > 0 else 0,
+            "avg_time_per_file_sec": total_time / len(input_files) if len(input_files) > 0 else 0,
+            "pipeline": args.pipeline,
+            "device": args.device,
+            "normalize_numbers": normalize_numbers,
+            "total_tables": total_tables,
+            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
+        },
+        "results": results
+    }
+    out_json = out_dir / f"{Path(args.output_dir).name}_table_recognition_v2.json"
+    with open(out_json, "w", encoding="utf-8") as f:
+        json.dump(summary, f, ensure_ascii=False, indent=2)
+    print(f"💾 Results saved to: {out_json}")
+
+    # 处理状态汇总CSV(可选)
+    try:
+        if args.collect_results:
+            from utils import collect_pid_files
+            processed_files = collect_pid_files(out_json.as_posix())
+            csv_path = Path(args.collect_results).resolve()
+            with open(csv_path, "w", encoding="utf-8") as f:
+                f.write("image_path,status\n")
+                for file_path, status in processed_files:
+                    f.write(f"{file_path},{status}\n")
+            print(f"💾 Processed files saved to: {csv_path}")
+    except Exception as e:
+        print(f"⚠️ Failed to save processed files CSV: {e}")
+
+    return 0
+
+if __name__ == "__main__":
+    print("🚀 启动 table_recognition_v2 单管线处理程序...")
+    if len(sys.argv) == 1:
+        # 演示默认参数(请按需修改)
+        demo = {
+            "--input_file": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img",
+            "--output_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/table_recognition_v2_Results",
+            "--pipeline": "./my_config/table_recognition_v2.yaml",
+            "--device": "cpu",
+        }
+        sys.argv = [sys.argv[0]] + [kv for kv in sum(demo.items(), ())]
+
+    sys.exit(main())