|
|
@@ -17,14 +17,10 @@ import codecs
|
|
|
from pathlib import Path
|
|
|
from abc import abstractmethod
|
|
|
|
|
|
-from ....utils.subclass_register import AutoRegisterABCMetaClass
|
|
|
-from ....utils.func_register import FuncRegister
|
|
|
-from ....utils import logging
|
|
|
-from ...utils.device import constr_device
|
|
|
-from ...components.base import BaseComponent, ComponentsEngine
|
|
|
+from ...components.base import BaseComponent
|
|
|
from ...utils.pp_option import PaddlePredictorOption
|
|
|
from ...utils.process_hook import generatorable_method
|
|
|
-from ..utils.predict_set import DeviceSetMixin, PPOptionSetMixin
|
|
|
+from ..utils.predict_set import DeviceSetMixin, PPOptionSetMixin, BatchSizeSetMixin
|
|
|
|
|
|
|
|
|
class BasePredictor(BaseComponent):
|
|
|
@@ -48,7 +44,7 @@ class BasePredictor(BaseComponent):
|
|
|
self.predict = self.__call__
|
|
|
|
|
|
def __call__(self, input, **kwargs):
|
|
|
- self.set_predict(**kwargs)
|
|
|
+ self.set_predictor(**kwargs)
|
|
|
for res in super().__call__(input):
|
|
|
yield res["result"]
|
|
|
|
|
|
@@ -65,7 +61,7 @@ class BasePredictor(BaseComponent):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
@abstractmethod
|
|
|
- def set_predict(self):
|
|
|
+ def set_predictor(self):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
@classmethod
|
|
|
@@ -78,67 +74,3 @@ class BasePredictor(BaseComponent):
|
|
|
with codecs.open(config_path, "r", "utf-8") as file:
|
|
|
dic = yaml.load(file, Loader=yaml.FullLoader)
|
|
|
return dic
|
|
|
-
|
|
|
-
|
|
|
-class BasicPredictor(
|
|
|
- BasePredictor, DeviceSetMixin, PPOptionSetMixin, metaclass=AutoRegisterABCMetaClass
|
|
|
-):
|
|
|
-
|
|
|
- __is_base = True
|
|
|
-
|
|
|
- def __init__(self, model_dir, config=None, device=None, pp_option=None):
|
|
|
- super().__init__(model_dir=model_dir, config=config)
|
|
|
- self._pred_set_func_map = {}
|
|
|
- self._pred_set_register = FuncRegister(self._pred_set_func_map)
|
|
|
- self._pred_set_register("device")(self.set_device)
|
|
|
- self._pred_set_register("pp_option")(self.set_pp_option)
|
|
|
-
|
|
|
- self.pp_option = pp_option if pp_option else PaddlePredictorOption()
|
|
|
- self.pp_option.set_device(device)
|
|
|
- self.components = {}
|
|
|
- self._build_components()
|
|
|
- self.engine = ComponentsEngine(self.components)
|
|
|
- logging.debug(
|
|
|
- f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}"
|
|
|
- )
|
|
|
-
|
|
|
- def apply(self, x):
|
|
|
- """predict"""
|
|
|
- yield from self._generate_res(self.engine(x))
|
|
|
-
|
|
|
- @generatorable_method
|
|
|
- def _generate_res(self, batch_data):
|
|
|
- return [{"result": self._pack_res(data)} for data in batch_data]
|
|
|
-
|
|
|
- def _add_component(self, cmps):
|
|
|
- if not isinstance(cmps, list):
|
|
|
- cmps = [cmps]
|
|
|
-
|
|
|
- for cmp in cmps:
|
|
|
- if not isinstance(cmp, (list, tuple)):
|
|
|
- key = cmp.__class__.__name__
|
|
|
- else:
|
|
|
- assert len(cmp) == 2
|
|
|
- key = cmp[0]
|
|
|
- cmp = cmp[1]
|
|
|
- assert isinstance(key, str)
|
|
|
- assert isinstance(cmp, BaseComponent)
|
|
|
- assert (
|
|
|
- key not in self.components
|
|
|
- ), f"The key ({key}) has been used: {self.components}!"
|
|
|
- self.components[key] = cmp
|
|
|
-
|
|
|
- def set_predict(self, **kwargs):
|
|
|
- for k in kwargs:
|
|
|
- assert (
|
|
|
- k in self._pred_set_func_map
|
|
|
- ), f"The arg({k}) is not supported to specify in predict() func! Only supports: {self._pred_set_func_map.keys()}"
|
|
|
- self._pred_set_func_map[k](kwargs[k])
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- def _build_components(self):
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- def _pack_res(self, data):
|
|
|
- raise NotImplementedError
|