|
|
@@ -17,22 +17,11 @@ from copy import deepcopy
|
|
|
from typing import Dict, List
|
|
|
|
|
|
from ...utils import logging
|
|
|
-from ...utils.device import (
|
|
|
- check_supported_device_type,
|
|
|
- get_default_device,
|
|
|
- parse_device,
|
|
|
- set_env_for_device_type,
|
|
|
-)
|
|
|
-from ...utils.flags import (
|
|
|
- DISABLE_MKLDNN_MODEL_BL,
|
|
|
- DISABLE_TRT_MODEL_BL,
|
|
|
- ENABLE_MKLDNN_BYDEFAULT,
|
|
|
- USE_PIR_TRT,
|
|
|
-)
|
|
|
+from ...utils.device import get_default_device, parse_device, set_env_for_device_type
|
|
|
+from ...utils.flags import ENABLE_MKLDNN_BYDEFAULT, USE_PIR_TRT
|
|
|
from .misc import is_mkldnn_available
|
|
|
from .mkldnn_blocklist import MKLDNN_BLOCKLIST
|
|
|
from .new_ir_blocklist import NEWIR_BLOCKLIST
|
|
|
-from .trt_blocklist import TRT_BLOCKLIST
|
|
|
from .trt_config import TRT_CFG_SETTING, TRT_PRECISION_MAP
|
|
|
|
|
|
|
|
|
@@ -67,33 +56,13 @@ class PaddlePredictorOption(object):
|
|
|
)
|
|
|
SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu", "gcu")
|
|
|
|
|
|
- def __init__(self, model_name=None, **kwargs):
|
|
|
+ def __init__(self, **kwargs):
|
|
|
super().__init__()
|
|
|
- self._is_default_run_mode = True
|
|
|
- self._model_name = model_name
|
|
|
self._cfg = {}
|
|
|
self._init_option(**kwargs)
|
|
|
- self._changed = False
|
|
|
-
|
|
|
- @property
|
|
|
- def model_name(self):
|
|
|
- return self._model_name
|
|
|
-
|
|
|
- @model_name.setter
|
|
|
- def model_name(self, model_name):
|
|
|
- self._model_name = model_name
|
|
|
-
|
|
|
- @property
|
|
|
- def changed(self):
|
|
|
- return self._changed
|
|
|
-
|
|
|
- @changed.setter
|
|
|
- def changed(self, v):
|
|
|
- assert isinstance(v, bool)
|
|
|
- self._changed = v
|
|
|
|
|
|
def copy(self):
|
|
|
- obj = type(self)(self._model_name)
|
|
|
+ obj = type(self)()
|
|
|
obj._cfg = deepcopy(self._cfg)
|
|
|
if hasattr(self, "trt_cfg_setting"):
|
|
|
obj.trt_cfg_setting = self.trt_cfg_setting
|
|
|
@@ -108,15 +77,13 @@ class PaddlePredictorOption(object):
|
|
|
f"{k} is not supported to set! The supported option is: {self._get_settable_attributes()}"
|
|
|
)
|
|
|
|
|
|
- if "run_mode" in self._cfg:
|
|
|
- self._is_default_run_mode = False
|
|
|
-
|
|
|
- for k, v in self._get_default_config().items():
|
|
|
+ def setdefault_by_model_name(self, model_name):
|
|
|
+ for k, v in self._get_default_config(model_name).items():
|
|
|
self._cfg.setdefault(k, v)
|
|
|
|
|
|
# for trt
|
|
|
if self.run_mode in ("trt_int8", "trt_fp32", "trt_fp16"):
|
|
|
- trt_cfg_setting = TRT_CFG_SETTING[self.model_name]
|
|
|
+ trt_cfg_setting = TRT_CFG_SETTING[model_name]
|
|
|
if USE_PIR_TRT:
|
|
|
trt_cfg_setting["precision_mode"] = TRT_PRECISION_MAP[self.run_mode]
|
|
|
else:
|
|
|
@@ -125,7 +92,7 @@ class PaddlePredictorOption(object):
|
|
|
)
|
|
|
self.trt_cfg_setting = trt_cfg_setting
|
|
|
|
|
|
- def _get_default_config(self):
|
|
|
+ def _get_default_config(self, model_name):
|
|
|
"""get default config"""
|
|
|
if self.device_type is None:
|
|
|
device_type, device_ids = parse_device(get_default_device())
|
|
|
@@ -134,12 +101,12 @@ class PaddlePredictorOption(object):
|
|
|
device_type, device_id = self.device_type, self.device_id
|
|
|
|
|
|
default_config = {
|
|
|
- "run_mode": get_default_run_mode(self.model_name, device_type),
|
|
|
+ "run_mode": get_default_run_mode(model_name, device_type),
|
|
|
"device_type": device_type,
|
|
|
"device_id": device_id,
|
|
|
"cpu_threads": 10,
|
|
|
"delete_pass": [],
|
|
|
- "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
|
|
|
+ "enable_new_ir": True if model_name not in NEWIR_BLOCKLIST else False,
|
|
|
"enable_cinn": False,
|
|
|
"trt_cfg_setting": {},
|
|
|
"trt_use_dynamic_shapes": True, # only for trt
|
|
|
@@ -155,13 +122,6 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
def _update(self, k, v):
|
|
|
self._cfg[k] = v
|
|
|
- self.changed = True
|
|
|
-
|
|
|
- def reset_run_mode_by_default(self, model_name=None, device_type=None):
|
|
|
- if self._is_default_run_mode:
|
|
|
- model_name = model_name or self.model_name
|
|
|
- device_type = device_type or self.device_type
|
|
|
- self._update("run_mode", get_default_run_mode(model_name, device_type))
|
|
|
|
|
|
@property
|
|
|
def run_mode(self):
|
|
|
@@ -179,32 +139,6 @@ class PaddlePredictorOption(object):
|
|
|
if run_mode.startswith("mkldnn") and not is_mkldnn_available():
|
|
|
logging.warning("MKL-DNN is not available. Using `paddle` instead.")
|
|
|
run_mode = "paddle"
|
|
|
-
|
|
|
- # TODO: Check if trt is available
|
|
|
-
|
|
|
- if self._model_name is not None:
|
|
|
- # TRT Blocklist
|
|
|
- if (
|
|
|
- not DISABLE_TRT_MODEL_BL
|
|
|
- and run_mode.startswith("trt")
|
|
|
- and self._model_name in TRT_BLOCKLIST
|
|
|
- ):
|
|
|
- logging.warning(
|
|
|
- f"The model({self._model_name}) is not supported to run in trt mode! Using `paddle` instead!"
|
|
|
- )
|
|
|
- run_mode = "paddle"
|
|
|
- # MKLDNN Blocklist
|
|
|
- elif (
|
|
|
- not DISABLE_MKLDNN_MODEL_BL
|
|
|
- and run_mode.startswith("mkldnn")
|
|
|
- and self._model_name in MKLDNN_BLOCKLIST
|
|
|
- ):
|
|
|
- logging.warning(
|
|
|
- f"The model({self._model_name}) is not supported to run in MKLDNN mode! Using `paddle` instead!"
|
|
|
- )
|
|
|
- run_mode = "paddle"
|
|
|
-
|
|
|
- self._is_default_run_mode = False
|
|
|
self._update("run_mode", run_mode)
|
|
|
|
|
|
@property
|
|
|
@@ -218,7 +152,6 @@ class PaddlePredictorOption(object):
|
|
|
raise ValueError(
|
|
|
f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
|
|
|
)
|
|
|
- check_supported_device_type(device_type, self.model_name)
|
|
|
self._update("device_type", device_type)
|
|
|
set_env_for_device_type(device_type)
|
|
|
# XXX(gaotingquan): set flag to accelerate inference in paddle 3.0b2
|
|
|
@@ -278,7 +211,7 @@ class PaddlePredictorOption(object):
|
|
|
def trt_cfg_setting(self, config: Dict):
|
|
|
"""set trt config"""
|
|
|
assert isinstance(
|
|
|
- config, dict
|
|
|
+ config, (dict, type(None))
|
|
|
), f"The trt_cfg_setting must be `dict` type, but received `{type(config)}` type!"
|
|
|
self._update("trt_cfg_setting", config)
|
|
|
|