base.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. from copy import deepcopy
  16. from abc import ABC
  17. from types import GeneratorType
  18. from ...utils import logging
  19. class BaseComponent(ABC):
  20. INPUT_KEYS = None
  21. OUTPUT_KEYS = None
  22. def __init__(self):
  23. self.inputs = self.DEAULT_INPUTS if hasattr(self, "DEAULT_INPUTS") else {}
  24. self.outputs = self.DEAULT_OUTPUTS if hasattr(self, "DEAULT_OUTPUTS") else {}
  25. def __call__(self, input_list):
  26. # use list type for batched data
  27. if not isinstance(input_list, list):
  28. input_list = [input_list]
  29. output_list = []
  30. for args, input_ in self._check_input(input_list):
  31. output = self.apply(**args)
  32. if not output:
  33. yield input_list
  34. # output may be a generator, when the apply() uses yield
  35. if isinstance(output, GeneratorType):
  36. # if output is a generator, use for-in to get every one batch output data and yield one by one
  37. for each_output in output:
  38. reassemble_data = self._check_output(each_output, input_)
  39. yield reassemble_data
  40. # if output is not a generator, process all data of that and yield, so use output_list to collect all reassemble_data
  41. else:
  42. reassemble_data = self._check_output(output, input_)
  43. output_list.extend(reassemble_data)
  44. # avoid yielding output_list when the output is a generator
  45. if len(output_list) > 0:
  46. yield output_list
  47. def _check_input(self, input_list):
  48. # check if the value of input data meets the requirements of apply(),
  49. # and reassemble the parameters of apply() from input_list
  50. def _check_type(input_):
  51. if not isinstance(input_, dict):
  52. if len(self.inputs) == 1:
  53. key = list(self.inputs.keys())[0]
  54. input_ = {key: input_}
  55. else:
  56. raise Exception(
  57. f"The input must be a dict or a list of dict, unless the input of the component only requires one argument, but the component({self.__class__.__name__}) requires {list(self.inputs.keys())}!"
  58. )
  59. return input_
  60. def _check_args_key(args):
  61. sig = inspect.signature(self.apply)
  62. for param in sig.parameters.values():
  63. if param.kind == inspect.Parameter.VAR_KEYWORD:
  64. logging.debug(
  65. f"The apply function parameter of {self.__class__.__name__} is **kwargs, so would not inspect!"
  66. )
  67. continue
  68. if param.default == inspect.Parameter.empty and param.name not in args:
  69. raise Exception(
  70. f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but {list(args.keys())} only found!"
  71. )
  72. if self.need_batch_input:
  73. args = {}
  74. for input_ in input_list:
  75. input_ = _check_type(input_)
  76. for k, v in self.inputs.items():
  77. if v not in input_:
  78. raise Exception(
  79. f"The value ({v}) is needed by {self.__class__.__name__}. But not found in Data ({input_.keys()})!"
  80. )
  81. if k not in args:
  82. args[k] = []
  83. args[k].append(input_.get(v))
  84. _check_args_key(args)
  85. reassemble_input = [(args, input_list)]
  86. else:
  87. reassemble_input = []
  88. for input_ in input_list:
  89. input_ = _check_type(input_)
  90. args = {}
  91. for k, v in self.inputs.items():
  92. if v not in input_:
  93. raise Exception(
  94. f"The value ({v}) is needed by {self.__class__.__name__}. But not found in Data ({input_.keys()})!"
  95. )
  96. args[k] = input_.get(v)
  97. _check_args_key(args)
  98. reassemble_input.append((args, input_))
  99. return reassemble_input
  100. def _check_output(self, output, ori_data):
  101. # check if the value of apply() output data meets the requirements of setting
  102. # when the output data is list type, reassemble each of that
  103. if isinstance(output, list):
  104. if self.need_batch_input:
  105. assert isinstance(ori_data, list) and len(ori_data) == len(output)
  106. output_list = []
  107. for ori_item, output_item in zip(ori_data, output):
  108. data = ori_item.copy() if self.keep_ori else {}
  109. for k, v in self.outputs.items():
  110. if k not in output_item:
  111. raise Exception(
  112. f"The value ({k}) is needed by {self.__class__.__name__}. But not found in Data ({output_item.keys()})!"
  113. )
  114. data.update({v: output_item[k]})
  115. output_list.append(data)
  116. return output_list
  117. else:
  118. assert isinstance(ori_data, dict)
  119. output_list = []
  120. for output_item in output:
  121. data = ori_data.copy() if self.keep_ori else {}
  122. for k, v in self.outputs.items():
  123. if k not in output_item:
  124. raise Exception(
  125. f"The value ({k}) is needed by {self.__class__.__name__}. But not found in Data ({output_item.keys()})!"
  126. )
  127. data.update({v: output_item[k]})
  128. output_list.append(data)
  129. return output_list
  130. else:
  131. assert isinstance(ori_data, dict) and isinstance(output, dict)
  132. data = ori_data.copy() if self.keep_ori else {}
  133. for k, v in self.outputs.items():
  134. if k not in output:
  135. raise Exception(
  136. f"The value of key ({k}) is needed add to Data. But not found in output of {self.__class__.__name__}: ({output.keys()})!"
  137. )
  138. data.update({v: output[k]})
  139. return [data]
  140. def set_inputs(self, inputs):
  141. assert isinstance(inputs, dict)
  142. input_keys = deepcopy(self.INPUT_KEYS)
  143. # e.g, input_keys is None or []
  144. if input_keys is None or (
  145. isinstance(input_keys, list) and len(input_keys) == 0
  146. ):
  147. self.inputs = {}
  148. if inputs:
  149. raise Exception
  150. return
  151. # e.g, input_keys is 'img'
  152. if not isinstance(input_keys, list):
  153. input_keys = [[input_keys]]
  154. # e.g, input_keys is ['img'] or [['img']]
  155. elif len(input_keys) > 0:
  156. # e.g, input_keys is ['img']
  157. if not isinstance(input_keys[0], list):
  158. input_keys = [input_keys]
  159. ck_pass = False
  160. for key_group in input_keys:
  161. for key in key_group:
  162. if key not in inputs:
  163. break
  164. # check pass
  165. else:
  166. ck_pass = True
  167. if ck_pass == True:
  168. break
  169. else:
  170. raise Exception(
  171. f"The input {input_keys} are needed by {self.__class__.__name__}. But only get: {list(inputs.keys())}"
  172. )
  173. self.inputs = inputs
  174. def set_outputs(self, outputs):
  175. assert isinstance(outputs, dict)
  176. output_keys = deepcopy(self.OUTPUT_KEYS)
  177. if not isinstance(output_keys, list):
  178. output_keys = [output_keys]
  179. for k in output_keys:
  180. if k not in outputs:
  181. logging.debug(
  182. f"The output ({k}) of {self.__class__.__name__} would be abandon!"
  183. )
  184. self.outputs = outputs
  185. @classmethod
  186. def get_input_keys(cls) -> list:
  187. return cls.input_keys
  188. @classmethod
  189. def get_output_keys(cls) -> list:
  190. return cls.output_keys
  191. @property
  192. def need_batch_input(self):
  193. return getattr(self, "ENABLE_BATCH", False)
  194. @property
  195. def keep_ori(self):
  196. return getattr(self, "KEEP_INPUT", True)
  197. class ComponentsEngine(object):
  198. def __init__(self, ops):
  199. self.ops = ops
  200. self.keys = list(ops.keys())
  201. def __call__(self, data, i=0):
  202. data_gen = self.ops[self.keys[i]](data)
  203. if i + 1 < len(self.ops):
  204. for data in data_gen:
  205. yield from self.__call__(data, i + 1)
  206. else:
  207. yield from data_gen