Bläddra i källkod

refactor: remove model_name from PaddlePredictorOption

gaotingquan 4 månader sedan
förälder
incheckning
0794d42a91

+ 5 - 6
paddlex/inference/models/base/predictor/base_predictor.py

@@ -237,7 +237,9 @@ class BasePredictor(
 
     def create_static_infer(self):
         if not self._use_hpip:
-            return PaddleInfer(self.model_dir, self.MODEL_FILE_PREFIX, self._pp_option)
+            return PaddleInfer(
+                self.model_name, self.model_dir, self.MODEL_FILE_PREFIX, self._pp_option
+            )
         else:
             return HPInfer(
                 self.model_dir,
@@ -334,14 +336,11 @@ class BasePredictor(
         else:
             device_info = None
         if pp_option is None:
-            pp_option = PaddlePredictorOption(model_name=self.model_name)
-        elif pp_option.model_name is None:
-            pp_option.model_name = self.model_name
-            pp_option.reset_run_mode_by_default(model_name=self.model_name)
+            pp_option = PaddlePredictorOption()
         if device_info:
             pp_option.device_type = device_info[0]
             pp_option.device_id = device_info[1]
-            pp_option.reset_run_mode_by_default(device_type=device_info[0])
+        pp_option.setdefault_by_model_name(model_name=self.model_name)
         hpi_info = self.get_hpi_info()
         if hpi_info is not None:
             hpi_info = hpi_info.model_dump(exclude_unset=True)

+ 58 - 17
paddlex/inference/models/common/static_infer.py

@@ -22,7 +22,13 @@ import numpy as np
 
 from ....utils import logging
 from ....utils.deps import class_requires_deps
-from ....utils.flags import DEBUG, USE_PIR_TRT
+from ....utils.device import check_supported_device_type
+from ....utils.flags import (
+    DEBUG,
+    DISABLE_MKLDNN_MODEL_BL,
+    DISABLE_TRT_MODEL_BL,
+    USE_PIR_TRT,
+)
 from ...utils.benchmark import benchmark, set_inference_operations
 from ...utils.hpi import (
     HPIConfig,
@@ -32,8 +38,10 @@ from ...utils.hpi import (
     TensorRTConfig,
     suggest_inference_backend_and_config,
 )
+from ...utils.mkldnn_blocklist import MKLDNN_BLOCKLIST
 from ...utils.model_paths import get_model_paths
 from ...utils.pp_option import PaddlePredictorOption, get_default_run_mode
+from ...utils.trt_blocklist import TRT_BLOCKLIST
 from ...utils.trt_config import DISABLE_TRT_HALF_OPS_CONFIG
 
 CACHE_DIR = ".cache"
@@ -263,11 +271,13 @@ class StaticInfer(metaclass=abc.ABCMeta):
 class PaddleInfer(StaticInfer):
     def __init__(
         self,
+        model_name: str,
         model_dir: Union[str, PathLike],
         model_file_prefix: str,
         option: PaddlePredictorOption,
     ) -> None:
         super().__init__()
+        self._model_name = model_name
         self.model_dir = Path(model_dir)
         self.model_file_prefix = model_file_prefix
         self._option = option
@@ -287,22 +297,35 @@ class PaddleInfer(StaticInfer):
         pred = self.infer(x)
         return pred
 
-    def _create(
-        self,
-    ):
-        """_create"""
-        import paddle
-        import paddle.inference
-
-        model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
-        if "paddle" not in model_paths:
-            raise RuntimeError("No valid PaddlePaddle model found")
-        model_file, params_file = model_paths["paddle"]
-
+    def _check_run_mode(self):
+        # TODO: Check if trt is available
+        # check avaliable for trt
         if (
-            self._option.model_name == "LaTeX_OCR_rec"
+            not DISABLE_TRT_MODEL_BL
+            and self._option.run_mode.startswith("trt")
+            and self._model_name in TRT_BLOCKLIST
+            and self._option.device_type == "gpu"
+        ):
+            logging.warning(
+                f"The model({self._model_name}) is not supported to run in trt mode! Using `paddle` instead!"
+            )
+            self._option.run_mode = "paddle"
+
+        # check avaliable for mkldnn
+        elif (
+            not DISABLE_MKLDNN_MODEL_BL
+            and self._option.run_mode.startswith("mkldnn")
+            and self._model_name in MKLDNN_BLOCKLIST
             and self._option.device_type == "cpu"
         ):
+            logging.warning(
+                f"The model({self._model_name}) is not supported to run in MKLDNN mode! Using `paddle` instead!"
+            )
+            self._option.run_mode = "paddle"
+            return "paddle"
+
+        # check avaliable for model
+        if self._model_name == "LaTeX_OCR_rec" and self._option.device_type == "cpu":
             import cpuinfo
 
             if (
@@ -313,7 +336,22 @@ class PaddleInfer(StaticInfer):
                     "Now, the `LaTeX_OCR_rec` model only support `mkldnn` mode when running on Intel CPU devices. So using `mkldnn` instead."
                 )
             self._option.run_mode = "mkldnn"
-            logging.debug("`run_mode` updated to 'mkldnn'")
+
+    def _create(
+        self,
+    ):
+        """_create"""
+        import paddle
+        import paddle.inference
+
+        model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
+        if "paddle" not in model_paths:
+            raise RuntimeError("No valid PaddlePaddle model found")
+
+        check_supported_device_type(self._option.device_type, self._model_name)
+        self._check_run_mode()
+
+        model_file, params_file = model_paths["paddle"]
 
         if self._option.device_type == "cpu" and self._option.device_id is not None:
             self._option.device_id = None
@@ -328,7 +366,10 @@ class PaddleInfer(StaticInfer):
 
         # for TRT
         if self._option.run_mode.startswith("trt"):
-            assert self._option.device_type == "gpu"
+            assert self._option.device_type.lower() == "gpu", (
+                f"`{self._option.run_mode}` is only available on GPU devices, "
+                f"but got device_type='{self._option.device_type}'."
+            )
             cache_dir = self.model_dir / CACHE_DIR / "paddle"
             config = self._configure_trt(
                 model_file,
@@ -681,7 +722,7 @@ class HPInfer(StaticInfer):
                     trt_dynamic_shape_input_data,
                 )
                 kwargs["trt_dynamic_shape_input_data"] = trt_dynamic_shape_input_data
-        pp_option = PaddlePredictorOption(self._config.pdx_model_name, **kwargs)
+        pp_option = PaddlePredictorOption(**kwargs)
         logging.info("Using Paddle Inference backend")
         logging.info("Paddle predictor option: %s", pp_option)
         return PaddleInfer(self._model_dir, self._model_file_prefix, option=pp_option)

+ 0 - 2
paddlex/inference/pipelines/base.py

@@ -99,8 +99,6 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         # TODO(gaotingquan): support to specify pp_option by model in pipeline
         if self.pp_option is not None:
             pp_option = self.pp_option.copy()
-            pp_option.model_name = config["model_name"]
-            pp_option.run_mode = self.pp_option.run_mode
         else:
             pp_option = None
 

+ 11 - 78
paddlex/inference/utils/pp_option.py

@@ -17,22 +17,11 @@ from copy import deepcopy
 from typing import Dict, List
 
 from ...utils import logging
-from ...utils.device import (
-    check_supported_device_type,
-    get_default_device,
-    parse_device,
-    set_env_for_device_type,
-)
-from ...utils.flags import (
-    DISABLE_MKLDNN_MODEL_BL,
-    DISABLE_TRT_MODEL_BL,
-    ENABLE_MKLDNN_BYDEFAULT,
-    USE_PIR_TRT,
-)
+from ...utils.device import get_default_device, parse_device, set_env_for_device_type
+from ...utils.flags import ENABLE_MKLDNN_BYDEFAULT, USE_PIR_TRT
 from .misc import is_mkldnn_available
 from .mkldnn_blocklist import MKLDNN_BLOCKLIST
 from .new_ir_blocklist import NEWIR_BLOCKLIST
-from .trt_blocklist import TRT_BLOCKLIST
 from .trt_config import TRT_CFG_SETTING, TRT_PRECISION_MAP
 
 
@@ -67,33 +56,13 @@ class PaddlePredictorOption(object):
     )
     SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu", "gcu")
 
-    def __init__(self, model_name=None, **kwargs):
+    def __init__(self, **kwargs):
         super().__init__()
-        self._is_default_run_mode = True
-        self._model_name = model_name
         self._cfg = {}
         self._init_option(**kwargs)
-        self._changed = False
-
-    @property
-    def model_name(self):
-        return self._model_name
-
-    @model_name.setter
-    def model_name(self, model_name):
-        self._model_name = model_name
-
-    @property
-    def changed(self):
-        return self._changed
-
-    @changed.setter
-    def changed(self, v):
-        assert isinstance(v, bool)
-        self._changed = v
 
     def copy(self):
-        obj = type(self)(self._model_name)
+        obj = type(self)()
         obj._cfg = deepcopy(self._cfg)
         if hasattr(self, "trt_cfg_setting"):
             obj.trt_cfg_setting = self.trt_cfg_setting
@@ -108,15 +77,13 @@ class PaddlePredictorOption(object):
                     f"{k} is not supported to set! The supported option is: {self._get_settable_attributes()}"
                 )
 
-        if "run_mode" in self._cfg:
-            self._is_default_run_mode = False
-
-        for k, v in self._get_default_config().items():
+    def setdefault_by_model_name(self, model_name):
+        for k, v in self._get_default_config(model_name).items():
             self._cfg.setdefault(k, v)
 
         # for trt
         if self.run_mode in ("trt_int8", "trt_fp32", "trt_fp16"):
-            trt_cfg_setting = TRT_CFG_SETTING[self.model_name]
+            trt_cfg_setting = TRT_CFG_SETTING[model_name]
             if USE_PIR_TRT:
                 trt_cfg_setting["precision_mode"] = TRT_PRECISION_MAP[self.run_mode]
             else:
@@ -125,7 +92,7 @@ class PaddlePredictorOption(object):
                 )
             self.trt_cfg_setting = trt_cfg_setting
 
-    def _get_default_config(self):
+    def _get_default_config(self, model_name):
         """get default config"""
         if self.device_type is None:
             device_type, device_ids = parse_device(get_default_device())
@@ -134,12 +101,12 @@ class PaddlePredictorOption(object):
             device_type, device_id = self.device_type, self.device_id
 
         default_config = {
-            "run_mode": get_default_run_mode(self.model_name, device_type),
+            "run_mode": get_default_run_mode(model_name, device_type),
             "device_type": device_type,
             "device_id": device_id,
             "cpu_threads": 10,
             "delete_pass": [],
-            "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
+            "enable_new_ir": True if model_name not in NEWIR_BLOCKLIST else False,
             "enable_cinn": False,
             "trt_cfg_setting": {},
             "trt_use_dynamic_shapes": True,  # only for trt
@@ -155,13 +122,6 @@ class PaddlePredictorOption(object):
 
     def _update(self, k, v):
         self._cfg[k] = v
-        self.changed = True
-
-    def reset_run_mode_by_default(self, model_name=None, device_type=None):
-        if self._is_default_run_mode:
-            model_name = model_name or self.model_name
-            device_type = device_type or self.device_type
-            self._update("run_mode", get_default_run_mode(model_name, device_type))
 
     @property
     def run_mode(self):
@@ -179,32 +139,6 @@ class PaddlePredictorOption(object):
         if run_mode.startswith("mkldnn") and not is_mkldnn_available():
             logging.warning("MKL-DNN is not available. Using `paddle` instead.")
             run_mode = "paddle"
-
-        # TODO: Check if trt is available
-
-        if self._model_name is not None:
-            # TRT Blocklist
-            if (
-                not DISABLE_TRT_MODEL_BL
-                and run_mode.startswith("trt")
-                and self._model_name in TRT_BLOCKLIST
-            ):
-                logging.warning(
-                    f"The model({self._model_name}) is not supported to run in trt mode! Using `paddle` instead!"
-                )
-                run_mode = "paddle"
-            # MKLDNN Blocklist
-            elif (
-                not DISABLE_MKLDNN_MODEL_BL
-                and run_mode.startswith("mkldnn")
-                and self._model_name in MKLDNN_BLOCKLIST
-            ):
-                logging.warning(
-                    f"The model({self._model_name}) is not supported to run in MKLDNN mode! Using `paddle` instead!"
-                )
-                run_mode = "paddle"
-
-        self._is_default_run_mode = False
         self._update("run_mode", run_mode)
 
     @property
@@ -218,7 +152,6 @@ class PaddlePredictorOption(object):
             raise ValueError(
                 f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
             )
-        check_supported_device_type(device_type, self.model_name)
         self._update("device_type", device_type)
         set_env_for_device_type(device_type)
         # XXX(gaotingquan): set flag to accelerate inference in paddle 3.0b2
@@ -278,7 +211,7 @@ class PaddlePredictorOption(object):
     def trt_cfg_setting(self, config: Dict):
         """set trt config"""
         assert isinstance(
-            config, dict
+            config, (dict, type(None))
         ), f"The trt_cfg_setting must be `dict` type, but received `{type(config)}` type!"
         self._update("trt_cfg_setting", config)
 

+ 0 - 1
paddlex/model.py

@@ -90,7 +90,6 @@ class _ModelBasedConfig(_BaseModel):
 
         create_predictor_kwargs = {}
         if kernel_option is not UNSET:
-            kernel_option.setdefault("model_name", self._model_name)
             create_predictor_kwargs["pp_option"] = PaddlePredictorOption(
                 **kernel_option
             )