basic_predictor.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from abc import abstractmethod
  15. from ....utils.subclass_register import AutoRegisterABCMetaClass
  16. from ....utils.func_register import FuncRegister
  17. from ....utils import logging
  18. from ...components.base import BaseComponent, ComponentsEngine
  19. from ...utils.pp_option import PaddlePredictorOption
  20. from ...utils.process_hook import generatorable_method
  21. from ..utils.predict_set import DeviceSetMixin, PPOptionSetMixin, BatchSizeSetMixin
  22. from .base_predictor import BasePredictor
  23. class BasicPredictor(
  24. BasePredictor,
  25. DeviceSetMixin,
  26. PPOptionSetMixin,
  27. BatchSizeSetMixin,
  28. metaclass=AutoRegisterABCMetaClass,
  29. ):
  30. __is_base = True
  31. def __init__(self, model_dir, config=None, device=None, pp_option=None):
  32. super().__init__(model_dir=model_dir, config=config)
  33. self._pred_set_func_map = {}
  34. self._pred_set_register = FuncRegister(self._pred_set_func_map)
  35. self._pred_set_register("device")(self.set_device)
  36. self._pred_set_register("pp_option")(self.set_pp_option)
  37. self._pred_set_register("batch_size")(self.set_batch_size)
  38. self.pp_option = (
  39. pp_option
  40. if pp_option
  41. else PaddlePredictorOption(model_name=self.model_name)
  42. )
  43. self.pp_option.set_device(device)
  44. self.components = {}
  45. self._build_components()
  46. self.engine = ComponentsEngine(self.components)
  47. logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
  48. def apply(self, x):
  49. """predict"""
  50. yield from self._generate_res(self.engine(x))
  51. @generatorable_method
  52. def _generate_res(self, batch_data):
  53. return [{"result": self._pack_res(data)} for data in batch_data]
  54. def _add_component(self, cmps):
  55. if not isinstance(cmps, list):
  56. cmps = [cmps]
  57. for cmp in cmps:
  58. if not isinstance(cmp, (list, tuple)):
  59. key = cmp.name
  60. else:
  61. assert len(cmp) == 2
  62. key = cmp[0]
  63. cmp = cmp[1]
  64. assert isinstance(key, str)
  65. assert isinstance(cmp, BaseComponent)
  66. assert (
  67. key not in self.components
  68. ), f"The key ({key}) has been used: {self.components}!"
  69. self.components[key] = cmp
  70. def set_predictor(self, **kwargs):
  71. for k in kwargs:
  72. assert (
  73. k in self._pred_set_func_map
  74. ), f"The arg({k}) is not supported to specify in predict() func! Only supports: {self._pred_set_func_map.keys()}"
  75. self._pred_set_func_map[k](kwargs[k])
  76. @abstractmethod
  77. def _build_components(self):
  78. raise NotImplementedError
  79. @abstractmethod
  80. def _pack_res(self, data):
  81. raise NotImplementedError