gaotingquan 1 سال پیش
والد
کامیت
dddf7b040f
30فایلهای تغییر یافته به همراه304 افزوده شده و 223 حذف شده
  1. 53 32
      paddlex/inference/components/paddle_predictor/predictor.py
  2. 3 0
      paddlex/inference/components/transforms/image/common.py
  3. 1 3
      paddlex/inference/models/__init__.py
  4. 16 0
      paddlex/inference/models/base/__init__.py
  5. 50 30
      paddlex/inference/models/base/base_predictor.py
  6. 22 0
      paddlex/inference/models/base/cv_predictor.py
  7. 6 14
      paddlex/inference/models/general_recognition.py
  8. 6 14
      paddlex/inference/models/image_classification.py
  9. 9 15
      paddlex/inference/models/image_unwarping.py
  10. 10 12
      paddlex/inference/models/instance_segmentation.py
  11. 12 13
      paddlex/inference/models/object_detection.py
  12. 6 14
      paddlex/inference/models/semantic_segmentation.py
  13. 8 10
      paddlex/inference/models/table_recognition.py
  14. 8 10
      paddlex/inference/models/text_detection.py
  15. 8 10
      paddlex/inference/models/text_recognition.py
  16. 0 3
      paddlex/inference/models/ts_ad.py
  17. 0 3
      paddlex/inference/models/ts_cls.py
  18. 0 3
      paddlex/inference/models/ts_fc.py
  19. 30 0
      paddlex/inference/models/utils/predict_set.py
  20. 2 1
      paddlex/inference/pipelines/__init__.py
  21. 12 9
      paddlex/inference/pipelines/ocr.py
  22. 5 5
      paddlex/inference/pipelines/single_model_pipeline.py
  23. 3 1
      paddlex/inference/utils/pp_option.py
  24. 11 7
      paddlex/modules/predictor.py
  25. 3 3
      paddlex/paddlex_cli.py
  26. 0 2
      paddlex/pipelines/OCR.yaml
  27. 5 2
      paddlex/pipelines/image_classification.yaml
  28. 5 2
      paddlex/pipelines/instance_segmentation.yaml
  29. 5 3
      paddlex/pipelines/object_detection.yaml
  30. 5 2
      paddlex/pipelines/semantic_segmentation.yaml

+ 53 - 32
paddlex/inference/components/paddle_predictor/predictor.py

@@ -17,8 +17,9 @@ from abc import abstractmethod
 import lazy_paddle as paddle
 import numpy as np
 
-from ..base import BaseComponent
 from ....utils import logging
+from ...utils.pp_option import PaddlePredictorOption
+from ..base import BaseComponent
 
 
 class BasePaddlePredictor(BaseComponent):
@@ -28,17 +29,27 @@ class BasePaddlePredictor(BaseComponent):
     DEAULT_OUTPUTS = {"pred": "pred"}
     ENABLE_BATCH = True
 
-    def __init__(self, model_dir, model_prefix, option):
+    def __init__(self, model_dir, model_prefix, option: PaddlePredictorOption = None):
         super().__init__()
+        self.model_dir = model_dir
+        self.model_prefix = model_prefix
+        self.option = option
+        self._is_initialized = False
+
+    def _build(self):
+        if not self.option:
+            self.option = PaddlePredictorOption()
         (
             self.predictor,
             self.inference_config,
             self.input_names,
             self.input_handlers,
             self.output_handlers,
-        ) = self._create(model_dir, model_prefix, option)
+        ) = self._create()
+        self._is_initialized = True
+        logging.debug(f"Env: {self.option}")
 
-    def _create(self, model_dir, model_prefix, option):
+    def _create(self):
         """_create"""
         from lazy_paddle.inference import Config, create_predictor
 
@@ -46,17 +57,17 @@ class BasePaddlePredictor(BaseComponent):
             hasattr(paddle.framework, "use_pir_api") and paddle.framework.use_pir_api()
         )
         model_postfix = ".json" if use_pir else ".pdmodel"
-        model_file = (model_dir / f"{model_prefix}{model_postfix}").as_posix()
-        params_file = (model_dir / f"{model_prefix}.pdiparams").as_posix()
+        model_file = (self.model_dir / f"{self.model_prefix}{model_postfix}").as_posix()
+        params_file = (self.model_dir / f"{self.model_prefix}.pdiparams").as_posix()
         config = Config(model_file, params_file)
 
-        if option.device == "gpu":
-            config.enable_use_gpu(200, option.device_id)
+        if self.option.device == "gpu":
+            config.enable_use_gpu(200, self.option.device_id)
             if paddle.is_compiled_with_rocm():
                 os.environ["FLAGS_conv_workspace_size_limit"] = "2000"
             elif hasattr(config, "enable_new_ir"):
-                config.enable_new_ir(option.enable_new_ir)
-        elif option.device == "npu":
+                config.enable_new_ir(self.option.enable_new_ir)
+        elif self.option.device == "npu":
             config.enable_custom_device("npu")
             os.environ["FLAGS_npu_jit_compile"] = "0"
             os.environ["FLAGS_use_stride_kernel"] = "0"
@@ -66,23 +77,25 @@ class BasePaddlePredictor(BaseComponent):
             )
             os.environ["FLAGS_npu_scale_aclnn"] = "True"
             os.environ["FLAGS_npu_split_aclnn"] = "True"
-        elif option.device == "xpu":
+        elif self.option.device == "xpu":
             os.environ["BKCL_FORCE_SYNC"] = "1"
             os.environ["BKCL_TIMEOUT"] = "1800"
             os.environ["FLAGS_use_stride_kernel"] = "0"
-        elif option.device == "mlu":
+        elif self.option.device == "mlu":
             config.enable_custom_device("mlu")
             os.environ["FLAGS_use_stride_kernel"] = "0"
         else:
-            assert option.device == "cpu"
+            assert self.option.device == "cpu"
             config.disable_gpu()
-            config.enable_new_ir(option.enable_new_ir)
-            config.enable_new_executor(True)
-            if "mkldnn" in option.run_mode:
+            if hasattr(config, "enable_new_ir"):
+                config.enable_new_ir(self.option.enable_new_ir)
+            if hasattr(config, "enable_new_executor"):
+                config.enable_new_executor(True)
+            if "mkldnn" in self.option.run_mode:
                 try:
                     config.enable_mkldnn()
-                    config.set_cpu_math_library_num_threads(option.cpu_threads)
-                    if "bf16" in option.run_mode:
+                    config.set_cpu_math_library_num_threads(self.option.cpu_threads)
+                    if "bf16" in self.option.run_mode:
                         config.enable_mkldnn_bfloat16()
                 except Exception as e:
                     logging.warning(
@@ -94,34 +107,34 @@ class BasePaddlePredictor(BaseComponent):
             "trt_fp32": Config.Precision.Float32,
             "trt_fp16": Config.Precision.Half,
         }
-        if option.run_mode in precision_map.keys():
+        if self.option.run_mode in precision_map.keys():
             config.enable_tensorrt_engine(
-                workspace_size=(1 << 25) * option.batch_size,
-                max_batch_size=option.batch_size,
-                min_subgraph_size=option.min_subgraph_size,
-                precision_mode=precision_map[option.run_mode],
-                trt_use_static=option.trt_use_static,
-                use_calib_mode=option.trt_calib_mode,
+                workspace_size=(1 << 25) * self.option.batch_size,
+                max_batch_size=self.option.batch_size,
+                min_subgraph_size=self.option.min_subgraph_size,
+                precision_mode=precision_map[self.option.run_mode],
+                trt_use_static=self.option.trt_use_static,
+                use_calib_mode=self.option.trt_calib_mode,
             )
 
-            if option.shape_info_filename is not None:
-                if not os.path.exists(option.shape_info_filename):
-                    config.collect_shape_range_info(option.shape_info_filename)
+            if self.option.shape_info_filename is not None:
+                if not os.path.exists(self.option.shape_info_filename):
+                    config.collect_shape_range_info(self.option.shape_info_filename)
                     logging.info(
-                        f"Dynamic shape info is collected into: {option.shape_info_filename}"
+                        f"Dynamic shape info is collected into: {self.option.shape_info_filename}"
                     )
                 else:
                     logging.info(
-                        f"A dynamic shape info file ( {option.shape_info_filename} ) already exists. \
+                        f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. \
 No need to generate again."
                     )
                 config.enable_tuned_tensorrt_dynamic_shape(
-                    option.shape_info_filename, True
+                    self.option.shape_info_filename, True
                 )
 
         # Disable paddle inference logging
         config.disable_glog_info()
-        for del_p in option.delete_pass:
+        for del_p in self.option.delete_pass:
             config.delete_pass(del_p)
         # Enable shared memory
         config.enable_memory_optim()
@@ -149,6 +162,9 @@ No need to generate again."
         return self.input_names
 
     def apply(self, **kwargs):
+        if not self._is_initialized:
+            self._build()
+
         x = self.to_batch(**kwargs)
         for idx in range(len(x)):
             self.input_handlers[idx].reshape(x[idx].shape)
@@ -164,6 +180,11 @@ No need to generate again."
     def format_output(self, pred):
         return [{"pred": res} for res in zip(*pred)]
 
+    def set_option(self, option):
+        if option != self.option:
+            self.option = option
+            self._build()
+
     @abstractmethod
     def to_batch(self):
         raise NotImplementedError

+ 3 - 0
paddlex/inference/components/transforms/image/common.py

@@ -159,6 +159,9 @@ class ReadImage(BaseComponent):
         imgs_lists = sorted(imgs_lists)
         return imgs_lists
 
+    def set_batch_size(self, batch_size):
+        self.batch_size = batch_size
+
 
 class GetImageInfo(BaseComponent):
     """Get Image Info"""

+ 1 - 3
paddlex/inference/models/__init__.py

@@ -58,7 +58,7 @@ def _create_hp_predictor(
 
 
 def create_model(
-    model: str, device: str = None, *args, use_hpip=False, hpi_params=None, **kwargs
+    model: str, *args, use_hpip=False, hpi_params=None, **kwargs
 ) -> BasePredictor:
     model_dir = check_model(model)
     config = BasePredictor.load_config(model_dir)
@@ -67,7 +67,6 @@ def create_model(
         return _create_hp_predictor(
             model_name=model_name,
             model_dir=model_dir,
-            device=device,
             config=config,
             hpi_params=hpi_params,
             *args,
@@ -77,7 +76,6 @@ def create_model(
         return BasicPredictor.get(model_name)(
             model_dir=model_dir,
             config=config,
-            device=device,
             *args,
             **kwargs,
         )

+ 16 - 0
paddlex/inference/models/base/__init__.py

@@ -0,0 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_predictor import BasePredictor, BasicPredictor
+from .cv_predictor import CVPredictor

+ 50 - 30
paddlex/inference/models/base.py → paddlex/inference/models/base/base_predictor.py

@@ -17,25 +17,18 @@ import codecs
 from pathlib import Path
 from abc import abstractmethod
 
-import GPUtil
-
-from ...utils.subclass_register import AutoRegisterABCMetaClass
-from ..utils.device import constr_device
-from ...utils import logging
-from ..components.base import BaseComponent, ComponentsEngine
-from ..utils.pp_option import PaddlePredictorOption
-from ..utils.process_hook import generatorable_method
-
-
-def _get_default_device():
-    avail_gpus = GPUtil.getAvailable()
-    if not avail_gpus:
-        return "cpu"
-    else:
-        return constr_device("gpu", [avail_gpus[0]])
+from ....utils.subclass_register import AutoRegisterABCMetaClass
+from ....utils.func_register import FuncRegister
+from ....utils import logging
+from ...utils.device import constr_device
+from ...components.base import BaseComponent, ComponentsEngine
+from ...utils.pp_option import PaddlePredictorOption
+from ...utils.process_hook import generatorable_method
+from ..utils.predict_set import DeviceSetMixin, PPOptionSetMixin
 
 
 class BasePredictor(BaseComponent):
+
     KEEP_INPUT = False
     YIELD_BATCH = False
 
@@ -46,17 +39,20 @@ class BasePredictor(BaseComponent):
 
     MODEL_FILE_PREFIX = "inference"
 
-    def __init__(self, model_dir, config=None, device=None, **kwargs):
+    def __init__(self, model_dir, config=None):
         super().__init__()
         self.model_dir = Path(model_dir)
         self.config = config if config else self.load_config(self.model_dir)
-        self.device = device if device else _get_default_device()
-        self.kwargs = self._check_args(kwargs)
+
+        self._pred_set_func_map = {}
+        self._pred_set_register = FuncRegister(self._pred_set_func_map)
+
         # alias predict() to the __call__()
         self.predict = self.__call__
 
-    def __call__(self, *args, **kwargs):
-        for res in super().__call__(*args, **kwargs):
+    def __call__(self, input, **kwargs):
+        self._set_predict(**kwargs)
+        for res in super().__call__(input):
             yield res["result"]
 
     @property
@@ -82,22 +78,28 @@ class BasePredictor(BaseComponent):
             dic = yaml.load(file, Loader=yaml.FullLoader)
         return dic
 
-    def _check_args(self, kwargs):
-        return kwargs
+    def _set_predict(self, **kwargs):
+        for k in kwargs:
+            self._pred_set_func_map[k](kwargs[k])
 
 
-class BasicPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
+class BasicPredictor(
+    BasePredictor, DeviceSetMixin, PPOptionSetMixin, metaclass=AutoRegisterABCMetaClass
+):
 
     __is_base = True
 
-    def __init__(self, model_dir, config=None, device=None, pp_option=None, **kwargs):
-        super().__init__(model_dir=model_dir, config=config, device=device, **kwargs)
-        self.pp_option = PaddlePredictorOption() if pp_option is None else pp_option
-        self.pp_option.set_device(self.device)
-        self.components = self._build_components()
+    def __init__(self, model_dir, config=None):
+        super().__init__(model_dir=model_dir, config=config)
+        self._pred_set_register("device")(self.set_device)
+        self._pred_set_register("pp_option")(self.set_pp_option)
+
+        self.pp_option = PaddlePredictorOption()
+        self.components = {}
+        self._build_components()
         self.engine = ComponentsEngine(self.components)
         logging.debug(
-            f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}\nEnv: {self.pp_option}"
+            f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}"
         )
 
     def apply(self, x):
@@ -108,6 +110,24 @@ class BasicPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
     def _generate_res(self, batch_data):
         return [{"result": self._pack_res(data)} for data in batch_data]
 
+    def _add_component(self, cmps):
+        if not isinstance(cmps, list):
+            cmps = [cmps]
+
+        for cmp in cmps:
+            if not isinstance(cmp, (list, tuple)):
+                key = cmp.__class__.__name__
+            else:
+                assert len(cmp) == 2
+                key = cmp[0]
+                cmp = cmp[1]
+            assert isinstance(key, str)
+            assert isinstance(cmp, BaseComponent)
+            assert (
+                key not in self.components
+            ), f"The key ({key}) has been used: {self.components}!"
+            self.components[key] = cmp
+
     @abstractmethod
     def _build_components(self):
         raise NotImplementedError

+ 22 - 0
paddlex/inference/models/base/cv_predictor.py

@@ -0,0 +1,22 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..utils.predict_set import BatchSetMixin
+from .base_predictor import BasicPredictor
+
+
+class CVPredictor(BasicPredictor, BatchSetMixin):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._pred_set_register("batch_size")(self.set_batch_size)

+ 6 - 14
paddlex/inference/models/general_recognition.py

@@ -19,46 +19,38 @@ from ...modules.general_recognition.model_list import MODELS
 from ..components import *
 from ..results import BaseResult
 from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import CVPredictor
 
 
-class ShiTuRecPredictor(BasicPredictor):
+class ShiTuRecPredictor(CVPredictor):
 
     entities = MODELS
 
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def _check_args(self, kwargs):
-        assert set(kwargs.keys()).issubset(set(["batch_size"]))
-        return kwargs
-
     def _build_components(self):
-        ops = {}
-        ops["ReadImage"] = ReadImage(
-            batch_size=self.kwargs.get("batch_size", 1), format="RGB"
-        )
+        self._add_component(ReadImage(format="RGB"))
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
             func = self._FUNC_MAP.get(tf_key)
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
-            ops[tf_key] = op
+            self._add_component(op)
 
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        ops["predictor"] = predictor
+        self._add_component(("Predictor", predictor))
 
         post_processes = self.config["PostProcess"]
         for key in post_processes:
             func = self._FUNC_MAP.get(key)
             args = post_processes.get(key, {})
             op = func(self, **args) if args else func(self)
-            ops[key] = op
-        return ops
+            self._add_component(op)
 
     @register("ResizeImage")
     # TODO(gaotingquan): backend & interpolation

+ 6 - 14
paddlex/inference/models/image_classification.py

@@ -20,46 +20,38 @@ from ...modules.multilabel_classification.model_list import MODELS as ML_MODELS
 from ..components import *
 from ..results import TopkResult
 from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import CVPredictor
 
 
-class ClasPredictor(BasicPredictor):
+class ClasPredictor(CVPredictor):
 
     entities = [*MODELS, *ML_MODELS]
 
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def _check_args(self, kwargs):
-        assert set(kwargs.keys()).issubset(set(["batch_size"]))
-        return kwargs
-
     def _build_components(self):
-        ops = {}
-        ops["ReadImage"] = ReadImage(
-            format="RGB", batch_size=self.kwargs.get("batch_size", 1)
-        )
+        self._add_component(ReadImage(format="RGB"))
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
             func = self._FUNC_MAP.get(tf_key)
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
-            ops[tf_key] = op
+            self._add_component(op)
 
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        ops["predictor"] = predictor
+        self._add_component(("Predictor", predictor))
 
         post_processes = self.config["PostProcess"]
         for key in post_processes:
             func = self._FUNC_MAP.get(key)
             args = post_processes.get(key, {})
             op = func(self, **args) if args else func(self)
-            ops[key] = op
-        return ops
+            self._add_component(op)
 
     @register("ResizeImage")
     # TODO(gaotingquan): backend & interpolation

+ 9 - 15
paddlex/inference/models/image_unwarping.py

@@ -16,34 +16,28 @@ from ...modules.image_unwarping.model_list import MODELS
 from ..components import *
 from ..results import DocTrResult
 from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import CVPredictor
 
 
-class WarpPredictor(BasicPredictor):
+class WarpPredictor(CVPredictor):
 
     entities = MODELS
 
-    def _check_args(self, kwargs):
-        assert set(kwargs.keys()).issubset(set(["batch_size"]))
-        return kwargs
-
     def _build_components(self):
-        ops = {}
-        ops["ReadImage"] = ReadImage(
-            format="RGB", batch_size=self.kwargs.get("batch_size", 1)
+        self._add_component(
+            [
+                ReadImage(format="RGB"),
+                Normalize(mean=0.0, std=1.0, scale=1.0 / 255),
+                ToCHWImage(),
+            ]
         )
-        ops["Normalize"] = Normalize(mean=0.0, std=1.0, scale=1.0 / 255)
-        ops["ToCHWImage"] = ToCHWImage()
 
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        ops["predictor"] = predictor
-
-        ops["postprocess"] = DocTrPostProcess()
-        return ops
+        self._add_component([("Predictor", predictor), DocTrPostProcess()])
 
     @batchable_method
     def _pack_res(self, single):

+ 10 - 12
paddlex/inference/models/instance_segmentation.py

@@ -27,17 +27,14 @@ class InstanceSegPredictor(DetPredictor):
     entities = MODELS
 
     def _build_components(self):
-        ops = {}
-        ops["ReadImage"] = ReadImage(
-            batch_size=self.kwargs.get("batch_size", 1), format="RGB"
-        )
+        self._add_component(ReadImage(format="RGB"))
         for cfg in self.config["Preprocess"]:
             tf_key = cfg["type"]
             func = self._FUNC_MAP.get(tf_key)
             cfg.pop("type")
             args = cfg
             op = func(self, **args) if args else func(self)
-            ops[tf_key] = op
+            self._add_component(op)
 
         predictor = ImageDetPredictor(
             model_dir=self.model_dir,
@@ -54,15 +51,16 @@ class InstanceSegPredictor(DetPredictor):
             predictor.set_inputs(
                 {"img": "img", "scale_factors": "scale_factors", "img_size": "img_size"}
             )
-
-        ops["predictor"] = predictor
-
-        ops["postprocess"] = InstanceSegPostProcess(
-            threshold=self.config["draw_threshold"], labels=self.config["label_list"]
+        self._add_component(
+            [
+                ("Predictor", predictor),
+                InstanceSegPostProcess(
+                    threshold=self.config["draw_threshold"],
+                    labels=self.config["label_list"],
+                ),
+            ]
         )
 
-        return ops
-
     def _pack_res(self, single):
         keys = ["img_path", "boxes", "masks"]
         return InstanceSegResult({key: single[key] for key in keys})

+ 12 - 13
paddlex/inference/models/object_detection.py

@@ -19,10 +19,10 @@ from ...modules.object_detection.model_list import MODELS
 from ..components import *
 from ..results import DetResult
 from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import CVPredictor
 
 
-class DetPredictor(BasicPredictor):
+class DetPredictor(CVPredictor):
 
     entities = MODELS
 
@@ -30,17 +30,14 @@ class DetPredictor(BasicPredictor):
     register = FuncRegister(_FUNC_MAP)
 
     def _build_components(self):
-        ops = {}
-        ops["ReadImage"] = ReadImage(
-            batch_size=self.kwargs.get("batch_size", 1), format="RGB"
-        )
+        self._add_component(ReadImage(format="RGB"))
         for cfg in self.config["Preprocess"]:
             tf_key = cfg["type"]
             func = self._FUNC_MAP.get(tf_key)
             cfg.pop("type")
             args = cfg
             op = func(self, **args) if args else func(self)
-            ops[tf_key] = op
+            self._add_component(op)
 
         predictor = ImageDetPredictor(
             model_dir=self.model_dir,
@@ -62,14 +59,16 @@ class DetPredictor(BasicPredictor):
                 }
             )
 
-        ops["predictor"] = predictor
-
-        ops["postprocess"] = DetPostProcess(
-            threshold=self.config["draw_threshold"], labels=self.config["label_list"]
+        self._add_component(
+            [
+                ("Predictor", predictor),
+                DetPostProcess(
+                    threshold=self.config["draw_threshold"],
+                    labels=self.config["label_list"],
+                ),
+            ]
         )
 
-        return ops
-
     @register("Resize")
     def build_resize(self, target_size, keep_ratio=False, interp=2):
         assert target_size

+ 6 - 14
paddlex/inference/models/semantic_segmentation.py

@@ -19,41 +19,33 @@ from ...modules.semantic_segmentation.model_list import MODELS
 from ..components import *
 from ..results import SegResult
 from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import CVPredictor
 
 
-class SegPredictor(BasicPredictor):
+class SegPredictor(CVPredictor):
 
     entities = MODELS
 
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def _check_args(self, kwargs):
-        assert set(kwargs.keys()).issubset(set(["batch_size"]))
-        return kwargs
-
     def _build_components(self):
-        ops = {}
-        ops["ReadImage"] = ReadImage(
-            batch_size=self.kwargs.get("batch_size", 1), format="RGB"
-        )
-        ops["ToCHWImage"] = ToCHWImage()
+        self._add_component(ReadImage(format="RGB"))
+        self._add_component(ToCHWImage())
         for cfg in self.config["Deploy"]["transforms"]:
             tf_key = cfg["type"]
             func = self._FUNC_MAP.get(tf_key)
             cfg.pop("type")
             args = cfg
             op = func(self, **args) if args else func(self)
-            ops[tf_key] = op
+            self._add_component(op)
 
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        ops["predictor"] = predictor
-        return ops
+        self._add_component(("Predictor", predictor))
 
     @register("Resize")
     def build_resize(

+ 8 - 10
paddlex/inference/models/table_recognition.py

@@ -19,11 +19,11 @@ from ...utils.func_register import FuncRegister
 from ...modules.table_recognition.model_list import MODELS
 from ..components import *
 from ..results import TableRecResult
-from .base import BasicPredictor
 from ..utils.process_hook import batchable_method
+from .base import CVPredictor
 
 
-class TablePredictor(BasicPredictor):
+class TablePredictor(CVPredictor):
     """table recognition predictor"""
 
     entities = MODELS
@@ -32,29 +32,27 @@ class TablePredictor(BasicPredictor):
     register = FuncRegister(_FUNC_MAP)
 
     def _build_components(self):
-        ops = {}
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
             func = self._FUNC_MAP.get(tf_key)
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             if op:
-                ops[tf_key] = op
+                self._add_component(op)
 
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        ops["predictor"] = predictor
+        self._add_component(("Predictor", predictor))
 
-        key, op = self.build_postprocess(**self.config["PostProcess"])
-        ops[key] = op
-        return ops
+        op = self.build_postprocess(**self.config["PostProcess"])
+        self._add_component(op)
 
     def build_postprocess(self, **kwargs):
         if kwargs.get("name") == "TableLabelDecode":
-            return "TableLabelDecode", TableLabelDecode(
+            return TableLabelDecode(
                 merge_no_span_structure=kwargs.get("merge_no_span_structure"),
                 dict_character=kwargs.get("character_dict"),
             )
@@ -63,7 +61,7 @@ class TablePredictor(BasicPredictor):
 
     @register("DecodeImage")
     def build_readimg(self, *args, **kwargs):
-        return ReadImage(batch_size=self.kwargs.get("batch_size", 1))
+        return ReadImage(*args, **kwargs)
 
     @register("TableLabelEncode")
     def foo(self, *args, **kwargs):

+ 8 - 10
paddlex/inference/models/text_detection.py

@@ -19,10 +19,10 @@ from ...modules.text_detection.model_list import MODELS
 from ..components import *
 from ..results import TextDetResult
 from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import CVPredictor
 
 
-class TextDetPredictor(BasicPredictor):
+class TextDetPredictor(CVPredictor):
 
     entities = MODELS
 
@@ -30,30 +30,28 @@ class TextDetPredictor(BasicPredictor):
     register = FuncRegister(_FUNC_MAP)
 
     def _build_components(self):
-        ops = {}
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
             func = self._FUNC_MAP.get(tf_key)
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             if op:
-                ops[tf_key] = op
+                self._add_component(op)
 
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        ops["predictor"] = predictor
+        self._add_component(("Predictor", predictor))
 
-        key, op = self.build_postprocess(**self.config["PostProcess"])
-        ops[key] = op
-        return ops
+        op = self.build_postprocess(**self.config["PostProcess"])
+        self._add_component(op)
 
     @register("DecodeImage")
     def build_readimg(self, channel_first, img_mode):
         assert channel_first == False
-        return ReadImage(format=img_mode, batch_size=self.kwargs.get("batch_size", 1))
+        return ReadImage(format=img_mode)
 
     @register("DetResizeForTest")
     def build_resize(self, resize_long=960):
@@ -78,7 +76,7 @@ class TextDetPredictor(BasicPredictor):
 
     def build_postprocess(self, **kwargs):
         if kwargs.get("name") == "DBPostProcess":
-            return "DBPostProcess", DBPostProcess(
+            return DBPostProcess(
                 thresh=kwargs.get("thresh", 0.3),
                 box_thresh=kwargs.get("box_thresh", 0.7),
                 max_candidates=kwargs.get("max_candidates", 1000),

+ 8 - 10
paddlex/inference/models/text_recognition.py

@@ -19,10 +19,10 @@ from ...modules.text_recognition.model_list import MODELS
 from ..components import *
 from ..results import TextRecResult
 from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import CVPredictor
 
 
-class TextRecPredictor(BasicPredictor):
+class TextRecPredictor(CVPredictor):
 
     entities = MODELS
 
@@ -30,7 +30,6 @@ class TextRecPredictor(BasicPredictor):
     register = FuncRegister(_FUNC_MAP)
 
     def _build_components(self):
-        ops = {}
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
             assert tf_key in self._FUNC_MAP
@@ -38,23 +37,22 @@ class TextRecPredictor(BasicPredictor):
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             if op:
-                ops[tf_key] = op
+                self._add_component(op)
 
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        ops["predictor"] = predictor
+        self._add_component(("Predictor", predictor))
 
-        key, op = self.build_postprocess(**self.config["PostProcess"])
-        ops[key] = op
-        return ops
+        op = self.build_postprocess(**self.config["PostProcess"])
+        self._add_component(op)
 
     @register("DecodeImage")
     def build_readimg(self, channel_first, img_mode):
         assert channel_first == False
-        return ReadImage(format=img_mode, batch_size=self.kwargs.get("batch_size", 1))
+        return ReadImage(format=img_mode)
 
     @register("RecResizeImg")
     def build_resize(self, image_shape):
@@ -62,7 +60,7 @@ class TextRecPredictor(BasicPredictor):
 
     def build_postprocess(self, **kwargs):
         if kwargs.get("name") == "CTCLabelDecode":
-            return "CTCLabelDecode", CTCLabelDecode(
+            return CTCLabelDecode(
                 character_list=kwargs.get("character_dict"),
             )
         else:

+ 0 - 3
paddlex/inference/models/ts_ad.py

@@ -25,9 +25,6 @@ class TSAdPredictor(BasicPredictor):
 
     entities = MODELS
 
-    def _check_args(self, kwargs):
-        pass
-
     def _build_components(self):
         preprocess = self._build_preprocess()
         predictor = TSPPPredictor(

+ 0 - 3
paddlex/inference/models/ts_cls.py

@@ -24,9 +24,6 @@ class TSClsPredictor(BasicPredictor):
 
     entities = MODELS
 
-    def _check_args(self, kwargs):
-        pass
-
     def _build_components(self):
         preprocess = self._build_preprocess()
         predictor = TSPPPredictor(

+ 0 - 3
paddlex/inference/models/ts_fc.py

@@ -25,9 +25,6 @@ class TSFcPredictor(BasicPredictor):
 
     entities = MODELS
 
-    def _check_args(self, kwargs):
-        pass
-
     def _build_components(self):
         preprocess = self._build_preprocess()
         predictor = TSPPPredictor(

+ 30 - 0
paddlex/inference/models/utils/predict_set.py

@@ -0,0 +1,30 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class BatchSetMixin:
+    def set_batch_size(self, batch_size):
+        self.components["ReadImage"].set_batch_size(batch_size)
+
+
+class DeviceSetMixin:
+    def set_device(self, device):
+        self.pp_option.set_device(device)
+        self.components["Predictor"].set_option(self.pp_option)
+
+
+class PPOptionSetMixin:
+    def set_pp_option(self, pp_option):
+        self.pp_option = pp_option
+        self.components["Predictor"].set_option(self.pp_option)

+ 2 - 1
paddlex/inference/pipelines/__init__.py

@@ -26,6 +26,7 @@ def create_pipeline(
     pipeline: str,
     use_hpip: bool = False,
     hpi_params: Optional[Dict[str, Any]] = None,
+    *args,
     **kwargs,
 ) -> BasePipeline:
     """build model evaluater
@@ -49,6 +50,6 @@ def create_pipeline(
     if hpi_params is not None:
         predictor_kwargs["hpi_params"] = hpi_params
     pipeline = BasePipeline.get(pipeline_name)(
-        predictor_kwargs=predictor_kwargs, **{**config["Pipeline"], **kwargs}
+        predictor_kwargs=predictor_kwargs, *args, **config["Pipeline"], **kwargs
     )
     return pipeline

+ 12 - 9
paddlex/inference/pipelines/ocr.py

@@ -26,15 +26,11 @@ class OCRPipeline(BasePipeline):
         self,
         det_model,
         rec_model,
-        rec_batch_size=1,
-        device="gpu",
         predictor_kwargs=None,
     ):
         super().__init__(predictor_kwargs)
-        self._det_predict = self._create_model(det_model, device=device)
-        self._rec_predict = self._create_model(
-            rec_model, batch_size=rec_batch_size, device=device
-        )
+        self._det_predict = self._create_model(det_model)
+        self._rec_predict = self._create_model(rec_model)
         self.is_curve = self._det_predict.model_name in [
             "PP-OCRv4_mobile_seal_det",
             "PP-OCRv4_server_seal_det",
@@ -44,8 +40,11 @@ class OCRPipeline(BasePipeline):
             det_box_type="poly" if self.is_curve else "quad"
         )
 
-    def predict(self, x):
-        for det_res in self._det_predict(x):
+    def predict(self, input, **kwargs):
+        device = kwargs.get("device", "gpu")
+        for det_res in self._det_predict(
+            input, batch_size=kwargs.get("det_batch_size", 1), device=device
+        ):
             single_img_res = (
                 det_res if self.is_curve else next(self._sort_boxes(det_res))
             )
@@ -53,7 +52,11 @@ class OCRPipeline(BasePipeline):
             single_img_res["rec_score"] = []
             if len(single_img_res["dt_polys"]) > 0:
                 all_subs_of_img = list(self._crop_by_polys(single_img_res))
-                for rec_res in self._rec_predict(all_subs_of_img):
+                for rec_res in self._rec_predict(
+                    all_subs_of_img,
+                    batch_size=kwargs.get("rec_batch_size", 1),
+                    device=device,
+                ):
                     single_img_res["rec_text"].append(rec_res["rec_text"])
                     single_img_res["rec_score"].append(rec_res["rec_score"])
             yield OCRResult(single_img_res)

+ 5 - 5
paddlex/inference/pipelines/single_model_pipeline.py

@@ -26,12 +26,12 @@ class SingleModelPipeline(BasePipeline):
         "ts_ad",
         "ts_cls",
         "multi_label_image_classification",
-        "anomaly_detection",
+        "small_object_detection" "anomaly_detection",
     ]
 
-    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
+    def __init__(self, model, predictor_kwargs=None):
         super().__init__(predictor_kwargs)
-        self._predict = self._create_model(model, batch_size=batch_size, device=device)
+        self._predict = self._create_model(model)
 
-    def predict(self, x):
-        yield from self._predict(x)
+    def predict(self, input, **kwargs):
+        yield from self._predict(input, **kwargs)

+ 3 - 1
paddlex/inference/utils/pp_option.py

@@ -86,7 +86,9 @@ class PaddlePredictorOption(object):
             )
         device_id = device_ids[0] if device_ids is not None else 0
         self._cfg["device_id"] = device_id
-        logging.warning(f"The device ID has been set to {device_id}.")
+        if device_type not in ("cpu"):
+            if device_ids is None or len(device_ids) > 1:
+                logging.warning(f"The device ID has been set to {device_id}.")
 
     @register("min_subgraph_size")
     def set_min_subgraph_size(self, min_subgraph_size: int):

+ 11 - 7
paddlex/modules/predictor.py

@@ -21,18 +21,22 @@ from ..utils.config import AttrDict
 class Predictor(object):
     def __init__(self, config):
         model_name = config.Global.model
-        predict_config = deepcopy(config.Predict)
+        self.predict_config = deepcopy(config.Predict)
 
-        model_dir = predict_config.pop("model_dir", None)
+        model_dir = self.predict_config.pop("model_dir", None)
         # if model_dir is None, using official
         model = model_name if model_dir is None else model_dir
-        self.input_path = predict_config.pop("input_path")
-        pp_option = PaddlePredictorOption(**predict_config.pop("kernel_option", {}))
-        self.model = create_model(model, pp_option=pp_option, **predict_config)
+        self.input_path = self.predict_config.pop("input_path")
+        self.pp_option = PaddlePredictorOption(
+            **self.predict_config.pop("kernel_option", {})
+        )
+        self.model = create_model(model)
 
     def predict(self):
-        for res in self.model(self.input_path):
-            res.print()
+        for res in self.model(
+            input=self.input_path, pp_option=self.pp_option, **self.predict_config
+        ):
+            res.print(json_format=False)
 
 
 def build_predictor(config: AttrDict):

+ 3 - 3
paddlex/paddlex_cli.py

@@ -93,10 +93,10 @@ def install(args):
 
 def pipeline_predict(pipeline, input_path, device=None, save_dir=None):
     """pipeline predict"""
-    pipeline = create_pipeline(pipeline, device=device)
-    result = pipeline(input_path)
+    pipeline = create_pipeline(pipeline)
+    result = pipeline(input_path, device=device)
     for res in result:
-        res.print()
+        res.print(json_format=False)
         # TODO(gaotingquan): support to save all
         # if save_dir:
         #     i["result"].save()

+ 0 - 2
paddlex/pipelines/OCR.yaml

@@ -8,7 +8,5 @@ Global:
 Pipeline:
   det_model: PP-OCRv4_mobile_det
   rec_model: PP-OCRv4_mobile_rec
-  rec_batch_size: 1
-  device: "gpu"
 
 ######################################## Support ########################################

+ 5 - 2
paddlex/pipelines/image_classification.yaml

@@ -7,10 +7,13 @@ Global:
 
 Pipeline:
   model: PP-LCNet_x0_5
-  batch_size: 1
-  device: "gpu"
 
 ######################################## Support ########################################
 NOTE:
+  device: 
+    - gpu
+    - gpu:2
+    - cpu
+  batch_size: "任意正整数"
   model:
     - PP-LCNet_x0_5

+ 5 - 2
paddlex/pipelines/instance_segmentation.yaml

@@ -7,10 +7,13 @@ Global:
 
 Pipeline:
   model: Mask-RT-DETR-S
-  batch_size: 1
-  device: "gpu"
 
 ######################################## Support ########################################
 NOTE:
+  device: 
+    - gpu
+    - gpu:2
+    - cpu
+  batch_size: "任意正整数"
   model:
     - Mask-RT-DETR-S

+ 5 - 3
paddlex/pipelines/object_detection.yaml

@@ -7,11 +7,13 @@ Global:
 
 Pipeline:
   model: PicoDet-S
-  batch_size: 1
-  device: "gpu"
-  # enable_hpi: False
 
 ######################################## Support ########################################
 NOTE:
+  device: 
+    - gpu
+    - gpu:2
+    - cpu
+  batch_size: "任意正整数"
   model:
     - PicoDet-S

+ 5 - 2
paddlex/pipelines/semantic_segmentation.yaml

@@ -7,10 +7,13 @@ Global:
 
 Pipeline:
   model: PP-LiteSeg-T
-  batch_size: 1
-  device: "gpu"
 
 ######################################## Support ########################################
 NOTE:
+  device: 
+    - gpu
+    - gpu:2
+    - cpu
+  batch_size: "任意正整数"
   model:
     - PP-LiteSeg-T