فهرست منبع

upgrade

1. unify pipeline API
2. support to set batch size for TSRead
gaotingquan 1 سال پیش
والد
کامیت
abee33d8a8
28فایلهای تغییر یافته به همراه363 افزوده شده و 240 حذف شده
  1. 4 0
      paddlex/inference/components/base.py
  2. 5 9
      paddlex/inference/components/paddle_predictor/predictor.py
  3. 7 36
      paddlex/inference/components/transforms/image/common.py
  4. 57 0
      paddlex/inference/components/transforms/read_data.py
  5. 18 14
      paddlex/inference/components/transforms/ts/common.py
  6. 47 0
      paddlex/inference/components/utils/mixin.py
  7. 2 3
      paddlex/inference/models/base/__init__.py
  8. 4 72
      paddlex/inference/models/base/base_predictor.py
  9. 93 0
      paddlex/inference/models/base/basic_predictor.py
  10. 0 22
      paddlex/inference/models/base/cv_predictor.py
  11. 0 20
      paddlex/inference/models/base/ts_predictor.py
  12. 3 3
      paddlex/inference/models/general_recognition.py
  13. 3 3
      paddlex/inference/models/image_classification.py
  14. 3 3
      paddlex/inference/models/image_unwarping.py
  15. 1 1
      paddlex/inference/models/instance_segmentation.py
  16. 3 3
      paddlex/inference/models/object_detection.py
  17. 3 3
      paddlex/inference/models/semantic_segmentation.py
  18. 3 3
      paddlex/inference/models/table_recognition.py
  19. 3 3
      paddlex/inference/models/text_detection.py
  20. 3 3
      paddlex/inference/models/text_recognition.py
  21. 3 3
      paddlex/inference/models/ts_ad.py
  22. 3 3
      paddlex/inference/models/ts_cls.py
  23. 3 3
      paddlex/inference/models/ts_fc.py
  24. 3 3
      paddlex/inference/models/utils/predict_set.py
  25. 19 4
      paddlex/inference/pipelines/__init__.py
  26. 8 0
      paddlex/inference/pipelines/ocr.py
  27. 50 15
      paddlex/inference/pipelines/single_model_pipeline.py
  28. 12 8
      paddlex/inference/pipelines/table_recognition/table_recognition.py

+ 4 - 0
paddlex/inference/components/base.py

@@ -253,6 +253,10 @@ class BaseComponent(ABC):
     def keep_input(self):
         return getattr(self, "KEEP_INPUT", True)
 
+    @property
+    def name(self):
+        return getattr(self, "NAME", self.__class__.__name__)
+
 
 class ComponentsEngine(object):
     def __init__(self, ops):

+ 5 - 9
paddlex/inference/components/paddle_predictor/predictor.py

@@ -19,10 +19,11 @@ import numpy as np
 
 from ....utils import logging
 from ...utils.pp_option import PaddlePredictorOption
+from ..utils.mixin import PPEngineMixin
 from ..base import BaseComponent
 
 
-class BasePaddlePredictor(BaseComponent):
+class BasePaddlePredictor(BaseComponent, PPEngineMixin):
     """Predictor based on Paddle Inference"""
 
     OUTPUT_KEYS = "pred"
@@ -31,12 +32,12 @@ class BasePaddlePredictor(BaseComponent):
 
     def __init__(self, model_dir, model_prefix, option: PaddlePredictorOption = None):
         super().__init__()
+        PPEngineMixin.__init__(self, option)
         self.model_dir = model_dir
         self.model_prefix = model_prefix
-        self.option = option
         self._is_initialized = False
 
-    def _build(self):
+    def _reset(self):
         if not self.option:
             self.option = PaddlePredictorOption()
         (
@@ -163,7 +164,7 @@ No need to generate again."
 
     def apply(self, **kwargs):
         if not self._is_initialized:
-            self._build()
+            self._reset()
 
         x = self.to_batch(**kwargs)
         for idx in range(len(x)):
@@ -180,11 +181,6 @@ No need to generate again."
     def format_output(self, pred):
         return [{"pred": res} for res in zip(*pred)]
 
-    def set_option(self, option):
-        if option != self.option:
-            self.option = option
-            self._build()
-
     @abstractmethod
     def to_batch(self):
         raise NotImplementedError

+ 7 - 36
paddlex/inference/components/transforms/image/common.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
 import math
 
 from pathlib import Path
@@ -21,10 +20,11 @@ from copy import deepcopy
 import numpy as np
 import cv2
 
-from .....utils.download import download
 from .....utils.cache import CACHE_DIR
 from ....utils.io import ImageReader, ImageWriter
+from ...utils.mixin import BatchSizeMixin
 from ...base import BaseComponent
+from ..read_data import _BaseRead
 from . import funcs as F
 
 __all__ = [
@@ -52,7 +52,7 @@ def _check_image_size(input_):
         raise TypeError(f"{input_} cannot represent a valid image size.")
 
 
-class ReadImage(BaseComponent):
+class ReadImage(_BaseRead):
     """Load image from the file."""
 
     INPUT_KEYS = ["img"]
@@ -71,6 +71,7 @@ class ReadImage(BaseComponent):
         "RGB": cv2.IMREAD_COLOR,
         "GRAY": cv2.IMREAD_GRAYSCALE,
     }
+
     SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp"]
 
     def __init__(self, batch_size=1, format="BGR"):
@@ -81,8 +82,7 @@ class ReadImage(BaseComponent):
             format (str, optional): Target color format to convert the image to.
                 Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'.
         """
-        super().__init__()
-        self.batch_size = batch_size
+        super().__init__(batch_size)
         self.format = format
         flags = self._FLAGS_DICT[self.format]
         self._reader = ImageReader(backend="opencv", flags=flags)
@@ -104,11 +104,10 @@ class ReadImage(BaseComponent):
             ]
         else:
             img_path = img
-            # XXX: auto download for url
             img_path = self._download_from_url(img_path)
-            image_list = self._get_image_list(img_path)
+            file_list = self._get_files_list(img_path)
             batch = []
-            for img_path in image_list:
+            for img_path in file_list:
                 img = self._read_img(img_path)
                 batch.append(img)
                 if len(batch) >= self.batch_size:
@@ -135,34 +134,6 @@ class ReadImage(BaseComponent):
             "ori_img_size": deepcopy([blob.shape[1], blob.shape[0]]),
         }
 
-    def _download_from_url(self, in_path):
-        if in_path.startswith("http"):
-            file_name = Path(in_path).name
-            save_path = Path(CACHE_DIR) / "predict_input" / file_name
-            download(in_path, save_path, overwrite=True)
-            return save_path.as_posix()
-        return in_path
-
-    def _get_image_list(self, img_file):
-        imgs_lists = []
-        if img_file is None or not os.path.exists(img_file):
-            raise Exception(f"Not found any img file in path: {img_file}")
-
-        if os.path.isfile(img_file) and img_file.split(".")[-1] in self.SUFFIX:
-            imgs_lists.append(img_file)
-        elif os.path.isdir(img_file):
-            for root, dirs, files in os.walk(img_file):
-                for single_file in files:
-                    if single_file.split(".")[-1] in self.SUFFIX:
-                        imgs_lists.append(os.path.join(root, single_file))
-        if len(imgs_lists) == 0:
-            raise Exception("not found any img file in {}".format(img_file))
-        imgs_lists = sorted(imgs_lists)
-        return imgs_lists
-
-    def set_batch_size(self, batch_size):
-        self.batch_size = batch_size
-
 
 class GetImageInfo(BaseComponent):
     """Get Image Info"""

+ 57 - 0
paddlex/inference/components/transforms/read_data.py

@@ -0,0 +1,57 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from pathlib import Path
+
+from ....utils.download import download
+from ....utils.cache import CACHE_DIR
+from ..utils.mixin import BatchSizeMixin
+from ..base import BaseComponent
+
+
+class _BaseRead(BaseComponent, BatchSizeMixin):
+    """Load image from the file."""
+
+    SUFFIX = []
+
+    def __init__(self, batch_size=1):
+        super().__init__()
+        BatchSizeMixin.__init__(self, batch_size)
+
+    # XXX: auto download for url
+    def _download_from_url(self, in_path):
+        if in_path.startswith("http"):
+            file_name = Path(in_path).name
+            save_path = Path(CACHE_DIR) / "predict_input" / file_name
+            download(in_path, save_path, overwrite=True)
+            return save_path.as_posix()
+        return in_path
+
+    def _get_files_list(self, fp):
+        file_list = []
+        if fp is None or not os.path.exists(fp):
+            raise Exception(f"Not found any img file in path: {fp}")
+
+        if os.path.isfile(fp) and fp.split(".")[-1] in self.SUFFIX:
+            file_list.append(fp)
+        elif os.path.isdir(fp):
+            for root, dirs, files in os.walk(fp):
+                for single_file in files:
+                    if single_file.split(".")[-1] in self.SUFFIX:
+                        file_list.append(os.path.join(root, single_file))
+        if len(file_list) == 0:
+            raise Exception("Not found any file in {}".format(fp))
+        file_list = sorted(file_list)
+        return file_list

+ 18 - 14
paddlex/inference/components/transforms/ts/common.py

@@ -23,6 +23,7 @@ from .....utils.cache import CACHE_DIR
 from ....utils.io.readers import TSReader
 from ....utils.io.writers import TSWriter
 from ...base import BaseComponent
+from ..read_data import _BaseRead
 from .funcs import load_from_dataframe, time_feature
 
 
@@ -41,15 +42,17 @@ __all__ = [
 ]
 
 
-class ReadTS(BaseComponent):
+class ReadTS(_BaseRead):
 
     INPUT_KEYS = ["ts"]
     OUTPUT_KEYS = ["ts_path", "ts", "ori_ts"]
     DEAULT_INPUTS = {"ts": "ts"}
     DEAULT_OUTPUTS = {"ts_path": "ts_path", "ts": "ts", "ori_ts": "ori_ts"}
 
-    def __init__(self):
-        super().__init__()
+    SUFFIX = ["csv"]
+
+    def __init__(self, batch_size=1):
+        super().__init__(batch_size)
         self._reader = TSReader(backend="pandas")
         self._writer = TSWriter(backend="pandas")
 
@@ -60,18 +63,19 @@ class ReadTS(BaseComponent):
             return {"ts_path": ts_path, "ts": ts, "ori_ts": deepcopy(ts)}
 
         ts_path = ts
-        # XXX: auto download for url
         ts_path = self._download_from_url(ts_path)
-        ts = self._reader.read(ts_path)
-        return {"ts_path": ts_path, "ts": ts, "ori_ts": deepcopy(ts)}
-
-    def _download_from_url(self, in_path):
-        if in_path.startswith("http"):
-            file_name = Path(in_path).name
-            save_path = Path(CACHE_DIR) / "predict_input" / file_name
-            download(in_path, save_path, overwrite=True)
-            return save_path.as_posix()
-        return in_path
+        file_list = self._get_files_list(ts_path)
+        batch = []
+        for ts_path in file_list:
+            ts_data = self._reader.read(ts_path)
+            batch.append(
+                {"ts_path": ts_path, "ts": ts_data, "ori_ts": deepcopy(ts_data)}
+            )
+            if len(batch) >= self.batch_size:
+                yield batch
+                batch = []
+        if len(batch) > 0:
+            yield batch
 
 
 class TSCutOff(BaseComponent):

+ 47 - 0
paddlex/inference/components/utils/mixin.py

@@ -0,0 +1,47 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class BatchSizeMixin:
+    NAME = "ReadCmp"
+
+    def __init__(self, batch_size=1):
+        self._batch_size = batch_size
+
+    @property
+    def batch_size(self):
+        return self._batch_size
+
+    @batch_size.setter
+    def batch_size(self, value):
+        if value <= 0:
+            raise ValueError("Batch size must be positive.")
+        self._batch_size = value
+
+
+class PPEngineMixin:
+    NAME = "PPEngineCmp"
+
+    def __init__(self, option=None):
+        self._option = option
+
+    @property
+    def option(self):
+        return self._option
+
+    @option.setter
+    def option(self, value):
+        if value != self.option:
+            self._option = value
+            self._reset()

+ 2 - 3
paddlex/inference/models/base/__init__.py

@@ -12,6 +12,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base_predictor import BasePredictor, BasicPredictor
-from .cv_predictor import CVPredictor
-from .ts_predictor import TSPredictor
+from .base_predictor import BasePredictor
+from .basic_predictor import BasicPredictor

+ 4 - 72
paddlex/inference/models/base/base_predictor.py

@@ -17,14 +17,10 @@ import codecs
 from pathlib import Path
 from abc import abstractmethod
 
-from ....utils.subclass_register import AutoRegisterABCMetaClass
-from ....utils.func_register import FuncRegister
-from ....utils import logging
-from ...utils.device import constr_device
-from ...components.base import BaseComponent, ComponentsEngine
+from ...components.base import BaseComponent
 from ...utils.pp_option import PaddlePredictorOption
 from ...utils.process_hook import generatorable_method
-from ..utils.predict_set import DeviceSetMixin, PPOptionSetMixin
+from ..utils.predict_set import DeviceSetMixin, PPOptionSetMixin, BatchSizeSetMixin
 
 
 class BasePredictor(BaseComponent):
@@ -48,7 +44,7 @@ class BasePredictor(BaseComponent):
         self.predict = self.__call__
 
     def __call__(self, input, **kwargs):
-        self.set_predict(**kwargs)
+        self.set_predictor(**kwargs)
         for res in super().__call__(input):
             yield res["result"]
 
@@ -65,7 +61,7 @@ class BasePredictor(BaseComponent):
         raise NotImplementedError
 
     @abstractmethod
-    def set_predict(self):
+    def set_predictor(self):
         raise NotImplementedError
 
     @classmethod
@@ -78,67 +74,3 @@ class BasePredictor(BaseComponent):
         with codecs.open(config_path, "r", "utf-8") as file:
             dic = yaml.load(file, Loader=yaml.FullLoader)
         return dic
-
-
-class BasicPredictor(
-    BasePredictor, DeviceSetMixin, PPOptionSetMixin, metaclass=AutoRegisterABCMetaClass
-):
-
-    __is_base = True
-
-    def __init__(self, model_dir, config=None, device=None, pp_option=None):
-        super().__init__(model_dir=model_dir, config=config)
-        self._pred_set_func_map = {}
-        self._pred_set_register = FuncRegister(self._pred_set_func_map)
-        self._pred_set_register("device")(self.set_device)
-        self._pred_set_register("pp_option")(self.set_pp_option)
-
-        self.pp_option = pp_option if pp_option else PaddlePredictorOption()
-        self.pp_option.set_device(device)
-        self.components = {}
-        self._build_components()
-        self.engine = ComponentsEngine(self.components)
-        logging.debug(
-            f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}"
-        )
-
-    def apply(self, x):
-        """predict"""
-        yield from self._generate_res(self.engine(x))
-
-    @generatorable_method
-    def _generate_res(self, batch_data):
-        return [{"result": self._pack_res(data)} for data in batch_data]
-
-    def _add_component(self, cmps):
-        if not isinstance(cmps, list):
-            cmps = [cmps]
-
-        for cmp in cmps:
-            if not isinstance(cmp, (list, tuple)):
-                key = cmp.__class__.__name__
-            else:
-                assert len(cmp) == 2
-                key = cmp[0]
-                cmp = cmp[1]
-            assert isinstance(key, str)
-            assert isinstance(cmp, BaseComponent)
-            assert (
-                key not in self.components
-            ), f"The key ({key}) has been used: {self.components}!"
-            self.components[key] = cmp
-
-    def set_predict(self, **kwargs):
-        for k in kwargs:
-            assert (
-                k in self._pred_set_func_map
-            ), f"The arg({k}) is not supported to specify in predict() func! Only supports: {self._pred_set_func_map.keys()}"
-            self._pred_set_func_map[k](kwargs[k])
-
-    @abstractmethod
-    def _build_components(self):
-        raise NotImplementedError
-
-    @abstractmethod
-    def _pack_res(self, data):
-        raise NotImplementedError

+ 93 - 0
paddlex/inference/models/base/basic_predictor.py

@@ -0,0 +1,93 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import abstractmethod
+
+from ....utils.subclass_register import AutoRegisterABCMetaClass
+from ....utils.func_register import FuncRegister
+from ....utils import logging
+from ...components.base import BaseComponent, ComponentsEngine
+from ...utils.pp_option import PaddlePredictorOption
+from ...utils.process_hook import generatorable_method
+from ..utils.predict_set import DeviceSetMixin, PPOptionSetMixin, BatchSizeSetMixin
+from .base_predictor import BasePredictor
+
+
+class BasicPredictor(
+    BasePredictor,
+    DeviceSetMixin,
+    PPOptionSetMixin,
+    BatchSizeSetMixin,
+    metaclass=AutoRegisterABCMetaClass,
+):
+
+    __is_base = True
+
+    def __init__(self, model_dir, config=None, device=None, pp_option=None):
+        super().__init__(model_dir=model_dir, config=config)
+        self._pred_set_func_map = {}
+        self._pred_set_register = FuncRegister(self._pred_set_func_map)
+        self._pred_set_register("device")(self.set_device)
+        self._pred_set_register("pp_option")(self.set_pp_option)
+        self._pred_set_register("batch_size")(self.set_batch_size)
+
+        self.pp_option = pp_option if pp_option else PaddlePredictorOption()
+        self.pp_option.set_device(device)
+        self.components = {}
+        self._build_components()
+        self.engine = ComponentsEngine(self.components)
+        logging.debug(
+            f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}"
+        )
+
+    def apply(self, x):
+        """predict"""
+        yield from self._generate_res(self.engine(x))
+
+    @generatorable_method
+    def _generate_res(self, batch_data):
+        return [{"result": self._pack_res(data)} for data in batch_data]
+
+    def _add_component(self, cmps):
+        if not isinstance(cmps, list):
+            cmps = [cmps]
+
+        for cmp in cmps:
+            if not isinstance(cmp, (list, tuple)):
+                key = cmp.name
+            else:
+                assert len(cmp) == 2
+                key = cmp[0]
+                cmp = cmp[1]
+            assert isinstance(key, str)
+            assert isinstance(cmp, BaseComponent)
+            assert (
+                key not in self.components
+            ), f"The key ({key}) has been used: {self.components}!"
+            self.components[key] = cmp
+
+    def set_predictor(self, **kwargs):
+        for k in kwargs:
+            assert (
+                k in self._pred_set_func_map
+            ), f"The arg({k}) is not supported to specify in predict() func! Only supports: {self._pred_set_func_map.keys()}"
+            self._pred_set_func_map[k](kwargs[k])
+
+    @abstractmethod
+    def _build_components(self):
+        raise NotImplementedError
+
+    @abstractmethod
+    def _pack_res(self, data):
+        raise NotImplementedError

+ 0 - 22
paddlex/inference/models/base/cv_predictor.py

@@ -1,22 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from ..utils.predict_set import BatchSizeSetMixin
-from .base_predictor import BasicPredictor
-
-
-class CVPredictor(BasicPredictor, BatchSizeSetMixin):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self._pred_set_register("batch_size")(self.set_batch_size)

+ 0 - 20
paddlex/inference/models/base/ts_predictor.py

@@ -1,20 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .base_predictor import BasicPredictor
-
-
-class TSPredictor(BasicPredictor):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)

+ 3 - 3
paddlex/inference/models/general_recognition.py

@@ -18,10 +18,10 @@ from ...utils.func_register import FuncRegister
 from ...modules.general_recognition.model_list import MODELS
 from ..components import *
 from ..results import BaseResult
-from .base import CVPredictor
+from .base import BasicPredictor
 
 
-class ShiTuRecPredictor(CVPredictor):
+class ShiTuRecPredictor(BasicPredictor):
 
     entities = MODELS
 
@@ -42,7 +42,7 @@ class ShiTuRecPredictor(CVPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component(("Predictor", predictor))
+        self._add_component(predictor)
 
         post_processes = self.config["PostProcess"]
         for key in post_processes:

+ 3 - 3
paddlex/inference/models/image_classification.py

@@ -19,10 +19,10 @@ from ...modules.image_classification.model_list import MODELS
 from ...modules.multilabel_classification.model_list import MODELS as ML_MODELS
 from ..components import *
 from ..results import TopkResult
-from .base import CVPredictor
+from .base import BasicPredictor
 
 
-class ClasPredictor(CVPredictor):
+class ClasPredictor(BasicPredictor):
 
     entities = [*MODELS, *ML_MODELS]
 
@@ -43,7 +43,7 @@ class ClasPredictor(CVPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component(("Predictor", predictor))
+        self._add_component(predictor)
 
         post_processes = self.config["PostProcess"]
         for key in post_processes:

+ 3 - 3
paddlex/inference/models/image_unwarping.py

@@ -15,10 +15,10 @@
 from ...modules.image_unwarping.model_list import MODELS
 from ..components import *
 from ..results import DocTrResult
-from .base import CVPredictor
+from .base import BasicPredictor
 
 
-class WarpPredictor(CVPredictor):
+class WarpPredictor(BasicPredictor):
 
     entities = MODELS
 
@@ -36,7 +36,7 @@ class WarpPredictor(CVPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component([("Predictor", predictor), DocTrPostProcess()])
+        self._add_component([predictor, DocTrPostProcess()])
 
     def _pack_res(self, single):
         keys = ["img_path", "doctr_img"]

+ 1 - 1
paddlex/inference/models/instance_segmentation.py

@@ -46,7 +46,7 @@ class InstanceSegPredictor(DetPredictor):
             )
         self._add_component(
             [
-                ("Predictor", predictor),
+                predictor,
                 InstanceSegPostProcess(
                     threshold=self.config["draw_threshold"],
                     labels=self.config["label_list"],

+ 3 - 3
paddlex/inference/models/object_detection.py

@@ -18,10 +18,10 @@ from ...utils.func_register import FuncRegister
 from ...modules.object_detection.model_list import MODELS
 from ..components import *
 from ..results import DetResult
-from .base import CVPredictor
+from .base import BasicPredictor
 
 
-class DetPredictor(CVPredictor):
+class DetPredictor(BasicPredictor):
 
     entities = MODELS
 
@@ -54,7 +54,7 @@ class DetPredictor(CVPredictor):
 
         self._add_component(
             [
-                ("Predictor", predictor),
+                predictor,
                 DetPostProcess(
                     threshold=self.config["draw_threshold"],
                     labels=self.config["label_list"],

+ 3 - 3
paddlex/inference/models/semantic_segmentation.py

@@ -18,10 +18,10 @@ from ...utils.func_register import FuncRegister
 from ...modules.semantic_segmentation.model_list import MODELS
 from ..components import *
 from ..results import SegResult
-from .base import CVPredictor
+from .base import BasicPredictor
 
 
-class SegPredictor(CVPredictor):
+class SegPredictor(BasicPredictor):
 
     entities = MODELS
 
@@ -44,7 +44,7 @@ class SegPredictor(CVPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component(("Predictor", predictor))
+        self._add_component(predictor)
 
     @register("Resize")
     def build_resize(

+ 3 - 3
paddlex/inference/models/table_recognition.py

@@ -19,10 +19,10 @@ from ...utils.func_register import FuncRegister
 from ...modules.table_recognition.model_list import MODELS
 from ..components import *
 from ..results import TableRecResult
-from .base import CVPredictor
+from .base import BasicPredictor
 
 
-class TablePredictor(CVPredictor):
+class TablePredictor(BasicPredictor):
     """table recognition predictor"""
 
     entities = MODELS
@@ -44,7 +44,7 @@ class TablePredictor(CVPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component(("Predictor", predictor))
+        self._add_component(predictor)
 
         op = self.build_postprocess(**self.config["PostProcess"])
         self._add_component(op)

+ 3 - 3
paddlex/inference/models/text_detection.py

@@ -18,10 +18,10 @@ from ...utils.func_register import FuncRegister
 from ...modules.text_detection.model_list import MODELS
 from ..components import *
 from ..results import TextDetResult
-from .base import CVPredictor
+from .base import BasicPredictor
 
 
-class TextDetPredictor(CVPredictor):
+class TextDetPredictor(BasicPredictor):
 
     entities = MODELS
 
@@ -42,7 +42,7 @@ class TextDetPredictor(CVPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component(("Predictor", predictor))
+        self._add_component(predictor)
 
         op = self.build_postprocess(**self.config["PostProcess"])
         self._add_component(op)

+ 3 - 3
paddlex/inference/models/text_recognition.py

@@ -18,10 +18,10 @@ from ...utils.func_register import FuncRegister
 from ...modules.text_recognition.model_list import MODELS
 from ..components import *
 from ..results import TextRecResult
-from .base import CVPredictor
+from .base import BasicPredictor
 
 
-class TextRecPredictor(CVPredictor):
+class TextRecPredictor(BasicPredictor):
 
     entities = MODELS
 
@@ -43,7 +43,7 @@ class TextRecPredictor(CVPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component(("Predictor", predictor))
+        self._add_component(predictor)
 
         op = self.build_postprocess(**self.config["PostProcess"])
         self._add_component(op)

+ 3 - 3
paddlex/inference/models/ts_ad.py

@@ -17,10 +17,10 @@ import os
 from ...modules.ts_anomaly_detection.model_list import MODELS
 from ..components import *
 from ..results import TSAdResult
-from .base import TSPredictor
+from .base import BasicPredictor
 
 
-class TSAdPredictor(TSPredictor):
+class TSAdPredictor(BasicPredictor):
 
     entities = MODELS
 
@@ -56,7 +56,7 @@ class TSAdPredictor(TSPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component(("Predictor", predictor))
+        self._add_component(predictor)
 
         self._add_component(
             GetAnomaly(self.config["model_threshold"], self.config["info_params"])

+ 3 - 3
paddlex/inference/models/ts_cls.py

@@ -16,10 +16,10 @@ import os
 from ...modules.ts_classification.model_list import MODELS
 from ..components import *
 from ..results import TSClsResult
-from .base import TSPredictor
+from .base import BasicPredictor
 
 
-class TSClsPredictor(TSPredictor):
+class TSClsPredictor(BasicPredictor):
 
     entities = MODELS
 
@@ -45,7 +45,7 @@ class TSClsPredictor(TSPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component(("Predictor", predictor))
+        self._add_component(predictor)
         self._add_component(GetCls())
 
     def _pack_res(self, single):

+ 3 - 3
paddlex/inference/models/ts_fc.py

@@ -17,10 +17,10 @@ import os
 from ...modules.ts_forecast.model_list import MODELS
 from ..components import *
 from ..results import TSFcResult
-from .base import TSPredictor
+from .base import BasicPredictor
 
 
-class TSFcPredictor(TSPredictor):
+class TSFcPredictor(BasicPredictor):
 
     entities = MODELS
 
@@ -56,7 +56,7 @@ class TSFcPredictor(TSPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        self._add_component(("Predictor", predictor))
+        self._add_component(predictor)
 
         self._add_component(ArraytoTS(self.config["info_params"]))
         if self.config.get("scale", None):

+ 3 - 3
paddlex/inference/models/utils/predict_set.py

@@ -15,16 +15,16 @@
 
 class BatchSizeSetMixin:
     def set_batch_size(self, batch_size):
-        self.components["ReadImage"].set_batch_size(batch_size)
+        self.components["ReadCmp"].batch_size = batch_size
 
 
 class DeviceSetMixin:
     def set_device(self, device):
         self.pp_option.set_device(device)
-        self.components["Predictor"].set_option(self.pp_option)
+        self.components["PPEngineCmp"].option = self.pp_option
 
 
 class PPOptionSetMixin:
     def set_pp_option(self, pp_option):
         self.pp_option = pp_option
-        self.components["Predictor"].set_option(self.pp_option)
+        self.components["PPEngineCmp"].option = self.pp_option

+ 19 - 4
paddlex/inference/pipelines/__init__.py

@@ -17,13 +17,27 @@ from typing import Any, Dict, Optional
 
 from ...utils.config import parse_config
 from .base import BasePipeline
-from .single_model_pipeline import SingleModelPipeline
+from .single_model_pipeline import (
+    _SingleModelPipeline,
+    ImageClassification,
+    ObjectDetection,
+    InstanceSegmentation,
+    SemanticSegmentation,
+    TSFc,
+    TSAd,
+    TSCls,
+    MultiLableImageClas,
+    SmallObjDet,
+    AnomolyDetection,
+)
 from .ocr import OCRPipeline
 from .table_recognition import TableRecPipeline
 
 
 def create_pipeline(
     pipeline: str,
+    device=None,
+    pp_option=None,
     use_hpip: bool = False,
     hpi_params: Optional[Dict[str, Any]] = None,
     *args,
@@ -39,17 +53,18 @@ def create_pipeline(
     """
     if not Path(pipeline).exists():
         # XXX: using dict class to handle all pipeline configs
-        pipeline = (
+        build_in_pipeline = (
             Path(__file__).parent.parent.parent / "pipelines" / f"{pipeline}.yaml"
         )
-        if not Path(pipeline).exists():
+        if not Path(build_in_pipeline).exists():
             raise Exception(f"The pipeline don't exist! ({pipeline})")
+        pipeline = build_in_pipeline
     config = parse_config(pipeline)
     pipeline_name = config["Global"]["pipeline_name"]
     pipeline_setting = config["Pipeline"]
     pipeline_setting.update(kwargs)
 
-    predictor_kwargs = {"use_hpip": use_hpip}
+    predictor_kwargs = {"device": device, "pp_option": pp_option, "use_hpip": use_hpip}
     if hpi_params is not None:
         predictor_kwargs["hpi_params"] = hpi_params
 

+ 8 - 0
paddlex/inference/pipelines/ocr.py

@@ -26,9 +26,14 @@ class OCRPipeline(BasePipeline):
         self,
         det_model,
         rec_model,
+        batch_size=1,
         predictor_kwargs=None,
     ):
         super().__init__(predictor_kwargs)
+        self._build_predictor(det_model, rec_model)
+        self.set_predictor(batch_size)
+
+    def _build_predictor(self, det_model, rec_model):
         self.det_model = self._create_model(det_model)
         self.rec_model = self._create_model(rec_model)
         self.is_curve = self.det_model.model_name in [
@@ -40,6 +45,9 @@ class OCRPipeline(BasePipeline):
             det_box_type="poly" if self.is_curve else "quad"
         )
 
+    def set_predictor(self, batch_size):
+        self.rec_model.set_predictor(batch_size=batch_size)
+
     def predict(self, input, **kwargs):
         device = kwargs.get("device", "gpu")
         for det_res in self.det_model(

+ 50 - 15
paddlex/inference/pipelines/single_model_pipeline.py

@@ -15,23 +15,58 @@
 from .base import BasePipeline
 
 
-class SingleModelPipeline(BasePipeline):
-
-    entities = [
-        "image_classification",
-        "object_detection",
-        "instance_segmentation",
-        "semantic_segmentation",
-        "ts_fc",
-        "ts_ad",
-        "ts_cls",
-        "multi_label_image_classification",
-        "small_object_detection" "anomaly_detection",
-    ]
-
-    def __init__(self, model, predictor_kwargs=None):
+class _SingleModelPipeline(BasePipeline):
+
+    def __init__(self, model, batch_size=1, predictor_kwargs=None):
         super().__init__(predictor_kwargs)
+        self._build_predictor(model)
+        self.set_predictor(batch_size)
+
+    def _build_predictor(self, model):
         self.model = self._create_model(model)
 
+    def set_predictor(self, batch_size):
+        self.model.set_predictor(batch_size=batch_size)
+
     def predict(self, input, **kwargs):
         yield from self.model(input, **kwargs)
+
+
+class ImageClassification(_SingleModelPipeline):
+    entities = "image_classification"
+
+
+class ObjectDetection(_SingleModelPipeline):
+    entities = "object_detection"
+
+
+class InstanceSegmentation(_SingleModelPipeline):
+    entities = "instance_segmentation"
+
+
+class SemanticSegmentation(_SingleModelPipeline):
+    entities = "semantic_segmentation"
+
+
+class TSFc(_SingleModelPipeline):
+    entities = "ts_fc"
+
+
+class TSAd(_SingleModelPipeline):
+    entities = "ts_ad"
+
+
+class TSCls(_SingleModelPipeline):
+    entities = "ts_cls"
+
+
+class MultiLableImageClas(_SingleModelPipeline):
+    entities = "multi_label_image_classification"
+
+
+class SmallObjDet(_SingleModelPipeline):
+    entities = "small_object_detection"
+
+
+class AnomolyDetection(_SingleModelPipeline):
+    entities = "anomaly_detection"

+ 12 - 8
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -31,12 +31,18 @@ class TableRecPipeline(BasePipeline):
         text_det_model,
         text_rec_model,
         table_model,
-        batch_size=1,
-        device="gpu",
+        layout_batch_size=1,
+        text_rec_batch_size=1,
+        table_batch_size=1,
         predictor_kwargs=None,
     ):
         super().__init__(predictor_kwargs)
+        self._build_predictor(layout_model, text_det_model, text_rec_model, table_model)
+        self.set_predictor(layout_batch_size, text_rec_batch_size, table_batch_size)
 
+    def _build_predictor(
+        self, layout_model, text_det_model, text_rec_model, table_model
+    ):
         self.layout_predictor = self._create_model(model=layout_model)
         self.ocr_pipeline = OCRPipeline(
             text_det_model,
@@ -46,13 +52,11 @@ class TableRecPipeline(BasePipeline):
         self.table_predictor = self._create_model(model=table_model)
         self._crop_by_boxes = CropByBoxes()
         self._match = TableMatch(filter_ocr_result=False)
-        self.set_predictor(batch_size=batch_size, device=device)
 
-    def set_predictor(self, batch_size, device):
-        self.layout_predictor.set_predict(device=device, batch_size=batch_size)
-        self.ocr_pipeline.det_model.set_predict(device=device)
-        self.ocr_pipeline.rec_model.set_predict(device=device, batch_size=batch_size)
-        self.table_predictor.set_predict(device=device, batch_size=batch_size)
+    def set_predictor(self, layout_batch_size, text_rec_batch_size, table_batch_size):
+        self.layout_predictor.set_predictor(batch_size=layout_batch_size)
+        self.ocr_pipeline.rec_model.set_predictor(batch_size=text_rec_batch_size)
+        self.table_predictor.set_predictor(batch_size=table_batch_size)
 
     def predict(self, x):
         for layout_pred, ocr_pred in zip(