|
|
@@ -13,7 +13,9 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
+from typing import Any, Dict, Optional
|
|
|
|
|
|
+from ..predictors import create_predictor
|
|
|
from ...utils.subclass_register import AutoRegisterABCMetaClass
|
|
|
|
|
|
|
|
|
@@ -23,6 +25,8 @@ def create_pipeline(
|
|
|
model_dir_list: list,
|
|
|
output: str,
|
|
|
device: str,
|
|
|
+ use_hpip: bool,
|
|
|
+ hpi_params: Optional[Dict[str, Any]] = None,
|
|
|
) -> "BasePipeline":
|
|
|
"""build model evaluater
|
|
|
|
|
|
@@ -32,7 +36,12 @@ def create_pipeline(
|
|
|
Returns:
|
|
|
BasePipeline: the pipeline, which is subclass of BasePipeline.
|
|
|
"""
|
|
|
- pipeline = BasePipeline.get(pipeline_name)(output=output, device=device)
|
|
|
+ predictor_kwargs = {"use_hpip": use_hpip}
|
|
|
+ if hpi_params is not None:
|
|
|
+ predictor_kwargs["hpi_params"] = hpi_params
|
|
|
+ pipeline = BasePipeline.get(pipeline_name)(
|
|
|
+ output=output, device=device, predictor_kwargs=predictor_kwargs
|
|
|
+ )
|
|
|
pipeline.update_model(model_list, model_dir_list)
|
|
|
pipeline.load_model()
|
|
|
return pipeline
|
|
|
@@ -43,6 +52,15 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
|
|
|
|
__is_base = True
|
|
|
|
|
|
+ def __init__(self, predictor_kwargs: Optional[Dict[str, Any]]) -> None:
|
|
|
+ super().__init__()
|
|
|
+ if predictor_kwargs is None:
|
|
|
+ predictor_kwargs = {}
|
|
|
+ self._predictor_kwargs = predictor_kwargs
|
|
|
+
|
|
|
# alias the __call__() to predict()
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
yield from self.predict(*args, **kwargs)
|
|
|
+
|
|
|
+ def _create_predictor(self, *args, **kwargs):
|
|
|
+ return create_predictor(*args, **kwargs, **self._predictor_kwargs)
|