base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. from abc import ABC, abstractmethod
  16. from copy import deepcopy
  17. from types import GeneratorType
  18. from ...utils import logging
  19. class BaseComponent(ABC):
  20. YIELD_BATCH = True
  21. KEEP_INPUT = True
  22. ENABLE_BATCH = False
  23. INPUT_KEYS = None
  24. OUTPUT_KEYS = None
  25. def __init__(self):
  26. self.inputs = self.DEAULT_INPUTS if hasattr(self, "DEAULT_INPUTS") else {}
  27. self.outputs = self.DEAULT_OUTPUTS if hasattr(self, "DEAULT_OUTPUTS") else {}
  28. def __call__(self, input_list):
  29. # use list type for batched data
  30. if not isinstance(input_list, list):
  31. input_list = [input_list]
  32. output_list = []
  33. for args, input_ in self._check_input(input_list):
  34. output = self.apply(**args)
  35. if not output:
  36. if self.YIELD_BATCH:
  37. yield input_list
  38. else:
  39. for item in input_list:
  40. yield item
  41. # output may be a generator, when the apply() uses yield
  42. if (
  43. isinstance(output, GeneratorType)
  44. or output.__class__.__name__ == "generator"
  45. ):
  46. # if output is a generator, use for-in to get every one batch output data and yield one by one
  47. for each_output in output:
  48. reassemble_data = self._check_output(each_output, input_)
  49. if self.YIELD_BATCH:
  50. yield reassemble_data
  51. else:
  52. for item in reassemble_data:
  53. yield item
  54. # if output is not a generator, process all data of that and yield, so use output_list to collect all reassemble_data
  55. else:
  56. reassemble_data = self._check_output(output, input_)
  57. output_list.extend(reassemble_data)
  58. # avoid yielding output_list when the output is a generator
  59. if len(output_list) > 0:
  60. if self.YIELD_BATCH:
  61. yield output_list
  62. else:
  63. for item in output_list:
  64. yield item
  65. def _check_input(self, input_list):
  66. # check if the value of input data meets the requirements of apply(),
  67. # and reassemble the parameters of apply() from input_list
  68. def _check_type(input_):
  69. if not isinstance(input_, dict):
  70. if len(self.inputs) == 1:
  71. key = list(self.inputs.keys())[0]
  72. input_ = {key: input_}
  73. else:
  74. raise Exception(
  75. f"The input must be a dict or a list of dict, unless the input of the component only requires one argument, but the component({self.__class__.__name__}) requires {list(self.inputs.keys())}!"
  76. )
  77. return input_
  78. def _check_args_key(args):
  79. sig = inspect.signature(self.apply)
  80. for param in sig.parameters.values():
  81. if param.kind == inspect.Parameter.VAR_KEYWORD:
  82. logging.debug(
  83. f"The apply function parameter of {self.__class__.__name__} is **kwargs, so would not inspect!"
  84. )
  85. continue
  86. if param.default == inspect.Parameter.empty and param.name not in args:
  87. raise Exception(
  88. f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but {list(args.keys())} only found!"
  89. )
  90. if self.need_batch_input:
  91. args = {}
  92. for input_ in input_list:
  93. input_ = _check_type(input_)
  94. for k, v in self.inputs.items():
  95. if v not in input_:
  96. raise Exception(
  97. f"The value ({v}) is needed by {self.__class__.__name__}. But not found in Data ({input_.keys()})!"
  98. )
  99. if k not in args:
  100. args[k] = []
  101. args[k].append(input_.get(v))
  102. _check_args_key(args)
  103. reassemble_input = [(args, input_list)]
  104. else:
  105. reassemble_input = []
  106. for input_ in input_list:
  107. input_ = _check_type(input_)
  108. args = {}
  109. for k, v in self.inputs.items():
  110. if v not in input_:
  111. raise Exception(
  112. f"The value ({v}) is needed by {self.__class__.__name__}. But not found in Data ({input_.keys()})!"
  113. )
  114. args[k] = input_.get(v)
  115. _check_args_key(args)
  116. reassemble_input.append((args, input_))
  117. return reassemble_input
  118. def _check_output(self, output, ori_data):
  119. # check if the value of apply() output data meets the requirements of setting
  120. # when the output data is list type, reassemble each of that
  121. if isinstance(output, list):
  122. if self.need_batch_input:
  123. assert isinstance(ori_data, list) and len(ori_data) == len(output)
  124. output_list = []
  125. for ori_item, output_item in zip(ori_data, output):
  126. data = ori_item.copy() if self.keep_input else {}
  127. if isinstance(self.outputs, type(None)):
  128. logging.debug(
  129. f"The `output_key` of {self.__class__.__name__} is None, so would not inspect!"
  130. )
  131. data.update(output_item)
  132. else:
  133. for k, v in self.outputs.items():
  134. if k not in output_item:
  135. raise Exception(
  136. f"The value ({k}) is needed by {self.__class__.__name__}. But not found in Data ({output_item.keys()})!"
  137. )
  138. data.update({v: output_item[k]})
  139. output_list.append(data)
  140. return output_list
  141. else:
  142. assert isinstance(ori_data, dict)
  143. output_list = []
  144. for output_item in output:
  145. data = ori_data.copy() if self.keep_input else {}
  146. if isinstance(self.outputs, type(None)):
  147. logging.debug(
  148. f"The `output_key` of {self.__class__.__name__} is None, so would not inspect!"
  149. )
  150. data.update(output_item)
  151. else:
  152. for k, v in self.outputs.items():
  153. if k not in output_item:
  154. raise Exception(
  155. f"The value ({k}) is needed by {self.__class__.__name__}. But not found in Data ({output_item.keys()})!"
  156. )
  157. data.update({v: output_item[k]})
  158. output_list.append(data)
  159. return output_list
  160. else:
  161. assert isinstance(ori_data, dict) and isinstance(output, dict)
  162. data = ori_data.copy() if self.keep_input else {}
  163. if isinstance(self.outputs, type(None)):
  164. logging.debug(
  165. f"The `output_key` of {self.__class__.__name__} is None, so would not inspect!"
  166. )
  167. data.update(output)
  168. else:
  169. for k, v in self.outputs.items():
  170. if k not in output:
  171. raise Exception(
  172. f"The value of key ({k}) is needed add to Data. But not found in output of {self.__class__.__name__}: ({output.keys()})!"
  173. )
  174. data.update({v: output[k]})
  175. return [data]
  176. def set_inputs(self, inputs):
  177. assert isinstance(inputs, dict)
  178. input_keys = deepcopy(self.INPUT_KEYS)
  179. # e.g, input_keys is None or []
  180. if input_keys is None or (
  181. isinstance(input_keys, list) and len(input_keys) == 0
  182. ):
  183. self.inputs = {}
  184. if inputs:
  185. raise Exception
  186. return
  187. # e.g, input_keys is 'img'
  188. if not isinstance(input_keys, list):
  189. input_keys = [[input_keys]]
  190. # e.g, input_keys is ['img'] or [['img']]
  191. elif len(input_keys) > 0:
  192. # e.g, input_keys is ['img']
  193. if not isinstance(input_keys[0], list):
  194. input_keys = [input_keys]
  195. ck_pass = False
  196. for key_group in input_keys:
  197. for key in key_group:
  198. if key not in inputs:
  199. break
  200. # check pass
  201. else:
  202. ck_pass = True
  203. if ck_pass == True:
  204. break
  205. else:
  206. raise Exception(
  207. f"The input {input_keys} are needed by {self.__class__.__name__}. But only get: {list(inputs.keys())}"
  208. )
  209. self.inputs = inputs
  210. def set_outputs(self, outputs):
  211. assert isinstance(outputs, dict)
  212. output_keys = deepcopy(self.OUTPUT_KEYS)
  213. if not isinstance(output_keys, list):
  214. output_keys = [output_keys]
  215. for k in output_keys:
  216. if k not in outputs:
  217. logging.debug(
  218. f"The output ({k}) of {self.__class__.__name__} would be abandon!"
  219. )
  220. self.outputs = outputs
  221. @classmethod
  222. def get_input_keys(cls) -> list:
  223. return cls.input_keys
  224. @classmethod
  225. def get_output_keys(cls) -> list:
  226. return cls.output_keys
  227. @property
  228. def need_batch_input(self):
  229. return getattr(self, "ENABLE_BATCH", False)
  230. @property
  231. def keep_input(self):
  232. return getattr(self, "KEEP_INPUT", True)
  233. @property
  234. def name(self):
  235. return getattr(self, "NAME", self.__class__.__name__)
  236. @abstractmethod
  237. def apply(self, input):
  238. raise NotImplementedError
  239. class ComponentsEngine(object):
  240. def __init__(self, ops):
  241. self.ops = ops
  242. self.keys = list(ops.keys())
  243. def __call__(self, data, i=0):
  244. data_gen = self.ops[self.keys[i]](data)
  245. if i + 1 < len(self.ops):
  246. for data in data_gen:
  247. yield from self.__call__(data, i + 1)
  248. else:
  249. yield from data_gen