浏览代码

feat: 新增单一模型预测功能,支持图像/PDF处理及结果保存

zhch158_admin 1 月之前
父节点
当前提交
4780efe026
共有 1 个文件被更改,包括 256 次插入0 次删除
  1. 256 0
      zhch/model_single_process.py

+ 256 - 0
zhch/model_single_process.py

@@ -0,0 +1,256 @@
+import os
+import sys
+import time
+import json
+import argparse
+import traceback
+from pathlib import Path
+from typing import List, Dict, Any
+
+from tqdm import tqdm
+from dotenv import load_dotenv
+load_dotenv(override=True)
+
+from paddlex import create_model
+
+# 复用你现有的输入收集与PDF转图像逻辑
+from ppstructurev3_utils import get_input_files
+
+# 定义paddlex模型名称列表
+MODEL_LIST = [
+    # OCR文本检测模型
+    {"model_name": "PP-OCRv5_mobile_det", "description": "轻量级OCR文本检测模型,适用于移动端部署"},
+    {"model_name": "PP-OCRv5_server_det", "description": "PP-OCRv5_rec 是新一代文本识别模型。该模型致力于以单一模型高效、精准地支持简体中文、繁体中文、英文、日文四种主要语言,以及手写、竖版、拼音、生僻字等复杂文本场景的识别。在保持识别效果的同时,兼顾推理速度和模型鲁棒性,为各种场景下的文档理解提供高效、精准的技术支撑。"},
+    
+    # OCR文本识别模型
+    {"model_name": "PP-OCRv5_mobile_rec", "description": "轻量级OCR文本识别模型,适用于移动端部署"},
+    {"model_name": "PP-OCRv5_server_rec", "description": "服务端OCR文本识别模型,高精度识别"},
+
+    # 版面区域检测模型
+    {"model_name": "PP-DocLayout_plus-L", "description": "版面检测模型,包含20个常见的类别:文档标题、段落标题、文本、页码、摘要、目录、参考文献、脚注、页眉、页脚、算法、公式、公式编号、图像、表格、图和表标题(图标题、表格标题和图表标题)、印章、图表、侧栏文本和参考文献内容"},
+    {"model_name": "PP-DocBlockLayout", "description": "文档图像版面子模块检测,包含1个 版面区域 类别,能检测多栏的报纸、杂志的每个子文章的文本区域"},
+    
+    # 表格分类模型
+    {"model_name": "PP-LCNet_x1_0_table_cls", "description": "wired_table, wireless_table"},
+    
+    # 表格识别模型
+    {"model_name": "SLANet_plus", "description": "SLANet_plus 是百度飞桨视觉团队自研的表格结构识别模型 SLANet 的增强版。相较于 SLANet,SLANet_plus 对无线表、复杂表格的识别能力得到了大幅提升,并降低了模型对表格定位准确性的敏感度,即使表格定位出现偏移,也能够较准确地进行识别。"},
+    {"model_name": "SLANeXt_wired", "description": "SLANeXt 系列是百度飞桨视觉团队自研的新一代表格结构识别模型。相较于 SLANet 和 SLANet_plus,SLANeXt 专注于对表格结构进行识别,并且对有线表格(wired)和无线表格(wireless)的识别分别训练了专用的权重,对各类型表格的识别能力都得到了明显提高,特别是对有线表格的识别能力得到了大幅提升。"},
+    {"model_name": "SLANeXt_wireless", "description": "SLANeXt 系列是百度飞桨视觉团队自研的新一代表格结构识别模型。相较于 SLANet 和 SLANet_plus,SLANeXt 专注于对表格结构进行识别,并且对有线表格(wired)和无线表格(wireless)的识别分别训练了专用的权重,对各类型表格的识别能力都得到了明显提高,特别是对无线表格的识别能力得到了大幅提升。"},
+
+    # 表格单元格识别模型
+    {"model_name": "RT-DETR-L_wired_table_cell_det", "description": "有线表格单元格检测模型"},
+    {"model_name": "RT-DETR-L_wireless_table_cell_det", "description": "无线表格单元格检测模型"},
+
+    # 公式识别模型
+    {"model_name": "PP-FormulaNet_plus-L", "description": "负责将图像中的数学公式转换为可编辑的文本或计算机可识别的格式。该模块的性能直接影响到整个OCR系统的准确性和效率。公式识别模块通常会输出数学公式的 LaTeX 或 MathML 代码"},
+    
+    # 文档图像方向分类模型
+    {"model_name": "PP-LCNet_x1_0_doc_ori", "description": "基于PP-LCNet_x1_0的文档图像分类模型,含有四个类别,即0度,90度,180度,270度"},
+
+    # 文本图像矫正模型
+    {"model_name": "UVDoc", "description": "针对图像进行几何变换,以纠正图像中的文档扭曲、倾斜、透视变形等问题,以供后续的文本识别进行更加准确"},
+
+    # 印章检测模型
+    {"model_name": "PP-OCRv4_mobile_seal_det", "description": "PP-OCRv4的移动端印章文本检测模型,效率更高,适合在端侧部署"},
+    {"model_name": "PP-OCRv4_server_seal_det", "description": "PP-OCRv4的服务端印章文本检测模型,精度更高,适合在较好的服务器上部署"},
+
+]
+
+# 需要字典输入的模型(Doc VLM / 图表到表格)
+DICT_INPUT_MODELS = {
+    "PP-Chart2Table",
+    "PP-DocBee-2B",
+    "PP-DocBee-7B",
+    "PP-DocBee2-3B",
+}
+
+def init_model(model_name: str, device: str = "gpu:0"):
+    """
+    初始化单一模型。若不支持device参数则回退到默认构造。
+    """
+    try:
+        model = create_model(model_name=model_name, device=device)
+    except TypeError:
+        model = create_model(model_name=model_name)
+    return model
+
+def predict_on_images(
+    model_name: str,
+    image_paths: List[str],
+    output_dir: str,
+    device: str = "gpu:0",
+    batch_size: int = 1,
+    layout_nms: bool = True,
+    query: str = "请将图表转换为表格格式"
+) -> List[Dict[str, Any]]:
+    """
+    对一组图片运行任意单一模型,保存可视化与原始结果,并返回汇总信息。
+    """
+    output_base = Path(output_dir).resolve()
+    output_base.mkdir(parents=True, exist_ok=True)
+
+    model = init_model(model_name, device=device)
+
+    # 一些检测/版面模型支持 layout_nms
+    predict_kwargs = {}
+    if hasattr(model, "_predictor") and hasattr(model._predictor, "layout_nms"):
+        predict_kwargs["layout_nms"] = layout_nms
+
+    results_summary: List[Dict[str, Any]] = []
+
+    with tqdm(total=len(image_paths), desc=f"{model_name} predicting", unit="img",
+              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+        for img_path in image_paths:
+            img_path = str(img_path)
+            # img_name = Path(img_path).stem
+            # img_out_dir = output_base / img_name
+            # img_out_dir.mkdir(parents=True, exist_ok=True)
+
+            start = time.time()
+            try:
+                # 针对需要字典输入的模型
+                if model_name in DICT_INPUT_MODELS:
+                    input_data = {"image": img_path, "query": query}
+                    outputs = model.predict(input_data, batch_size=1, **predict_kwargs)
+                else:
+                    outputs = model.predict(img_path, batch_size=batch_size, **predict_kwargs)
+
+                elapsed = time.time() - start
+
+                # 保存模型输出(可视化与结构化)
+                saved_files = []
+                for i, res in enumerate(outputs):
+                    # 子目录区分多结果
+                    # sub_dir = img_out_dir / f"res_{i:02d}"
+                    # sub_dir.mkdir(parents=True, exist_ok=True)
+                    # # 可视化与所有产物
+                    # res.save_all(save_path=sub_dir.as_posix())
+                    # saved_files.append(sub_dir.as_posix())
+                    res.save_all(save_path=output_base.as_posix())
+                    saved_files.append(output_base.as_posix())
+
+                results_summary.append({
+                    "image_path": img_path,
+                    "success": True,
+                    "model_name": model_name,
+                    "device": device,
+                    "batch_size": batch_size,
+                    "layout_nms": layout_nms,
+                    "time_sec": elapsed,
+                    "saved_paths": saved_files
+                })
+
+                pbar.update(1)
+                pbar.set_postfix(time=f"{elapsed:.2f}s", ok=len([r for r in results_summary if r['success']]))
+
+            except Exception as e:
+                elapsed = time.time() - start
+                traceback.print_exc()
+                results_summary.append({
+                    "image_path": img_path,
+                    "success": False,
+                    "model_name": model_name,
+                    "device": device,
+                    "batch_size": batch_size,
+                    "layout_nms": layout_nms,
+                    "time_sec": elapsed,
+                    "error": str(e)
+                })
+                pbar.update(1)
+                pbar.set_postfix_str("error")
+
+    return results_summary
+
+def save_summary(summary: List[Dict[str, Any]], output_dir: str, model_name: str):
+    out_dir = Path(output_dir).resolve()
+    out_dir.mkdir(parents=True, exist_ok=True)
+    stats = {
+        "model_name": model_name,
+        "total": len(summary),
+        "success": sum(1 for r in summary if r.get("success")),
+        "failed": sum(1 for r in summary if not r.get("success")),
+        "avg_time": (sum(r.get("time_sec", 0) for r in summary) / len(summary)) if summary else 0,
+        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
+    }
+    final = {"stats": stats, "results": summary}
+    out_file = out_dir / f"{model_name}_results.json"
+    with open(out_file, "w", encoding="utf-8") as f:
+        json.dump(final, f, ensure_ascii=False, indent=2)
+    print(f"💾 Summary saved to: {out_file}")
+
+def main():
+    parser = argparse.ArgumentParser(description="Run any single PaddleX model on images/PDFs (similar to ppstructurev3_single_process.py)")
+    # 输入源(与 ppstructurev3_single_process 一致)
+    group = parser.add_mutually_exclusive_group(required=True)
+    group.add_argument("--input_file", type=str, help="单个文件(图片或PDF)")
+    group.add_argument("--input_dir", type=str, help="目录(扫描图片或PDF)")
+    group.add_argument("--input_file_list", type=str, help="文件列表(每行一个路径)")
+    group.add_argument("--input_csv", type=str, help="CSV,含 image_path 与 status 列")
+
+    parser.add_argument("--model_name", type=str, required=True, help="要运行的模型名,如 PP-OCRv5_server_det / PP-DocLayout_plus-L / SLANeXt_wireless 等")
+    parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
+    parser.add_argument("--device", type=str, default="gpu:0", help="设备,如 gpu:0 或 cpu")
+    parser.add_argument("--pdf_dpi", type=int, default=200, help="PDF 转图像的 DPI")
+    parser.add_argument("--batch_size", type=int, default=1, help="预测 batch size(多数单图模型支持)")
+    parser.add_argument("--no_layout_nms", action="store_true", help="关闭 layout_nms(若模型支持)")
+    parser.add_argument("--query", type=str, default="请将图表转换为表格格式", help="仅对需要字典输入的模型生效,如 PP-Chart2Table")
+    parser.add_argument("--test_mode", action="store_true", help="仅处理前 20 个文件")
+
+    args = parser.parse_args()
+
+    # 复用 ppstructurev3_utils 的文件收集能力(含PDF转图像)
+    class DummyArgs:
+        input_file = args.input_file
+        input_dir = args.input_dir
+        input_file_list = args.input_file_list
+        input_csv = args.input_csv
+        output_dir = args.output_dir
+        pdf_dpi = args.pdf_dpi
+        test_mode = args.test_mode
+
+    input_files = get_input_files(DummyArgs)
+    if not input_files:
+        print("❌ No input files found.")
+        return 1
+    if args.test_mode:
+        input_files = input_files[:20]
+        print(f"Test mode: {len(input_files)} files")
+
+    print(f"🚀 Model: {args.model_name} | Device: {args.device} | Files: {len(input_files)}")
+
+    summary = predict_on_images(
+        model_name=args.model_name,
+        image_paths=input_files,
+        output_dir=args.output_dir,
+        device=args.device,
+        batch_size=args.batch_size,
+        layout_nms=not args.no_layout_nms,
+        query=args.query
+    )
+    save_summary(summary, args.output_dir, args.model_name)
+    return 0
+
+if __name__ == "__main__":
+    # 无参数示例(便于快速体验)
+    if len(sys.argv) == 1:
+        model_name = "RT-DETR-L_wired_table_cell_det"
+        # demo = {
+        #     "--model_name": model_name,
+        #     "--input_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img",
+        #     "--output_dir": f"/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/{model_name}_Results",
+        #     "--device": "cpu",
+        # }
+        
+        model_name = "RT-DETR-L_wireless_table_cell_det"
+        demo = {
+            "--model_name": model_name,
+            "--input_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水.img",
+            "--output_dir": f"/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/{model_name}_Results",
+            "--device": "cpu",
+        }
+
+        sys.argv = [sys.argv[0]] + [kv for pair in demo.items() for kv in pair]
+        print("ℹ️  No args provided. Running demo with:", demo)
+
+    sys.exit(main())