|
|
@@ -24,6 +24,7 @@ from ...utils.device import (
|
|
|
)
|
|
|
from .new_ir_blacklist import NEWIR_BLOCKLIST
|
|
|
from .trt_blacklist import TRT_BLOCKLIST
|
|
|
+from .trt_config import TRT_PRECISION_MAP, TRT_CFG
|
|
|
|
|
|
|
|
|
class PaddlePredictorOption(object):
|
|
|
@@ -69,21 +70,24 @@ class PaddlePredictorOption(object):
|
|
|
for k, v in self._get_default_config().items():
|
|
|
self._cfg.setdefault(k, v)
|
|
|
|
|
|
+ # for trt
|
|
|
+ if self.run_mode in TRT_PRECISION_MAP:
|
|
|
+ trt_cfg = TRT_CFG[self.model_name]
|
|
|
+ trt_cfg["precision_mode"] = TRT_PRECISION_MAP[self.run_mode]
|
|
|
+ self.trt_cfg = trt_cfg
|
|
|
+
|
|
|
def _get_default_config(self):
|
|
|
"""get default config"""
|
|
|
device_type, device_ids = parse_device(get_default_device())
|
|
|
- return {
|
|
|
+
|
|
|
+ default_config = {
|
|
|
"run_mode": "paddle",
|
|
|
"device_type": device_type,
|
|
|
"device_id": None if device_ids is None else device_ids[0],
|
|
|
"cpu_threads": 8,
|
|
|
"delete_pass": [],
|
|
|
"enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
|
|
|
- "trt_max_workspace_size": 1 << 30, # only for trt
|
|
|
- "trt_max_batch_size": 32, # only for trt
|
|
|
- "trt_min_subgraph_size": 3, # only for trt
|
|
|
- "trt_use_static": True, # only for trt
|
|
|
- "trt_use_calib_mode": False, # only for trt
|
|
|
+ "trt_cfg": {},
|
|
|
"trt_use_dynamic_shapes": True, # only for trt
|
|
|
"trt_collect_shape_range_info": True, # only for trt
|
|
|
"trt_discard_cached_shape_range_info": False, # only for trt
|
|
|
@@ -92,6 +96,7 @@ class PaddlePredictorOption(object):
|
|
|
"trt_shape_range_info_path": None, # only for trt
|
|
|
"trt_allow_rebuild_at_runtime": True, # only for trt
|
|
|
}
|
|
|
+ return default_config
|
|
|
|
|
|
def _update(self, k, v):
|
|
|
self._cfg[k] = v
|
|
|
@@ -173,49 +178,16 @@ class PaddlePredictorOption(object):
|
|
|
self._update("enable_new_ir", enable_new_ir)
|
|
|
|
|
|
@property
|
|
|
- def trt_max_workspace_size(self):
|
|
|
- return self._cfg["trt_max_workspace_size"]
|
|
|
-
|
|
|
- @trt_max_workspace_size.setter
|
|
|
- def trt_max_workspace_size(self, trt_max_workspace_size):
|
|
|
- self._update("trt_max_workspace_size", trt_max_workspace_size)
|
|
|
-
|
|
|
- @property
|
|
|
- def trt_max_batch_size(self):
|
|
|
- return self._cfg["trt_max_batch_size"]
|
|
|
-
|
|
|
- @trt_max_batch_size.setter
|
|
|
- def trt_max_batch_size(self, trt_max_batch_size):
|
|
|
- self._update("trt_max_batch_size", trt_max_batch_size)
|
|
|
-
|
|
|
- @property
|
|
|
- def trt_min_subgraph_size(self):
|
|
|
- return self._cfg["trt_min_subgraph_size"]
|
|
|
-
|
|
|
- @trt_min_subgraph_size.setter
|
|
|
- def trt_min_subgraph_size(self, trt_min_subgraph_size: int):
|
|
|
- """set min subgraph size"""
|
|
|
- if not isinstance(trt_min_subgraph_size, int):
|
|
|
- raise Exception()
|
|
|
- self._update("trt_min_subgraph_size", trt_min_subgraph_size)
|
|
|
-
|
|
|
- @property
|
|
|
- def trt_use_static(self):
|
|
|
- return self._cfg["trt_use_static"]
|
|
|
+ def trt_cfg(self):
|
|
|
+ return self._cfg["trt_cfg"]
|
|
|
|
|
|
- @trt_use_static.setter
|
|
|
- def trt_use_static(self, trt_use_static):
|
|
|
- """set trt use static"""
|
|
|
- self._update("trt_use_static", trt_use_static)
|
|
|
-
|
|
|
- @property
|
|
|
- def trt_use_calib_mode(self):
|
|
|
- return self._cfg["trt_use_calib_mode"]
|
|
|
-
|
|
|
- @trt_use_calib_mode.setter
|
|
|
- def trt_use_calib_mode(self, trt_use_calib_mode):
|
|
|
- """set trt calib mode"""
|
|
|
- self._update("trt_use_calib_mode", trt_use_calib_mode)
|
|
|
+ @trt_cfg.setter
|
|
|
+ def trt_cfg(self, config: Dict):
|
|
|
+ """set trt config"""
|
|
|
+ assert isinstance(
|
|
|
+ config, dict
|
|
|
+ ), f"The trt_cfg must be `dict` type, but recived `{type(config)}` type!"
|
|
|
+ self._update("trt_cfg", config)
|
|
|
|
|
|
@property
|
|
|
def trt_use_dynamic_shapes(self):
|
|
|
@@ -284,14 +256,6 @@ class PaddlePredictorOption(object):
|
|
|
# For backward compatibility
|
|
|
# TODO: Issue deprecation warnings
|
|
|
@property
|
|
|
- def min_subgraph_size(self):
|
|
|
- return self.trt_min_subgraph_size
|
|
|
-
|
|
|
- @min_subgraph_size.setter
|
|
|
- def min_subgraph_size(self, min_subgraph_size):
|
|
|
- self.trt_min_subgraph_size = min_subgraph_size
|
|
|
-
|
|
|
- @property
|
|
|
def shape_info_filename(self):
|
|
|
return self.trt_shape_range_info_path
|
|
|
|
|
|
@@ -299,22 +263,6 @@ class PaddlePredictorOption(object):
|
|
|
def shape_info_filename(self, shape_info_filename):
|
|
|
self.trt_shape_range_info_path = shape_info_filename
|
|
|
|
|
|
- @property
|
|
|
- def trt_calib_mode(self):
|
|
|
- return self.trt_use_calib_mode
|
|
|
-
|
|
|
- @trt_calib_mode.setter
|
|
|
- def trt_calib_mode(self, trt_calib_mode):
|
|
|
- self.trt_use_calib_mode = trt_calib_mode
|
|
|
-
|
|
|
- @property
|
|
|
- def batch_size(self):
|
|
|
- return self.trt_max_batch_size
|
|
|
-
|
|
|
- @batch_size.setter
|
|
|
- def batch_size(self, batch_size):
|
|
|
- self.trt_max_batch_size = batch_size
|
|
|
-
|
|
|
def set_device(self, device: str):
|
|
|
"""set device"""
|
|
|
if not device:
|