瀏覽代碼

upgrade pipeline and predictor API

gaotingquan 1 年之前
父節點
當前提交
a79cad8557

+ 24 - 8
paddlex/inference/components/base.py

@@ -22,6 +22,10 @@ from ...utils import logging
 
 class BaseComponent(ABC):
 
+    YIELD_BATCH = True
+    KEEP_INPUT = True
+    ENABLE_BATCH = False
+
     INPUT_KEYS = None
     OUTPUT_KEYS = None
 
@@ -38,14 +42,22 @@ class BaseComponent(ABC):
         for args, input_ in self._check_input(input_list):
             output = self.apply(**args)
             if not output:
-                yield input_list
+                if self.YIELD_BATCH:
+                    yield input_list
+                else:
+                    for item in input_list:
+                        yield item
 
             # output may be a generator, when the apply() uses yield
             if isinstance(output, GeneratorType):
                 # if output is a generator, use for-in to get every one batch output data and yield one by one
                 for each_output in output:
                     reassemble_data = self._check_output(each_output, input_)
-                    yield reassemble_data
+                    if self.YIELD_BATCH:
+                        yield reassemble_data
+                    else:
+                        for item in reassemble_data:
+                            yield item
             # if output is not a generator, process all data of that and yield, so use output_list to collect all reassemble_data
             else:
                 reassemble_data = self._check_output(output, input_)
@@ -53,7 +65,11 @@ class BaseComponent(ABC):
 
         # avoid yielding output_list when the output is a generator
         if len(output_list) > 0:
-            yield output_list
+            if self.YIELD_BATCH:
+                yield output_list
+            else:
+                for item in output_list:
+                    yield item
 
     def _check_input(self, input_list):
         # check if the value of input data meets the requirements of apply(),
@@ -119,7 +135,7 @@ class BaseComponent(ABC):
                 assert isinstance(ori_data, list) and len(ori_data) == len(output)
                 output_list = []
                 for ori_item, output_item in zip(ori_data, output):
-                    data = ori_item.copy() if self.keep_ori else {}
+                    data = ori_item.copy() if self.keep_input else {}
                     for k, v in self.outputs.items():
                         if k not in output_item:
                             raise Exception(
@@ -132,7 +148,7 @@ class BaseComponent(ABC):
                 assert isinstance(ori_data, dict)
                 output_list = []
                 for output_item in output:
-                    data = ori_data.copy() if self.keep_ori else {}
+                    data = ori_data.copy() if self.keep_input else {}
                     for k, v in self.outputs.items():
                         if k not in output_item:
                             raise Exception(
@@ -143,14 +159,14 @@ class BaseComponent(ABC):
                 return output_list
         else:
             assert isinstance(ori_data, dict) and isinstance(output, dict)
-            data = ori_data.copy() if self.keep_ori else {}
+            data = ori_data.copy() if self.keep_input else {}
             for k, v in self.outputs.items():
                 if k not in output:
                     raise Exception(
                         f"The value of key ({k}) is needed add to Data. But not found in output of {self.__class__.__name__}: ({output.keys()})!"
                     )
                 data.update({v: output[k]})
-        return [data]
+            return [data]
 
     def set_inputs(self, inputs):
         assert isinstance(inputs, dict)
@@ -216,7 +232,7 @@ class BaseComponent(ABC):
         return getattr(self, "ENABLE_BATCH", False)
 
     @property
-    def keep_ori(self):
+    def keep_input(self):
         return getattr(self, "KEEP_INPUT", True)
 
 

+ 1 - 5
paddlex/inference/pipelines/__init__.py

@@ -12,10 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .image_classification import ClasPipeline
+from .single_model_pipeline import SingleModelPipeline
 from .ocr import OCRPipeline
 from .table_recognition import TableRecPipeline
-from .object_detection import DetPipeline
-from .instance_segmentation import InstanceSegPipeline
-from .semantic_segmentation import SegPipeline
-from .general_recognition import ShiTuRecPipeline

+ 1 - 0
paddlex/inference/pipelines/base.py

@@ -17,6 +17,7 @@ from typing import Any, Dict, Optional
 
 from ..predictors import create_predictor
 from ...utils.subclass_register import AutoRegisterABCMetaClass
+from ..predictors import create_predictor
 
 
 def create_pipeline(

+ 0 - 34
paddlex/inference/pipelines/image_classification.py

@@ -1,34 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .base import BasePipeline
-
-
-class ClasPipeline(BasePipeline):
-    """Cls Pipeline"""
-
-    entities = "image_classification"
-
-    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
-        super().__init__(predictor_kwargs)
-        self._predict = self._create_predictor(
-            model, batch_size=batch_size, device=device
-        )
-
-    def predict(self, x):
-        self._check_input(x)
-        yield from self._predict(x)
-
-    def _check_input(self, x):
-        pass

+ 0 - 34
paddlex/inference/pipelines/instance_segmentation.py

@@ -1,34 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .base import BasePipeline
-
-
-class InstanceSegPipeline(BasePipeline):
-    """InstanceSeg Pipeline"""
-
-    entities = "instance_segmentation"
-
-    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
-        super().__init__(predictor_kwargs)
-        self._predict = self._create_predictor(
-            model, batch_size=batch_size, device=device
-        )
-
-    def predict(self, x):
-        self._check_input(x)
-        yield from self._predict(x)
-
-    def _check_input(self, x):
-        pass

+ 0 - 34
paddlex/inference/pipelines/object_detection.py

@@ -1,34 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .base import BasePipeline
-
-
-class DetPipeline(BasePipeline):
-    """Det Pipeline"""
-
-    entities = "object_detection"
-
-    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
-        super().__init__(predictor_kwargs)
-        self._predict = self._create_predictor(
-            model, batch_size=batch_size, device=device
-        )
-
-    def predict(self, x):
-        self._check_input(x)
-        yield from self._predict(x)
-
-    def _check_input(self, x):
-        pass

+ 12 - 26
paddlex/inference/pipelines/ocr.py

@@ -23,36 +23,22 @@ class OCRPipeline(BasePipeline):
     entities = "ocr"
 
     def __init__(
-        self,
-        det_model,
-        rec_model,
-        det_batch_size,
-        rec_batch_size,
-        predictor_kwargs=None,
-        **kwargs
+        self, det_model, rec_model, rec_batch_size, predictor_kwargs=None, **kwargs
     ):
         super().__init__(predictor_kwargs)
-        self._det_predict = self._create_predictor(det_model, batch_size=det_batch_size)
+        self._det_predict = self._create_predictor(det_model)
         self._rec_predict = self._create_predictor(rec_model, batch_size=rec_batch_size)
         # TODO: foo
         self._crop_by_polys = CropByPolys(det_box_type="foo")
 
     def predict(self, x):
-        batch_ocr_res = []
-        for batch_det_res in self._det_predict(x):
-            for det_res in batch_det_res:
-                single_img_res = det_res["result"]
-                single_img_res["rec_text"] = []
-                single_img_res["rec_score"] = []
-                if len(single_img_res["dt_polys"]) > 0:
-                    all_subs_of_img = list(self._crop_by_polys(single_img_res))
-                    for batch_rec_res in self._rec_predict(all_subs_of_img):
-                        for rec_res in batch_rec_res:
-                            single_img_res["rec_text"].append(
-                                rec_res["result"]["rec_text"]
-                            )
-                            single_img_res["rec_score"].append(
-                                rec_res["result"]["rec_score"]
-                            )
-                batch_ocr_res.append({"result": OCRResult(single_img_res)})
-        yield batch_ocr_res
+        for det_res in self._det_predict(x):
+            single_img_res = det_res
+            single_img_res["rec_text"] = []
+            single_img_res["rec_score"] = []
+            if len(single_img_res["dt_polys"]) > 0:
+                all_subs_of_img = list(self._crop_by_polys(single_img_res))
+                for rec_res in self._rec_predict(all_subs_of_img):
+                    single_img_res["rec_text"].append(rec_res["rec_text"])
+                    single_img_res["rec_score"].append(rec_res["rec_score"])
+            yield single_img_res

+ 0 - 34
paddlex/inference/pipelines/semantic_segmentation.py

@@ -1,34 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .base import BasePipeline
-
-
-class SegPipeline(BasePipeline):
-    """Det Pipeline"""
-
-    entities = "semantic_segmentation"
-
-    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
-        super().__init__(predictor_kwargs)
-        self._predict = self._create_predictor(
-            model, batch_size=batch_size, device=device
-        )
-
-    def predict(self, x):
-        self._check_input(x)
-        yield from self._predict(x)
-
-    def _check_input(self, x):
-        pass

+ 7 - 7
paddlex/inference/pipelines/general_recognition.py → paddlex/inference/pipelines/single_model_pipeline.py

@@ -15,10 +15,14 @@
 from .base import BasePipeline
 
 
-class ShiTuRecPipeline(BasePipeline):
-    """ShiTu Rec Pipeline"""
+class SingleModelPipeline(BasePipeline):
 
-    entities = "general_recognition"
+    entities = [
+        "image_classification",
+        "object_detection",
+        "instance_segmentation",
+        "semantic_segmentation",
+    ]
 
     def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
         super().__init__(predictor_kwargs)
@@ -27,8 +31,4 @@ class ShiTuRecPipeline(BasePipeline):
         )
 
     def predict(self, x):
-        self._check_input(x)
         yield from self._predict(x)
-
-    def _check_input(self, x):
-        pass

+ 10 - 4
paddlex/inference/predictors/base.py

@@ -36,13 +36,14 @@ def _get_default_device():
 
 
 class BasePredictor(BaseComponent):
+    KEEP_INPUT = False
+    YIELD_BATCH = False
+
     INPUT_KEYS = "x"
     DEAULT_INPUTS = {"x": "x"}
     OUTPUT_KEYS = "result"
     DEAULT_OUTPUTS = {"result": "result"}
 
-    KEEP_INPUT = False
-
     MODEL_FILE_PREFIX = "inference"
 
     def __init__(self, model_dir, config=None, device=None, **kwargs):
@@ -54,6 +55,10 @@ class BasePredictor(BaseComponent):
         # alias predict() to the __call__()
         self.predict = self.__call__
 
+    def __call__(self, *args, **kwargs):
+        for res in super().__call__(*args, **kwargs):
+            yield res["result"]
+
     @property
     def config_path(self):
         return self.get_config_path(self.model_dir)
@@ -82,6 +87,7 @@ class BasePredictor(BaseComponent):
 
 
 class BasicPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
+
     __is_base = True
 
     def __init__(self, model_dir, config=None, device=None, pp_option=None, **kwargs):
@@ -99,8 +105,8 @@ class BasicPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
         yield from self._generate_res(self.engine(x))
 
     @generatorable_method
-    def _generate_res(self, data):
-        return self._pack_res(data)
+    def _generate_res(self, batch_data):
+        return [{"result": self._pack_res(data)} for data in batch_data]
 
     @abstractmethod
     def _build_components(self):

+ 3 - 4
paddlex/inference/predictors/image_classification.py

@@ -103,9 +103,8 @@ class ClasPredictor(BasicPredictor):
     def build_threshoutput(self, threshold, label_list=None):
         return MultiLabelThreshOutput(threshold=float(threshold), class_ids=label_list)
 
-    @batchable_method
-    def _pack_res(self, data):
+    def _pack_res(self, single):
         keys = ["img_path", "class_ids", "scores"]
-        if "label_names" in data:
+        if "label_names" in single:
             keys.append("label_names")
-        return {"result": TopkResult({key: data[key] for key in keys})}
+        return TopkResult({key: single[key] for key in keys})

+ 2 - 3
paddlex/inference/predictors/instance_segmentation.py

@@ -53,7 +53,6 @@ class InstanceSegPredictor(DetPredictor):
 
         return ops
 
-    @batchable_method
-    def _pack_res(self, data):
+    def _pack_res(self, single):
         keys = ["img_path", "boxes", "masks", "labels"]
-        return {"result": InstanceSegResult({key: data[key] for key in keys})}
+        return InstanceSegResult({key: single[key] for key in keys})

+ 2 - 3
paddlex/inference/predictors/object_detection.py

@@ -93,7 +93,6 @@ class DetPredictor(BasicPredictor):
     def build_to_chw(self):
         return ToCHWImage()
 
-    @batchable_method
-    def _pack_res(self, data):
+    def _pack_res(self, single):
         keys = ["img_path", "boxes", "labels"]
-        return {"result": DetResult({key: data[key] for key in keys})}
+        return DetResult({key: single[key] for key in keys})

+ 2 - 3
paddlex/inference/predictors/semantic_segmentation.py

@@ -90,7 +90,6 @@ class SegPredictor(BasicPredictor):
     ):
         return Normalize(mean=mean, std=std)
 
-    @batchable_method
-    def _pack_res(self, data):
+    def _pack_res(self, single):
         keys = ["img_path", "pred"]
-        return {"result": SegResult({key: data[key] for key in keys})}
+        return SegResult({key: single[key] for key in keys})

+ 2 - 3
paddlex/inference/predictors/table_recognition.py

@@ -99,7 +99,6 @@ class TablePredictor(BasicPredictor):
     def foo(self, *args, **kwargs):
         return None
 
-    @batchable_method
-    def _pack_res(self, data):
+    def _pack_res(self, single):
         keys = ["img_path", "bbox", "structure"]
-        return {"result": TableRecResult({key: data[key] for key in keys})}
+        return TableRecResult({key: single[key] for key in keys})

+ 2 - 3
paddlex/inference/predictors/text_detection.py

@@ -99,7 +99,6 @@ class TextDetPredictor(BasicPredictor):
     def foo(self, *args, **kwargs):
         return None
 
-    @batchable_method
-    def _pack_res(self, data):
+    def _pack_res(self, single):
         keys = ["img_path", "dt_polys", "dt_scores"]
-        return {"result": TextDetResult({key: data[key] for key in keys})}
+        return TextDetResult({key: single[key] for key in keys})

+ 2 - 3
paddlex/inference/predictors/text_recognition.py

@@ -76,7 +76,6 @@ class TextRecPredictor(BasicPredictor):
     def foo(self, *args, **kwargs):
         return None
 
-    @batchable_method
-    def _pack_res(self, data):
+    def _pack_res(self, single):
         keys = ["img_path", "rec_text", "rec_score"]
-        return {"result": TextRecResult({key: data[key] for key in keys})}
+        return TextRecResult({key: single[key] for key in keys})

+ 2 - 7
paddlex/inference/predictors/ts_cls.py

@@ -56,10 +56,5 @@ class TSClsPredictor(BasicPredictor):
 
         return ops
 
-    @batchable_method
-    def _pack_res(self, data):
-        return {
-            "result": TSClsResult(
-                {"ts_path": data["ts_path"], "forecast": data["pred"]}
-            )
-        }
+    def _pack_res(self, single):
+        return TSClsResult({"ts_path": single["ts_path"], "forecast": single["pred"]})

+ 2 - 5
paddlex/inference/predictors/ts_fc.py

@@ -80,8 +80,5 @@ class TSFcPredictor(BasicPredictor):
             )
         return ops
 
-    @batchable_method
-    def _pack_res(self, data):
-        return {
-            "result": TSFcResult({"ts_path": data["ts_path"], "forecast": data["pred"]})
-        }
+    def _pack_res(self, single):
+        return TSFcResult({"ts_path": single["ts_path"], "forecast": single["pred"]})