|
|
@@ -22,6 +22,10 @@ from ...utils import logging
|
|
|
|
|
|
class BaseComponent(ABC):
|
|
|
|
|
|
+ YIELD_BATCH = True
|
|
|
+ KEEP_INPUT = True
|
|
|
+ ENABLE_BATCH = False
|
|
|
+
|
|
|
INPUT_KEYS = None
|
|
|
OUTPUT_KEYS = None
|
|
|
|
|
|
@@ -38,14 +42,22 @@ class BaseComponent(ABC):
|
|
|
for args, input_ in self._check_input(input_list):
|
|
|
output = self.apply(**args)
|
|
|
if not output:
|
|
|
- yield input_list
|
|
|
+ if self.YIELD_BATCH:
|
|
|
+ yield input_list
|
|
|
+ else:
|
|
|
+ for item in input_list:
|
|
|
+ yield item
|
|
|
|
|
|
# output may be a generator, when the apply() uses yield
|
|
|
if isinstance(output, GeneratorType):
|
|
|
# if output is a generator, use for-in to get every one batch output data and yield one by one
|
|
|
for each_output in output:
|
|
|
reassemble_data = self._check_output(each_output, input_)
|
|
|
- yield reassemble_data
|
|
|
+ if self.YIELD_BATCH:
|
|
|
+ yield reassemble_data
|
|
|
+ else:
|
|
|
+ for item in reassemble_data:
|
|
|
+ yield item
|
|
|
# if output is not a generator, process all data of that and yield, so use output_list to collect all reassemble_data
|
|
|
else:
|
|
|
reassemble_data = self._check_output(output, input_)
|
|
|
@@ -53,7 +65,11 @@ class BaseComponent(ABC):
|
|
|
|
|
|
# avoid yielding output_list when the output is a generator
|
|
|
if len(output_list) > 0:
|
|
|
- yield output_list
|
|
|
+ if self.YIELD_BATCH:
|
|
|
+ yield output_list
|
|
|
+ else:
|
|
|
+ for item in output_list:
|
|
|
+ yield item
|
|
|
|
|
|
def _check_input(self, input_list):
|
|
|
# check if the value of input data meets the requirements of apply(),
|
|
|
@@ -119,7 +135,7 @@ class BaseComponent(ABC):
|
|
|
assert isinstance(ori_data, list) and len(ori_data) == len(output)
|
|
|
output_list = []
|
|
|
for ori_item, output_item in zip(ori_data, output):
|
|
|
- data = ori_item.copy() if self.keep_ori else {}
|
|
|
+ data = ori_item.copy() if self.keep_input else {}
|
|
|
for k, v in self.outputs.items():
|
|
|
if k not in output_item:
|
|
|
raise Exception(
|
|
|
@@ -132,7 +148,7 @@ class BaseComponent(ABC):
|
|
|
assert isinstance(ori_data, dict)
|
|
|
output_list = []
|
|
|
for output_item in output:
|
|
|
- data = ori_data.copy() if self.keep_ori else {}
|
|
|
+ data = ori_data.copy() if self.keep_input else {}
|
|
|
for k, v in self.outputs.items():
|
|
|
if k not in output_item:
|
|
|
raise Exception(
|
|
|
@@ -143,14 +159,14 @@ class BaseComponent(ABC):
|
|
|
return output_list
|
|
|
else:
|
|
|
assert isinstance(ori_data, dict) and isinstance(output, dict)
|
|
|
- data = ori_data.copy() if self.keep_ori else {}
|
|
|
+ data = ori_data.copy() if self.keep_input else {}
|
|
|
for k, v in self.outputs.items():
|
|
|
if k not in output:
|
|
|
raise Exception(
|
|
|
f"The value of key ({k}) is needed add to Data. But not found in output of {self.__class__.__name__}: ({output.keys()})!"
|
|
|
)
|
|
|
data.update({v: output[k]})
|
|
|
- return [data]
|
|
|
+ return [data]
|
|
|
|
|
|
def set_inputs(self, inputs):
|
|
|
assert isinstance(inputs, dict)
|
|
|
@@ -216,7 +232,7 @@ class BaseComponent(ABC):
|
|
|
return getattr(self, "ENABLE_BATCH", False)
|
|
|
|
|
|
@property
|
|
|
- def keep_ori(self):
|
|
|
+ def keep_input(self):
|
|
|
return getattr(self, "KEEP_INPUT", True)
|
|
|
|
|
|
|