base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. from abc import ABC, abstractmethod
  16. from copy import deepcopy
  17. from types import GeneratorType
  18. from ...utils.flags import INFER_BENCHMARK
  19. from ...utils import logging
  20. from ..utils.benchmark import Timer
  21. class BaseComponent(ABC):
  22. YIELD_BATCH = True
  23. KEEP_INPUT = True
  24. ENABLE_BATCH = False
  25. INPUT_KEYS = None
  26. OUTPUT_KEYS = None
  27. def __init__(self):
  28. self.inputs = self.DEAULT_INPUTS if hasattr(self, "DEAULT_INPUTS") else {}
  29. self.outputs = self.DEAULT_OUTPUTS if hasattr(self, "DEAULT_OUTPUTS") else {}
  30. if INFER_BENCHMARK:
  31. self.timer = Timer()
  32. self.apply = self.timer.watch_func(self.apply)
  33. def __call__(self, input_list):
  34. # use list type for batched data
  35. if not isinstance(input_list, list):
  36. input_list = [input_list]
  37. output_list = []
  38. for args, input_ in self._check_input(input_list):
  39. output = self.apply(**args)
  40. if not output:
  41. if self.YIELD_BATCH:
  42. yield input_list
  43. else:
  44. for item in input_list:
  45. yield item
  46. # output may be a generator, when the apply() uses yield
  47. if (
  48. isinstance(output, GeneratorType)
  49. or output.__class__.__name__ == "generator"
  50. ):
  51. # if output is a generator, use for-in to get every one batch output data and yield one by one
  52. for each_output in output:
  53. reassemble_data = self._check_output(each_output, input_)
  54. if self.YIELD_BATCH:
  55. yield reassemble_data
  56. else:
  57. for item in reassemble_data:
  58. yield item
  59. # if output is not a generator, process all data of that and yield, so use output_list to collect all reassemble_data
  60. else:
  61. reassemble_data = self._check_output(output, input_)
  62. output_list.extend(reassemble_data)
  63. # avoid yielding output_list when the output is a generator
  64. if len(output_list) > 0:
  65. if self.YIELD_BATCH:
  66. yield output_list
  67. else:
  68. for item in output_list:
  69. yield item
  70. def _check_input(self, input_list):
  71. # check if the value of input data meets the requirements of apply(),
  72. # and reassemble the parameters of apply() from input_list
  73. def _check_type(input_):
  74. if not isinstance(input_, dict):
  75. if len(self.inputs) == 1:
  76. key = list(self.inputs.keys())[0]
  77. input_ = {key: input_}
  78. else:
  79. raise Exception(
  80. f"The input must be a dict or a list of dict, unless the input of the component only requires one argument, but the component({self.__class__.__name__}) requires {list(self.inputs.keys())}!"
  81. )
  82. return input_
  83. def _check_args_key(args):
  84. sig = inspect.signature(self.apply)
  85. for param in sig.parameters.values():
  86. if param.kind == inspect.Parameter.VAR_KEYWORD:
  87. logging.debug(
  88. f"The apply function parameter of {self.__class__.__name__} is **kwargs, so would not inspect!"
  89. )
  90. continue
  91. if param.default == inspect.Parameter.empty and param.name not in args:
  92. raise Exception(
  93. f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but {list(args.keys())} only found!"
  94. )
  95. if self.inputs is None:
  96. return [({}, None)]
  97. if self.need_batch_input:
  98. args = {}
  99. for input_ in input_list:
  100. input_ = _check_type(input_)
  101. for k, v in self.inputs.items():
  102. if v not in input_:
  103. raise Exception(
  104. f"The value ({v}) is needed by {self.__class__.__name__}. But not found in Data ({input_.keys()})!"
  105. )
  106. if k not in args:
  107. args[k] = []
  108. args[k].append(input_.get(v))
  109. _check_args_key(args)
  110. reassemble_input = [(args, input_list)]
  111. else:
  112. reassemble_input = []
  113. for input_ in input_list:
  114. input_ = _check_type(input_)
  115. args = {}
  116. for k, v in self.inputs.items():
  117. if v not in input_:
  118. raise Exception(
  119. f"The value ({v}) is needed by {self.__class__.__name__}. But not found in Data ({input_.keys()})!"
  120. )
  121. args[k] = input_.get(v)
  122. _check_args_key(args)
  123. reassemble_input.append((args, input_))
  124. return reassemble_input
  125. def _check_output(self, output, ori_data):
  126. # check if the value of apply() output data meets the requirements of setting
  127. # when the output data is list type, reassemble each of that
  128. if isinstance(output, list):
  129. if self.need_batch_input:
  130. assert isinstance(ori_data, list) and len(ori_data) == len(output)
  131. output_list = []
  132. for ori_item, output_item in zip(ori_data, output):
  133. data = ori_item.copy() if self.keep_input else {}
  134. if isinstance(self.outputs, type(None)):
  135. logging.debug(
  136. f"The `output_key` of {self.__class__.__name__} is None, so would not inspect!"
  137. )
  138. data.update(output_item)
  139. else:
  140. for k, v in self.outputs.items():
  141. if k not in output_item:
  142. raise Exception(
  143. f"The value ({k}) is needed by {self.__class__.__name__}. But not found in Data ({output_item.keys()})!"
  144. )
  145. data.update({v: output_item[k]})
  146. output_list.append(data)
  147. return output_list
  148. else:
  149. assert isinstance(ori_data, dict)
  150. output_list = []
  151. for output_item in output:
  152. data = ori_data.copy() if self.keep_input else {}
  153. if isinstance(self.outputs, type(None)):
  154. logging.debug(
  155. f"The `output_key` of {self.__class__.__name__} is None, so would not inspect!"
  156. )
  157. data.update(output_item)
  158. else:
  159. for k, v in self.outputs.items():
  160. if k not in output_item:
  161. raise Exception(
  162. f"The value ({k}) is needed by {self.__class__.__name__}. But not found in Data ({output_item.keys()})!"
  163. )
  164. data.update({v: output_item[k]})
  165. output_list.append(data)
  166. return output_list
  167. else:
  168. assert isinstance(ori_data, dict) and isinstance(output, dict)
  169. data = ori_data.copy() if self.keep_input else {}
  170. if isinstance(self.outputs, type(None)):
  171. logging.debug(
  172. f"The `output_key` of {self.__class__.__name__} is None, so would not inspect!"
  173. )
  174. data.update(output)
  175. else:
  176. for k, v in self.outputs.items():
  177. if k not in output:
  178. raise Exception(
  179. f"The value of key ({k}) is needed add to Data. But not found in output of {self.__class__.__name__}: ({output.keys()})!"
  180. )
  181. data.update({v: output[k]})
  182. return [data]
  183. def set_inputs(self, inputs):
  184. assert isinstance(inputs, dict)
  185. input_keys = deepcopy(self.INPUT_KEYS)
  186. # e.g, input_keys is None or []
  187. if input_keys is None or (
  188. isinstance(input_keys, list) and len(input_keys) == 0
  189. ):
  190. self.inputs = {}
  191. if inputs:
  192. raise Exception
  193. return
  194. # e.g, input_keys is 'img'
  195. if not isinstance(input_keys, list):
  196. input_keys = [[input_keys]]
  197. # e.g, input_keys is ['img'] or [['img']]
  198. elif len(input_keys) > 0:
  199. # e.g, input_keys is ['img']
  200. if not isinstance(input_keys[0], list):
  201. input_keys = [input_keys]
  202. ck_pass = False
  203. for key_group in input_keys:
  204. for key in key_group:
  205. if key not in inputs:
  206. break
  207. # check pass
  208. else:
  209. ck_pass = True
  210. if ck_pass == True:
  211. break
  212. else:
  213. raise Exception(
  214. f"The input {input_keys} are needed by {self.__class__.__name__}. But only get: {list(inputs.keys())}"
  215. )
  216. self.inputs = inputs
  217. def set_outputs(self, outputs):
  218. assert isinstance(outputs, dict)
  219. output_keys = deepcopy(self.OUTPUT_KEYS)
  220. if not isinstance(output_keys, list):
  221. output_keys = [output_keys]
  222. for k in output_keys:
  223. if k not in outputs:
  224. logging.debug(
  225. f"The output ({k}) of {self.__class__.__name__} would be abandon!"
  226. )
  227. self.outputs = outputs
  228. @classmethod
  229. def get_input_keys(cls) -> list:
  230. return cls.input_keys
  231. @classmethod
  232. def get_output_keys(cls) -> list:
  233. return cls.output_keys
  234. @property
  235. def need_batch_input(self):
  236. return getattr(self, "ENABLE_BATCH", False)
  237. @property
  238. def keep_input(self):
  239. return getattr(self, "KEEP_INPUT", True)
  240. @property
  241. def name(self):
  242. return getattr(self, "NAME", self.__class__.__name__)
  243. @property
  244. def sub_cmps(self):
  245. return None
  246. @abstractmethod
  247. def apply(self, input):
  248. raise NotImplementedError
  249. class ComponentsEngine(object):
  250. def __init__(self, ops):
  251. self.ops = ops
  252. self.keys = list(ops.keys())
  253. def __call__(self, data, i=0):
  254. data_gen = self.ops[self.keys[i]](data)
  255. if i + 1 < len(self.ops):
  256. for data in data_gen:
  257. yield from self.__call__(data, i + 1)
  258. else:
  259. yield from data_gen