predictor.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from copy import deepcopy
  16. from abc import ABC, abstractmethod
  17. from .utils.paddle_inference_predictor import _PaddleInferencePredictor, PaddleInferenceOption
  18. from .utils.mixin import FromDictMixin
  19. from .utils.batch import batchable_method, Batcher
  20. from .utils.node import Node
  21. from .utils.official_models import official_models
  22. from ....utils.device import get_device
  23. from ....utils import logging
  24. from ....utils.config import AttrDict
  25. class BasePredictor(ABC, FromDictMixin, Node):
  26. """ Base Predictor """
  27. __is_base = True
  28. MODEL_FILE_TAG = 'inference'
  29. def __init__(self,
  30. model_dir,
  31. kernel_option,
  32. pre_transforms=None,
  33. post_transforms=None):
  34. super().__init__()
  35. self.model_dir = model_dir
  36. self.pre_transforms = pre_transforms
  37. self.post_transforms = post_transforms
  38. self.kernel_option = kernel_option
  39. param_path = os.path.join(model_dir, f"{self.MODEL_FILE_TAG}.pdiparams")
  40. model_path = os.path.join(model_dir, f"{self.MODEL_FILE_TAG}.pdmodel")
  41. self._predictor = _PaddleInferencePredictor(
  42. param_path=param_path, model_path=model_path, option=kernel_option)
  43. self.other_src = self.load_other_src()
  44. def predict(self, input, batch_size=1):
  45. """ predict """
  46. if not isinstance(input, dict) and not (isinstance(input, list) and all(
  47. isinstance(ele, dict) for ele in input)):
  48. raise TypeError(f"`input` should be a dict or a list of dicts.")
  49. orig_input = input
  50. if isinstance(input, dict):
  51. input = [input]
  52. logging.info(
  53. f"Running {self.__class__.__name__}\nModel: {self.model_dir}\nEnv: {self.kernel_option}\n"
  54. )
  55. data = input[0]
  56. if self.pre_transforms is not None:
  57. pre_tfs = self.pre_transforms
  58. else:
  59. pre_tfs = self._get_pre_transforms_for_data(data)
  60. logging.info(
  61. f"The following transformation operators will be used for data preprocessing:\n\
  62. {self._format_transforms(pre_tfs)}\n")
  63. if self.post_transforms is not None:
  64. post_tfs = self.post_transforms
  65. else:
  66. post_tfs = self._get_post_transforms_for_data(data)
  67. logging.info(
  68. f"The following transformation operators will be used for postprocessing:\n\
  69. {self._format_transforms(post_tfs)}\n")
  70. output = []
  71. for mini_batch in Batcher(input, batch_size=batch_size):
  72. mini_batch = self._preprocess(mini_batch, pre_transforms=pre_tfs)
  73. for data in mini_batch:
  74. self.check_input_keys(data)
  75. mini_batch = self._run(batch_input=mini_batch)
  76. for data in mini_batch:
  77. self.check_output_keys(data)
  78. mini_batch = self._postprocess(mini_batch, post_transforms=post_tfs)
  79. output.extend(mini_batch)
  80. if isinstance(orig_input, dict):
  81. return output[0]
  82. else:
  83. return output
  84. @abstractmethod
  85. def _run(self, batch_input):
  86. raise NotImplementedError
  87. @abstractmethod
  88. def _get_pre_transforms_for_data(self, data):
  89. """ get preprocess transforms """
  90. raise NotImplementedError
  91. @abstractmethod
  92. def _get_post_transforms_for_data(self, data):
  93. """ get postprocess transforms """
  94. raise NotImplementedError
  95. @batchable_method
  96. def _preprocess(self, data, pre_transforms):
  97. """ preprocess """
  98. for tf in pre_transforms:
  99. data = tf(data)
  100. return data
  101. @batchable_method
  102. def _postprocess(self, data, post_transforms):
  103. """ postprocess """
  104. for tf in post_transforms:
  105. data = tf(data)
  106. return data
  107. def _format_transforms(self, transforms):
  108. """ format transforms """
  109. lines = ['[']
  110. for tf in transforms:
  111. s = '\t'
  112. s += str(tf)
  113. lines.append(s)
  114. lines.append(']')
  115. return '\n'.join(lines)
  116. def load_other_src(self):
  117. """ load other source
  118. """
  119. return None
  120. class PredictorBuilderByConfig(object):
  121. """build model predictor
  122. """
  123. def __init__(self, config):
  124. """
  125. Args:
  126. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  127. """
  128. model_name = config.Global.model
  129. device = config.Global.device.split(':')[0]
  130. predict_config = deepcopy(config.Predict)
  131. model_dir = predict_config.pop('model_dir')
  132. kernel_setting = predict_config.pop('kernel_option', {})
  133. kernel_setting.setdefault('device', device)
  134. kernel_option = PaddleInferenceOption(**kernel_setting)
  135. self.input_path = predict_config.pop('input_path')
  136. self.predictor = BasePredictor.get(model_name)(model_dir, kernel_option,
  137. **predict_config)
  138. self.output = config.Global.output
  139. def __call__(self):
  140. data = {
  141. "input_path": self.input_path,
  142. "cli_flag": True,
  143. "output_dir": self.output
  144. }
  145. self.predictor.predict(data)
  146. def build_predictor(*args, **kwargs):
  147. """build predictor by config for dev
  148. """
  149. return PredictorBuilderByConfig(*args, **kwargs)
  150. def create_model(model_name,
  151. model_dir=None,
  152. kernel_option=None,
  153. pre_transforms=None,
  154. post_transforms=None,
  155. *args,
  156. **kwargs):
  157. """create model for predicting using inference model
  158. """
  159. kernel_option = PaddleInferenceOption(
  160. ) if kernel_option is None else kernel_option
  161. model_dir = official_models[model_name] if model_dir is None else model_dir
  162. return BasePredictor.get(model_name)(model_dir, kernel_option,
  163. pre_transforms, post_transforms, *args,
  164. **kwargs)