浏览代码

update get_pipeline_config method

cuicheng01 10 月之前
父节点
当前提交
ce3d03a610
共有 1 个文件被更改,包括 7 次插入3 次删除
  1. 7 3
      paddlex/inference/utils/get_pipeline_path.py

+ 7 - 3
paddlex/inference/utils/get_pipeline_path.py

@@ -18,10 +18,14 @@ from pathlib import Path
 
 def get_pipeline_path(pipeline_name):
     # XXX: using dict class to handle all pipeline configs
+    from ...utils.flags import USE_NEW_INFERENCE
+
+    if USE_NEW_INFERENCE:
+        config_subdir = "configs/pipelines"
+    else:
+        config_subdir = "pipelines"
     pipeline_path = (
-        Path(__file__).parent.parent.parent
-        / "configs/pipelines"
-        / f"{pipeline_name}.yaml"
+        Path(__file__).parent.parent.parent / config_subdir / f"{pipeline_name}.yaml"
     ).resolve()
     if not Path(pipeline_path).exists():
         return None