Просмотр исходного кода

enable mkldnn by default & DISABLE_TRT_MODEL_BL & DISABLE_MKLDNN_MODEL_BL (#4191)

* support DISABLE_TRT_MODEL_BL and DISABLE_MKLDNN_MODEL_BL

* enable mkldnn by default when device is set to cpu

* bugfix

* update the check about if mkldnn is supported

ref #4169
Tingquan Gao 5 месяцев назад
Родитель
Сommit
559c12f0f9

+ 3 - 1
paddlex/inference/models/common/static_infer.py

@@ -411,7 +411,9 @@ class PaddleInfer(StaticInfer):
                         config.enable_mkldnn()
                         if "bf16" in self._option.run_mode:
                             config.enable_mkldnn_bfloat16()
-                        config.set_mkldnn_cache_capacity(self._option.mkldnn_cache_capacity)
+                        config.set_mkldnn_cache_capacity(
+                            self._option.mkldnn_cache_capacity
+                        )
                     else:
                         logging.warning(
                             "MKL-DNN is not available. We will disable MKL-DNN."

+ 0 - 1
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -964,7 +964,6 @@ class _LayoutParsingPipelineV2(BasePipeline):
         Returns:
             LayoutParsingResultV2: The predicted layout parsing result.
         """
-
         model_settings = self.get_model_settings(
             use_doc_orientation_classify,
             use_doc_unwarping,

+ 20 - 0
paddlex/inference/utils/misc.py

@@ -0,0 +1,20 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def is_mkldnn_available():
+    # XXX: Not sure if this is the best way to check if MKL-DNN is available
+    from paddle.inference import Config
+
+    return hasattr(Config, "set_mkldnn_cache_capacity")

+ 1 - 1
paddlex/inference/utils/official_models.py

@@ -473,7 +473,7 @@ class OfficialModelsDict(dict):
 
         if (
             MODEL_SOURCE.lower() == "huggingface"
-            and is_huggingface_accessible
+            and is_huggingface_accessible()
             and key in HUGGINGFACE_MODELS
         ):
             return _download_from_hf()

+ 33 - 4
paddlex/inference/utils/pp_option.py

@@ -23,13 +23,34 @@ from ...utils.device import (
     parse_device,
     set_env_for_device_type,
 )
-from ...utils.flags import USE_PIR_TRT
+from ...utils.flags import (
+    DISABLE_MKLDNN_MODEL_BL,
+    DISABLE_TRT_MODEL_BL,
+    ENABLE_MKLDNN_BYDEFAULT,
+    USE_PIR_TRT,
+)
+from .misc import is_mkldnn_available
 from .mkldnn_blocklist import MKLDNN_BLOCKLIST
 from .new_ir_blocklist import NEWIR_BLOCKLIST
 from .trt_blocklist import TRT_BLOCKLIST
 from .trt_config import TRT_CFG_SETTING, TRT_PRECISION_MAP
 
 
+def get_default_run_mode(model_name, device_type):
+    if not model_name:
+        return "paddle"
+    if device_type != "cpu":
+        return "paddle"
+    if (
+        ENABLE_MKLDNN_BYDEFAULT
+        and is_mkldnn_available()
+        and model_name not in MKLDNN_BLOCKLIST
+    ):
+        return "mkldnn"
+    else:
+        return "paddle"
+
+
 class PaddlePredictorOption(object):
     """Paddle Inference Engine Option"""
 
@@ -104,7 +125,7 @@ class PaddlePredictorOption(object):
         device_type, device_ids = parse_device(get_default_device())
 
         default_config = {
-            "run_mode": "paddle",
+            "run_mode": get_default_run_mode(self.model_name, device_type),
             "device_type": device_type,
             "device_id": None if device_ids is None else device_ids[0],
             "cpu_threads": 8,
@@ -142,13 +163,21 @@ class PaddlePredictorOption(object):
 
         if self._model_name is not None:
             # TRT Blocklist
-            if run_mode.startswith("trt") and self._model_name in TRT_BLOCKLIST:
+            if (
+                not DISABLE_TRT_MODEL_BL
+                and run_mode.startswith("trt")
+                and self._model_name in TRT_BLOCKLIST
+            ):
                 logging.warning(
                     f"The model({self._model_name}) is not supported to run in trt mode! Using `paddle` instead!"
                 )
                 run_mode = "paddle"
             # MKLDNN Blocklist
-            elif run_mode.startswith("mkldnn") and self._model_name in MKLDNN_BLOCKLIST:
+            elif (
+                not DISABLE_MKLDNN_MODEL_BL
+                and run_mode.startswith("mkldnn")
+                and self._model_name in MKLDNN_BLOCKLIST
+            ):
                 logging.warning(
                     f"The model({self._model_name}) is not supported to run in MKLDNN mode! Using `paddle` instead!"
                 )

+ 8 - 0
paddlex/utils/flags.py

@@ -51,7 +51,15 @@ FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", True)
 USE_PIR_TRT = get_flag_from_env_var("PADDLE_PDX_USE_PIR_TRT", True)
 DISABLE_DEV_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_DEV_MODEL_WL", False)
 DISABLE_CINN_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_CINN_MODEL_WL", False)
+DISABLE_TRT_MODEL_BL = get_flag_from_env_var("PADDLE_PDX_DISABLE_TRT_MODEL_BL", False)
+DISABLE_MKLDNN_MODEL_BL = get_flag_from_env_var(
+    "PADDLE_PDX_DISABLE_MKLDNN_MODEL_BL", False
+)
 LOCAL_FONT_FILE_PATH = get_flag_from_env_var("PADDLE_PDX_LOCAL_FONT_FILE_PATH", None)
+ENABLE_MKLDNN_BYDEFAULT = get_flag_from_env_var(
+    "PADDLE_PDX_ENABLE_MKLDNN_BYDEFAULT", True
+)
+
 MODEL_SOURCE = os.environ.get("PADDLE_PDX_MODEL_SOURCE", "huggingface")