|
|
@@ -23,13 +23,34 @@ from ...utils.device import (
|
|
|
parse_device,
|
|
|
set_env_for_device_type,
|
|
|
)
|
|
|
-from ...utils.flags import USE_PIR_TRT
|
|
|
+from ...utils.flags import (
|
|
|
+ DISABLE_MKLDNN_MODEL_BL,
|
|
|
+ DISABLE_TRT_MODEL_BL,
|
|
|
+ ENABLE_MKLDNN_BYDEFAULT,
|
|
|
+ USE_PIR_TRT,
|
|
|
+)
|
|
|
+from .misc import is_mkldnn_available
|
|
|
from .mkldnn_blocklist import MKLDNN_BLOCKLIST
|
|
|
from .new_ir_blocklist import NEWIR_BLOCKLIST
|
|
|
from .trt_blocklist import TRT_BLOCKLIST
|
|
|
from .trt_config import TRT_CFG_SETTING, TRT_PRECISION_MAP
|
|
|
|
|
|
|
|
|
+def get_default_run_mode(model_name, device_type):
|
|
|
+ if not model_name:
|
|
|
+ return "paddle"
|
|
|
+ if device_type != "cpu":
|
|
|
+ return "paddle"
|
|
|
+ if (
|
|
|
+ ENABLE_MKLDNN_BYDEFAULT
|
|
|
+ and is_mkldnn_available()
|
|
|
+ and model_name not in MKLDNN_BLOCKLIST
|
|
|
+ ):
|
|
|
+ return "mkldnn"
|
|
|
+ else:
|
|
|
+ return "paddle"
|
|
|
+
|
|
|
+
|
|
|
class PaddlePredictorOption(object):
|
|
|
"""Paddle Inference Engine Option"""
|
|
|
|
|
|
@@ -104,7 +125,7 @@ class PaddlePredictorOption(object):
|
|
|
device_type, device_ids = parse_device(get_default_device())
|
|
|
|
|
|
default_config = {
|
|
|
- "run_mode": "paddle",
|
|
|
+ "run_mode": get_default_run_mode(self.model_name, device_type),
|
|
|
"device_type": device_type,
|
|
|
"device_id": None if device_ids is None else device_ids[0],
|
|
|
"cpu_threads": 8,
|
|
|
@@ -142,13 +163,21 @@ class PaddlePredictorOption(object):
|
|
|
|
|
|
if self._model_name is not None:
|
|
|
# TRT Blocklist
|
|
|
- if run_mode.startswith("trt") and self._model_name in TRT_BLOCKLIST:
|
|
|
+ if (
|
|
|
+ not DISABLE_TRT_MODEL_BL
|
|
|
+ and run_mode.startswith("trt")
|
|
|
+ and self._model_name in TRT_BLOCKLIST
|
|
|
+ ):
|
|
|
logging.warning(
|
|
|
f"The model({self._model_name}) is not supported to run in trt mode! Using `paddle` instead!"
|
|
|
)
|
|
|
run_mode = "paddle"
|
|
|
# MKLDNN Blocklist
|
|
|
- elif run_mode.startswith("mkldnn") and self._model_name in MKLDNN_BLOCKLIST:
|
|
|
+ elif (
|
|
|
+ not DISABLE_MKLDNN_MODEL_BL
|
|
|
+ and run_mode.startswith("mkldnn")
|
|
|
+ and self._model_name in MKLDNN_BLOCKLIST
|
|
|
+ ):
|
|
|
logging.warning(
|
|
|
f"The model({self._model_name}) is not supported to run in MKLDNN mode! Using `paddle` instead!"
|
|
|
)
|