|
|
@@ -67,12 +67,13 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
|
"""
|
|
|
raise NotImplementedError("The method `predict` has not been implemented yet.")
|
|
|
|
|
|
- def create_model(self, config: Dict) -> BasePredictor:
|
|
|
+ def create_model(self, config: Dict, **kwargs) -> BasePredictor:
|
|
|
"""
|
|
|
Create a model instance based on the given configuration.
|
|
|
|
|
|
Args:
|
|
|
config (Dict): A dictionary containing configuration settings.
|
|
|
+ **kwargs: The model arguments that needed to be pass.
|
|
|
|
|
|
Returns:
|
|
|
BasePredictor: An instance of the model.
|
|
|
@@ -82,14 +83,15 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
|
if model_dir == None:
|
|
|
model_dir = config["model_name"]
|
|
|
|
|
|
- from ...model import create_model
|
|
|
+ from .. import create_predictor
|
|
|
|
|
|
- model = create_model(
|
|
|
+ model = create_predictor(
|
|
|
model=model_dir,
|
|
|
device=self.device,
|
|
|
pp_option=self.pp_option,
|
|
|
use_hpip=self.use_hpip,
|
|
|
hpi_params=self.hpi_params,
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
|
|
|
# [TODO] Support initializing with additional parameters
|