Przeglądaj źródła

fix pipline model_name input (#2864)

zhangyubo0722 10 miesięcy temu
rodzic
commit
8f99cccf95
1 zmienionych plików z 3 dodań i 5 usunięć
  1. 3 5
      paddlex/inference/pipelines_new/base.py

+ 3 - 5
paddlex/inference/pipelines_new/base.py

@@ -81,14 +81,12 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         if "model_config_error" in config:
             raise ValueError(config["model_config_error"])
 
-        model_dir = config["model_dir"]
-        if model_dir == None:
-            model_dir = config["model_name"]
-
+        model_dir = config.get("model_dir", None)
         from .. import create_predictor
 
         model = create_predictor(
-            model=model_dir,
+            model_name=config["model_name"],
+            model_dir=model_dir,
             device=self.device,
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,