浏览代码

fix pipline model_name input (#2864)

zhangyubo0722 10 月之前
父节点
当前提交
8f99cccf95
共有 1 个文件被更改,包括 3 次插入5 次删除
  1. 3 5
      paddlex/inference/pipelines_new/base.py

+ 3 - 5
paddlex/inference/pipelines_new/base.py

@@ -81,14 +81,12 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         if "model_config_error" in config:
         if "model_config_error" in config:
             raise ValueError(config["model_config_error"])
             raise ValueError(config["model_config_error"])
 
 
-        model_dir = config["model_dir"]
-        if model_dir == None:
-            model_dir = config["model_name"]
-
+        model_dir = config.get("model_dir", None)
         from .. import create_predictor
         from .. import create_predictor
 
 
         model = create_predictor(
         model = create_predictor(
-            model=model_dir,
+            model_name=config["model_name"],
+            model_dir=model_dir,
             device=self.device,
             device=self.device,
             pp_option=self.pp_option,
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,
             use_hpip=self.use_hpip,