|
|
@@ -122,12 +122,16 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
def _get_default_config(self):
|
|
|
"""get default config"""
|
|
|
- device_type, device_ids = parse_device(get_default_device())
|
|
|
+ if self.device_type is None:
|
|
|
+ device_type, device_ids = parse_device(get_default_device())
|
|
|
+ device_id = None if device_ids is None else device_ids[0]
|
|
|
+ else:
|
|
|
+ device_type, device_id = self.device_type, self.device_id
|
|
|
|
|
|
default_config = {
|
|
|
"run_mode": get_default_run_mode(self.model_name, device_type),
|
|
|
"device_type": device_type,
|
|
|
- "device_id": None if device_ids is None else device_ids[0],
|
|
|
+ "device_id": device_id,
|
|
|
"cpu_threads": 8,
|
|
|
"delete_pass": [],
|
|
|
"enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
|
|
|
@@ -148,9 +152,14 @@ class PaddlePredictorOption(object):
|
|
|
self._cfg[k] = v
|
|
|
self.changed = True
|
|
|
|
|
|
+ def reset_run_mode_by_default(self, model_name=None, device_type=None):
|
|
|
+ model_name = model_name or self.model_name
|
|
|
+ device_type = device_type or self.device_type
|
|
|
+ self._update("run_mode", get_default_run_mode(model_name, device_type))
|
|
|
+
|
|
|
@property
|
|
|
def run_mode(self):
|
|
|
- return self._cfg["run_mode"]
|
|
|
+ return self._cfg.get("run_mode")
|
|
|
|
|
|
@run_mode.setter
|
|
|
def run_mode(self, run_mode: str):
|
|
|
@@ -193,7 +202,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def device_type(self):
|
|
|
- return self._cfg["device_type"]
|
|
|
+ return self._cfg.get("device_type")
|
|
|
|
|
|
@device_type.setter
|
|
|
def device_type(self, device_type):
|
|
|
@@ -211,7 +220,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def device_id(self):
|
|
|
- return self._cfg["device_id"]
|
|
|
+ return self._cfg.get("device_id")
|
|
|
|
|
|
@device_id.setter
|
|
|
def device_id(self, device_id):
|
|
|
@@ -219,7 +228,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def cpu_threads(self):
|
|
|
- return self._cfg["cpu_threads"]
|
|
|
+ return self._cfg.get("cpu_threads")
|
|
|
|
|
|
@cpu_threads.setter
|
|
|
def cpu_threads(self, cpu_threads):
|
|
|
@@ -230,7 +239,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def delete_pass(self):
|
|
|
- return self._cfg["delete_pass"]
|
|
|
+ return self._cfg.get("delete_pass")
|
|
|
|
|
|
@delete_pass.setter
|
|
|
def delete_pass(self, delete_pass):
|
|
|
@@ -238,7 +247,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def enable_new_ir(self):
|
|
|
- return self._cfg["enable_new_ir"]
|
|
|
+ return self._cfg.get("enable_new_ir")
|
|
|
|
|
|
@enable_new_ir.setter
|
|
|
def enable_new_ir(self, enable_new_ir: bool):
|
|
|
@@ -247,7 +256,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def enable_cinn(self):
|
|
|
- return self._cfg["enable_cinn"]
|
|
|
+ return self._cfg.get("enable_cinn")
|
|
|
|
|
|
@enable_cinn.setter
|
|
|
def enable_cinn(self, enable_cinn: bool):
|
|
|
@@ -256,7 +265,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def trt_cfg_setting(self):
|
|
|
- return self._cfg["trt_cfg_setting"]
|
|
|
+ return self._cfg.get("trt_cfg_setting")
|
|
|
|
|
|
@trt_cfg_setting.setter
|
|
|
def trt_cfg_setting(self, config: Dict):
|
|
|
@@ -268,7 +277,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def trt_use_dynamic_shapes(self):
|
|
|
- return self._cfg["trt_use_dynamic_shapes"]
|
|
|
+ return self._cfg.get("trt_use_dynamic_shapes")
|
|
|
|
|
|
@trt_use_dynamic_shapes.setter
|
|
|
def trt_use_dynamic_shapes(self, trt_use_dynamic_shapes):
|
|
|
@@ -276,7 +285,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def trt_collect_shape_range_info(self):
|
|
|
- return self._cfg["trt_collect_shape_range_info"]
|
|
|
+ return self._cfg.get("trt_collect_shape_range_info")
|
|
|
|
|
|
@trt_collect_shape_range_info.setter
|
|
|
def trt_collect_shape_range_info(self, trt_collect_shape_range_info):
|
|
|
@@ -284,7 +293,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def trt_discard_cached_shape_range_info(self):
|
|
|
- return self._cfg["trt_discard_cached_shape_range_info"]
|
|
|
+ return self._cfg.get("trt_discard_cached_shape_range_info")
|
|
|
|
|
|
@trt_discard_cached_shape_range_info.setter
|
|
|
def trt_discard_cached_shape_range_info(self, trt_discard_cached_shape_range_info):
|
|
|
@@ -294,7 +303,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def trt_dynamic_shapes(self):
|
|
|
- return self._cfg["trt_dynamic_shapes"]
|
|
|
+ return self._cfg.get("trt_dynamic_shapes")
|
|
|
|
|
|
@trt_dynamic_shapes.setter
|
|
|
def trt_dynamic_shapes(self, trt_dynamic_shapes: Dict[str, List[List[int]]]):
|
|
|
@@ -305,7 +314,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def trt_dynamic_shape_input_data(self):
|
|
|
- return self._cfg["trt_dynamic_shape_input_data"]
|
|
|
+ return self._cfg.get("trt_dynamic_shape_input_data")
|
|
|
|
|
|
@trt_dynamic_shape_input_data.setter
|
|
|
def trt_dynamic_shape_input_data(
|
|
|
@@ -315,7 +324,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def trt_shape_range_info_path(self):
|
|
|
- return self._cfg["trt_shape_range_info_path"]
|
|
|
+ return self._cfg.get("trt_shape_range_info_path")
|
|
|
|
|
|
@trt_shape_range_info_path.setter
|
|
|
def trt_shape_range_info_path(self, trt_shape_range_info_path: str):
|
|
|
@@ -324,7 +333,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def trt_allow_rebuild_at_runtime(self):
|
|
|
- return self._cfg["trt_allow_rebuild_at_runtime"]
|
|
|
+ return self._cfg.get("trt_allow_rebuild_at_runtime")
|
|
|
|
|
|
@trt_allow_rebuild_at_runtime.setter
|
|
|
def trt_allow_rebuild_at_runtime(self, trt_allow_rebuild_at_runtime):
|
|
|
@@ -332,7 +341,7 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
@property
|
|
|
def mkldnn_cache_capacity(self):
|
|
|
- return self._cfg["mkldnn_cache_capacity"]
|
|
|
+ return self._cfg.get("mkldnn_cache_capacity")
|
|
|
|
|
|
@mkldnn_cache_capacity.setter
|
|
|
def mkldnn_cache_capacity(self, capacity: int):
|