|
|
@@ -24,9 +24,10 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
|
|
|
|
__is_base = True
|
|
|
|
|
|
- def __init__(self, predictor_kwargs) -> None:
|
|
|
+ def __init__(self, device, predictor_kwargs={}) -> None:
|
|
|
super().__init__()
|
|
|
- self._predictor_kwargs = {} if predictor_kwargs is None else predictor_kwargs
|
|
|
+ self._predictor_kwargs = predictor_kwargs
|
|
|
+ self._device = device
|
|
|
|
|
|
@abstractmethod
|
|
|
def set_predictor():
|
|
|
@@ -41,9 +42,18 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
|
def _create(self, model=None, pipeline=None, *args, **kwargs):
|
|
|
if model:
|
|
|
return create_predictor(
|
|
|
- model=model, *args, **kwargs, **self._predictor_kwargs
|
|
|
+ *args,
|
|
|
+ model=model,
|
|
|
+ device=self._device,
|
|
|
+ **kwargs,
|
|
|
+ **self._predictor_kwargs
|
|
|
)
|
|
|
elif pipeline:
|
|
|
- return pipeline(*args, **kwargs, predictor_kwargs=self._predictor_kwargs)
|
|
|
+ return pipeline(
|
|
|
+ *args,
|
|
|
+ device=self._device,
|
|
|
+ predictor_kwargs=self._predictor_kwargs,
|
|
|
+ **kwargs
|
|
|
+ )
|
|
|
else:
|
|
|
raise Exception()
|