|
|
@@ -81,14 +81,12 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
|
if "model_config_error" in config:
|
|
|
raise ValueError(config["model_config_error"])
|
|
|
|
|
|
- model_dir = config["model_dir"]
|
|
|
- if model_dir == None:
|
|
|
- model_dir = config["model_name"]
|
|
|
-
|
|
|
+ model_dir = config.get("model_dir", None)
|
|
|
from .. import create_predictor
|
|
|
|
|
|
model = create_predictor(
|
|
|
- model=model_dir,
|
|
|
+ model_name=config["model_name"],
|
|
|
+ model_dir=model_dir,
|
|
|
device=self.device,
|
|
|
pp_option=self.pp_option,
|
|
|
use_hpip=self.use_hpip,
|