predictor.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from copy import deepcopy
  16. from abc import ABC, abstractmethod
  17. from .kernel_option import PaddleInferenceOption
  18. from .utils.paddle_inference_predictor import _PaddleInferencePredictor
  19. from .utils.mixin import FromDictMixin
  20. from .utils.batch import batchable_method, Batcher
  21. from .utils.node import Node
  22. from .utils.official_models import official_models
  23. from ....utils.device import get_device
  24. from ....utils import logging
  25. from ....utils.config import AttrDict
  26. class BasePredictor(ABC, FromDictMixin, Node):
  27. """Base Predictor"""
  28. __is_base = True
  29. MODEL_FILE_TAG = "inference"
  30. def __init__(
  31. self,
  32. model_name,
  33. model_dir,
  34. kernel_option,
  35. output,
  36. pre_transforms=None,
  37. post_transforms=None,
  38. disable_print=False,
  39. disable_save=False,
  40. ):
  41. super().__init__()
  42. self.model_name = model_name
  43. self.model_dir = model_dir
  44. self.kernel_option = kernel_option
  45. self.output = output
  46. self.disable_print = disable_print
  47. self.disable_save = disable_save
  48. self.other_src = self.load_other_src()
  49. logging.debug(
  50. f"-------------------- {self.__class__.__name__} --------------------\n\
  51. Model: {self.model_dir}\n\
  52. Env: {self.kernel_option}"
  53. )
  54. self.pre_tfs, self.post_tfs = self.build_transforms(
  55. pre_transforms, post_transforms
  56. )
  57. self._predictor = _PaddleInferencePredictor(
  58. model_dir=model_dir, model_prefix=self.MODEL_FILE_TAG, option=kernel_option
  59. )
  60. def build_transforms(self, pre_transforms, post_transforms):
  61. """build pre-transforms and post-transforms"""
  62. pre_tfs = (
  63. pre_transforms
  64. if pre_transforms is not None
  65. else self._get_pre_transforms_from_config()
  66. )
  67. logging.debug(f"Preprocess Ops: {self._format_transforms(pre_tfs)}")
  68. post_tfs = (
  69. post_transforms
  70. if post_transforms is not None
  71. else self._get_post_transforms_from_config()
  72. )
  73. logging.debug(f"Postprocessing: {self._format_transforms(post_tfs)}")
  74. return pre_tfs, post_tfs
  75. def predict(self, input, batch_size=1):
  76. """predict"""
  77. if not isinstance(input, dict) and not (
  78. isinstance(input, list) and all(isinstance(ele, dict) for ele in input)
  79. ):
  80. raise TypeError(f"`input` should be a dict or a list of dicts.")
  81. orig_input = input
  82. if isinstance(input, dict):
  83. input = [input]
  84. output = []
  85. for mini_batch in Batcher(input, batch_size=batch_size):
  86. mini_batch = self._preprocess(mini_batch, pre_transforms=self.pre_tfs)
  87. for data in mini_batch:
  88. self.check_input_keys(data)
  89. mini_batch = self._run(batch_input=mini_batch)
  90. for data in mini_batch:
  91. self.check_output_keys(data)
  92. mini_batch = self._postprocess(mini_batch, post_transforms=self.post_tfs)
  93. output.extend(mini_batch)
  94. if isinstance(orig_input, dict):
  95. return output[0]
  96. else:
  97. return output
  98. @abstractmethod
  99. def _run(self, batch_input):
  100. raise NotImplementedError
  101. @abstractmethod
  102. def _get_pre_transforms_from_config(self):
  103. """get preprocess transforms"""
  104. raise NotImplementedError
  105. @abstractmethod
  106. def _get_post_transforms_from_config(self):
  107. """get postprocess transforms"""
  108. raise NotImplementedError
  109. @batchable_method
  110. def _preprocess(self, data, pre_transforms):
  111. """preprocess"""
  112. for tf in pre_transforms:
  113. data = tf(data)
  114. return data
  115. @batchable_method
  116. def _postprocess(self, data, post_transforms):
  117. """postprocess"""
  118. for tf in post_transforms:
  119. data = tf(data)
  120. return data
  121. def _format_transforms(self, transforms):
  122. """format transforms"""
  123. ops_str = ", ".join([str(tf) for tf in transforms])
  124. return f"[{ops_str}]"
  125. def load_other_src(self):
  126. """load other source"""
  127. return None
  128. def get_input_keys(self):
  129. """get keys of input dict"""
  130. return self.pre_tfs[0].get_input_keys()
  131. class PredictorBuilderByConfig(object):
  132. """build model predictor"""
  133. def __init__(self, config):
  134. """
  135. Args:
  136. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  137. """
  138. model_name = config.Global.model
  139. device = config.Global.device
  140. predict_config = deepcopy(config.Predict)
  141. model_dir = predict_config.pop("model_dir")
  142. if model_dir is None:
  143. if model_name in official_models:
  144. model_dir = official_models[model_name]
  145. kernel_setting = predict_config.pop("kernel_option", {})
  146. kernel_setting.setdefault("device", device)
  147. kernel_option = PaddleInferenceOption(**kernel_setting)
  148. self.input_path = predict_config.pop("input_path")
  149. self.predictor = BasePredictor.get(model_name)(
  150. model_name=model_name,
  151. model_dir=model_dir,
  152. kernel_option=kernel_option,
  153. output=config.Global.output,
  154. **predict_config,
  155. )
  156. def predict(self):
  157. """predict"""
  158. self.predictor.predict({"input_path": self.input_path})
  159. def build_predictor(*args, **kwargs):
  160. """build predictor by config for dev"""
  161. return PredictorBuilderByConfig(*args, **kwargs)
  162. def create_model(
  163. model_name,
  164. model_dir=None,
  165. kernel_option=None,
  166. output="./",
  167. pre_transforms=None,
  168. post_transforms=None,
  169. *args,
  170. **kwargs,
  171. ):
  172. """create model for predicting using inference model"""
  173. kernel_option = PaddleInferenceOption() if kernel_option is None else kernel_option
  174. if model_dir is None:
  175. if model_name in official_models:
  176. model_dir = official_models[model_name]
  177. return BasePredictor.get(model_name)(
  178. model_name=model_name,
  179. model_dir=model_dir,
  180. kernel_option=kernel_option,
  181. output=output,
  182. pre_transforms=pre_transforms,
  183. post_transforms=post_transforms,
  184. *args,
  185. **kwargs,
  186. )