|
|
@@ -25,42 +25,36 @@ from .modules import (
|
|
|
|
|
|
|
|
|
# TODO(gaotingquan): support _ModelBasedConfig
|
|
|
-def create_model(model=None, **kwargs):
|
|
|
- return _ModelBasedInference(model, **kwargs)
|
|
|
+def create_model(model=None, *args, **kwargs):
|
|
|
+ return _ModelBasedInference(model, *args, **kwargs)
|
|
|
|
|
|
|
|
|
class _BaseModel:
|
|
|
- @abstractmethod
|
|
|
def check_dataset(self, *args, **kwargs):
|
|
|
- raise NotImplementedError
|
|
|
+ raise Exception("check_dataset is not supported!")
|
|
|
|
|
|
- @abstractmethod
|
|
|
def train(self, *args, **kwargs):
|
|
|
- raise NotImplementedError
|
|
|
+ raise Exception("train is not supported!")
|
|
|
|
|
|
- @abstractmethod
|
|
|
def evaluate(self, *args, **kwargs):
|
|
|
- raise NotImplementedError
|
|
|
+ raise Exception("evaluate is not supported!")
|
|
|
|
|
|
- @abstractmethod
|
|
|
def export(self, *args, **kwargs):
|
|
|
- raise NotImplementedError
|
|
|
+ raise Exception("export is not supported!")
|
|
|
|
|
|
- @abstractmethod
|
|
|
def predict(self, *args, **kwargs):
|
|
|
- raise NotImplementedError
|
|
|
+ raise Exception("predict is not supported!")
|
|
|
|
|
|
- @abstractmethod
|
|
|
def set_predict(self, *args, **kwargs):
|
|
|
- raise NotImplementedError
|
|
|
+ raise Exception("set_predict is not supported!")
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
yield from self.predict(*args, **kwargs)
|
|
|
|
|
|
|
|
|
class _ModelBasedInference(_BaseModel):
|
|
|
- def __init__(self, model, device=None, **kwargs):
|
|
|
- self._predictor = create_predictor(model, device=device, **kwargs)
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+ self._predictor = create_predictor(*args, **kwargs)
|
|
|
|
|
|
def predict(self, *args, **kwargs):
|
|
|
yield from self._predictor(*args, **kwargs)
|
|
|
@@ -108,5 +102,5 @@ class _ModelBasedConfig(_BaseModel):
|
|
|
return exportor.export()
|
|
|
|
|
|
def predict(self):
|
|
|
- _predict_kwargs, _predictor = self._build_predictor()
|
|
|
- yield from _predictor(**_predict_kwargs)
|
|
|
+ predict_kwargs, predictor = self._build_predictor()
|
|
|
+ yield from predictor(**predict_kwargs)
|