basic_predictor.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from abc import abstractmethod
  15. from ....utils.subclass_register import AutoRegisterABCMetaClass
  16. from ....utils.func_register import FuncRegister
  17. from ....utils import logging
  18. from ...components.base import BaseComponent, ComponentsEngine
  19. from ...utils.pp_option import PaddlePredictorOption
  20. from ...utils.process_hook import generatorable_method
  21. from ..utils.predict_set import DeviceSetMixin, PPOptionSetMixin, BatchSizeSetMixin
  22. from .base_predictor import BasePredictor
  23. class BasicPredictor(
  24. BasePredictor,
  25. DeviceSetMixin,
  26. PPOptionSetMixin,
  27. BatchSizeSetMixin,
  28. metaclass=AutoRegisterABCMetaClass,
  29. ):
  30. __is_base = True
  31. def __init__(self, model_dir, config=None, device=None, pp_option=None):
  32. super().__init__(model_dir=model_dir, config=config)
  33. self._pred_set_func_map = {}
  34. self._pred_set_register = FuncRegister(self._pred_set_func_map)
  35. self._pred_set_register("device")(self.set_device)
  36. self._pred_set_register("pp_option")(self.set_pp_option)
  37. self._pred_set_register("batch_size")(self.set_batch_size)
  38. self.pp_option = pp_option if pp_option else PaddlePredictorOption()
  39. self.pp_option.set_device(device)
  40. self.components = {}
  41. self._build_components()
  42. self.engine = ComponentsEngine(self.components)
  43. logging.debug(
  44. f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}"
  45. )
  46. def apply(self, x):
  47. """predict"""
  48. yield from self._generate_res(self.engine(x))
  49. @generatorable_method
  50. def _generate_res(self, batch_data):
  51. return [{"result": self._pack_res(data)} for data in batch_data]
  52. def _add_component(self, cmps):
  53. if not isinstance(cmps, list):
  54. cmps = [cmps]
  55. for cmp in cmps:
  56. if not isinstance(cmp, (list, tuple)):
  57. key = cmp.name
  58. else:
  59. assert len(cmp) == 2
  60. key = cmp[0]
  61. cmp = cmp[1]
  62. assert isinstance(key, str)
  63. assert isinstance(cmp, BaseComponent)
  64. assert (
  65. key not in self.components
  66. ), f"The key ({key}) has been used: {self.components}!"
  67. self.components[key] = cmp
  68. def set_predictor(self, **kwargs):
  69. for k in kwargs:
  70. assert (
  71. k in self._pred_set_func_map
  72. ), f"The arg({k}) is not supported to specify in predict() func! Only supports: {self._pred_set_func_map.keys()}"
  73. self._pred_set_func_map[k](kwargs[k])
  74. @abstractmethod
  75. def _build_components(self):
  76. raise NotImplementedError
  77. @abstractmethod
  78. def _pack_res(self, data):
  79. raise NotImplementedError