predictor.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from copy import deepcopy
  16. from abc import ABC, abstractmethod
  17. from .kernel_option import PaddleInferenceOption
  18. from .utils.paddle_inference_predictor import _PaddleInferencePredictor
  19. from .utils.mixin import FromDictMixin
  20. from .utils.batch import batchable_method, Batcher
  21. from .utils.node import Node
  22. from .utils.official_models import official_models
  23. from ....utils.device import get_device
  24. from ....utils import logging
  25. from ....utils.config import AttrDict
  26. class BasePredictor(ABC, FromDictMixin, Node):
  27. """ Base Predictor """
  28. __is_base = True
  29. MODEL_FILE_TAG = 'inference'
  30. def __init__(self,
  31. model_name,
  32. model_dir,
  33. kernel_option,
  34. output,
  35. pre_transforms=None,
  36. post_transforms=None):
  37. super().__init__()
  38. self.model_name = model_name
  39. self.model_dir = model_dir
  40. self.kernel_option = kernel_option
  41. self.output = output
  42. self.other_src = self.load_other_src()
  43. logging.debug(
  44. f"-------------------- {self.__class__.__name__} --------------------\n\
  45. Model: {self.model_dir}\n\
  46. Env: {self.kernel_option}")
  47. self.pre_tfs, self.post_tfs = self.build_transforms(pre_transforms,
  48. post_transforms)
  49. param_path = os.path.join(model_dir, f"{self.MODEL_FILE_TAG}.pdiparams")
  50. model_path = os.path.join(model_dir, f"{self.MODEL_FILE_TAG}.pdmodel")
  51. self._predictor = _PaddleInferencePredictor(
  52. param_path=param_path, model_path=model_path, option=kernel_option)
  53. def build_transforms(self, pre_transforms, post_transforms):
  54. """ build pre-transforms and post-transforms
  55. """
  56. pre_tfs = pre_transforms if pre_transforms is not None else self._get_pre_transforms_from_config(
  57. )
  58. logging.debug(f"Preprocess Ops: {self._format_transforms(pre_tfs)}")
  59. post_tfs = post_transforms if post_transforms is not None else self._get_post_transforms_from_config(
  60. )
  61. logging.debug(f"Postprocessing: {self._format_transforms(post_tfs)}")
  62. return pre_tfs, post_tfs
  63. def predict(self, input, batch_size=1):
  64. """ predict """
  65. if not isinstance(input, dict) and not (isinstance(input, list) and all(
  66. isinstance(ele, dict) for ele in input)):
  67. raise TypeError(f"`input` should be a dict or a list of dicts.")
  68. orig_input = input
  69. if isinstance(input, dict):
  70. input = [input]
  71. output = []
  72. for mini_batch in Batcher(input, batch_size=batch_size):
  73. mini_batch = self._preprocess(
  74. mini_batch, pre_transforms=self.pre_tfs)
  75. for data in mini_batch:
  76. self.check_input_keys(data)
  77. mini_batch = self._run(batch_input=mini_batch)
  78. for data in mini_batch:
  79. self.check_output_keys(data)
  80. mini_batch = self._postprocess(
  81. mini_batch, post_transforms=self.post_tfs)
  82. output.extend(mini_batch)
  83. if isinstance(orig_input, dict):
  84. return output[0]
  85. else:
  86. return output
  87. @abstractmethod
  88. def _run(self, batch_input):
  89. raise NotImplementedError
  90. @abstractmethod
  91. def _get_pre_transforms_from_config(self):
  92. """ get preprocess transforms """
  93. raise NotImplementedError
  94. @abstractmethod
  95. def _get_post_transforms_from_config(self):
  96. """ get postprocess transforms """
  97. raise NotImplementedError
  98. @batchable_method
  99. def _preprocess(self, data, pre_transforms):
  100. """ preprocess """
  101. for tf in pre_transforms:
  102. data = tf(data)
  103. return data
  104. @batchable_method
  105. def _postprocess(self, data, post_transforms):
  106. """ postprocess """
  107. for tf in post_transforms:
  108. data = tf(data)
  109. return data
  110. def _format_transforms(self, transforms):
  111. """ format transforms """
  112. ops_str = ", ".join([str(tf) for tf in transforms])
  113. return f"[{ops_str}]"
  114. def load_other_src(self):
  115. """ load other source
  116. """
  117. return None
  118. def get_input_keys(self):
  119. """get keys of input dict
  120. """
  121. return self.pre_tfs[0].get_input_keys()
  122. class PredictorBuilderByConfig(object):
  123. """build model predictor
  124. """
  125. def __init__(self, config):
  126. """
  127. Args:
  128. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  129. """
  130. model_name = config.Global.model
  131. device = config.Global.device
  132. predict_config = deepcopy(config.Predict)
  133. model_dir = predict_config.pop('model_dir')
  134. kernel_setting = predict_config.pop('kernel_option', {})
  135. kernel_setting.setdefault('device', device)
  136. kernel_option = PaddleInferenceOption(**kernel_setting)
  137. self.input_path = predict_config.pop('input_path')
  138. self.predictor = BasePredictor.get(model_name)(
  139. model_name=model_name,
  140. model_dir=model_dir,
  141. kernel_option=kernel_option,
  142. output=config.Global.output,
  143. **predict_config)
  144. def predict(self):
  145. """predict
  146. """
  147. self.predictor.predict({'input_path': self.input_path})
  148. def build_predictor(*args, **kwargs):
  149. """build predictor by config for dev
  150. """
  151. return PredictorBuilderByConfig(*args, **kwargs)
  152. def create_model(model_name,
  153. model_dir=None,
  154. kernel_option=None,
  155. output="./",
  156. pre_transforms=None,
  157. post_transforms=None,
  158. *args,
  159. **kwargs):
  160. """create model for predicting using inference model
  161. """
  162. kernel_option = PaddleInferenceOption(
  163. ) if kernel_option is None else kernel_option
  164. if model_dir is None:
  165. if model_name in official_models:
  166. model_dir = official_models[model_name]
  167. return BasePredictor.get(model_name)(model_name=model_name,
  168. model_dir=model_dir,
  169. kernel_option=kernel_option,
  170. output=output,
  171. pre_transforms=pre_transforms,
  172. post_transforms=post_transforms,
  173. *args,
  174. **kwargs)