浏览代码

refactor(zhch): 重构测试脚本并添加环境初始化

- 引入 paddle 和 os 模块
- 添加 GPU 检测和内存监控功能
- 通过 dotenv 加载环境变量
- 优化代码结构,将主程序放入 if __name__ == "__main__": 块中
- 更新 pipeline 初始化方式,直接使用设备参数
- 移除冗余的注释代码
zhch158_admin 3 月之前
父节点
当前提交
9f54e44cf5
共有 1 个文件被更改,包括 43 次插入26 次删除
  1. 43 26
      zhch/test_ppstructure_v3.py

+ 43 - 26
zhch/test_ppstructure_v3.py

@@ -1,30 +1,47 @@
 from paddlex import create_pipeline
+import paddle
 import time
 from pathlib import Path
+import os
+from typing import List
 
-input_path = "./sample_data/300674-母公司现金流量表-扫描.png"
-
-pipeline_path = "./PP-StructureV3-zhch.yaml"
-pipeline_name = Path(pipeline_path).stem
-output_path = Path(f"./sample_data/single_pipeline_output/{pipeline_name}/")
-
-pipeline = create_pipeline(pipeline=pipeline_path)
-
-# For Image
-output = pipeline.predict(
-    input=input_path,
-    device="gpu",  # 或者 "gpu" 如果你有 GPU 支持
-    use_doc_orientation_classify=True, # 开启文档方向分类
-    use_doc_unwarping=False, # 开启文档去畸变, 效果不佳
-    layout_detection_model_name=None, # 如果要禁用版面分析,可以这样设置,或者依赖其默认行为
-    use_seal_recognition=True,         # 跳过印章识别
-    use_chart_recognition=True,         # 跳过图表识别
-    use_table_recognition=True,        # 开启表格识别
-)
-
-# 可视化结果并保存 json 结果
-for res in output:
-    res.print() 
-    # res.save_to_json(save_path="sample_data/output") 
-    # res.save_to_markdown(save_path="sample_data/output") 
-    res.save_all(save_path=output_path.as_posix())  # 保存所有结果到指定路径
+from cuda_utils import detect_available_gpus, monitor_gpu_memory
+
+from dotenv import load_dotenv
+load_dotenv(override=True)
+
+
+if __name__ == "__main__":
+    print(f"🚀 启动演示程序...")
+    print(f"CUDA 版本: {paddle.device.cuda.get_device_name()}")
+    print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
+    available_gpus = detect_available_gpus()
+    monitor_gpu_memory(available_gpus)
+
+    input_path = "./sample_data/300674-母公司现金流量表-扫描.png"
+
+    # pipeline_path = "./PP-StructureV3-zhch.yaml"
+    pipeline_path = "PP-StructureV3"
+    pipeline_name = Path(pipeline_path).stem
+    output_path = Path(f"./sample_data/single_pipeline_output/{pipeline_name}/")
+
+    pipeline = create_pipeline(pipeline=pipeline_path, device="gpu:0")
+
+    # For Image
+    output = pipeline.predict(
+        input=input_path,
+        # device="gpu",  # 或者 "gpu" 如果你有 GPU 支持
+        use_doc_orientation_classify=True, # 开启文档方向分类
+        use_doc_unwarping=False, # 开启文档去畸变, 效果不佳
+        layout_detection_model_name=None, # 如果要禁用版面分析,可以这样设置,或者依赖其默认行为
+        use_seal_recognition=True,         # 跳过印章识别
+        use_chart_recognition=True,         # 跳过图表识别
+        use_table_recognition=True,        # 开启表格识别
+    )
+
+    # 可视化结果并保存 json 结果
+    for res in output:
+        res.print() 
+        # res.save_to_json(save_path="sample_data/output") 
+        # res.save_to_markdown(save_path="sample_data/output") 
+        res.save_all(save_path=output_path.as_posix())  # 保存所有结果到指定路径