basic_predictor.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from abc import abstractmethod
  15. from ....utils.subclass_register import AutoRegisterABCMetaClass
  16. from ....utils.func_register import FuncRegister
  17. from ....utils import logging
  18. from ...components.base import BaseComponent, ComponentsEngine
  19. from ...utils.pp_option import PaddlePredictorOption
  20. from ...utils.process_hook import generatorable_method
  21. from ..utils.predict_set import DeviceSetMixin, PPOptionSetMixin, BatchSizeSetMixin
  22. from .base_predictor import BasePredictor
  23. class BasicPredictor(
  24. BasePredictor,
  25. DeviceSetMixin,
  26. PPOptionSetMixin,
  27. BatchSizeSetMixin,
  28. metaclass=AutoRegisterABCMetaClass,
  29. ):
  30. __is_base = True
  31. def __init__(
  32. self, model_dir, config=None, device=None, pp_option=None, **option_kwargs
  33. ):
  34. super().__init__(model_dir=model_dir, config=config)
  35. self._pred_set_func_map = {}
  36. self._pred_set_register = FuncRegister(self._pred_set_func_map)
  37. self._pred_set_register("device")(self.set_device)
  38. self._pred_set_register("pp_option")(self.set_pp_option)
  39. self._pred_set_register("batch_size")(self.set_batch_size)
  40. self.pp_option = (
  41. pp_option
  42. if pp_option
  43. else PaddlePredictorOption(model_name=self.model_name, **option_kwargs)
  44. )
  45. self.pp_option.set_device(device)
  46. self.components = {}
  47. self._build_components()
  48. self.engine = ComponentsEngine(self.components)
  49. logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
  50. def apply(self, x):
  51. """predict"""
  52. yield from self._generate_res(self.engine(x))
  53. @generatorable_method
  54. def _generate_res(self, batch_data):
  55. return [{"result": self._pack_res(data)} for data in batch_data]
  56. def _add_component(self, cmps):
  57. if not isinstance(cmps, list):
  58. cmps = [cmps]
  59. for cmp in cmps:
  60. if not isinstance(cmp, (list, tuple)):
  61. key = cmp.name
  62. else:
  63. assert len(cmp) == 2
  64. key = cmp[0]
  65. cmp = cmp[1]
  66. assert isinstance(key, str)
  67. assert isinstance(cmp, BaseComponent)
  68. assert (
  69. key not in self.components
  70. ), f"The key ({key}) has been used: {self.components}!"
  71. self.components[key] = cmp
  72. def set_predictor(self, **kwargs):
  73. for k in kwargs:
  74. assert (
  75. k in self._pred_set_func_map
  76. ), f"The arg({k}) is not supported to specify in predict() func! Only supports: {self._pred_set_func_map.keys()}"
  77. self._pred_set_func_map[k](kwargs[k])
  78. @abstractmethod
  79. def _build_components(self):
  80. raise NotImplementedError
  81. @abstractmethod
  82. def _pack_res(self, data):
  83. raise NotImplementedError