Эх сурвалжийг харах

support to disable device model white list (#3306)

Tingquan Gao 9 сар өмнө
parent
commit
e66748f238

+ 6 - 0
paddlex/utils/device.py

@@ -17,6 +17,7 @@ import GPUtil
 
 import lazy_paddle as paddle
 from . import logging
+from .flags import DISABLE_DEV_MODEL_WL
 from .errors import raise_unsupported_device_error
 from .custom_device_whitelist import (
     DCU_WHITELIST,
@@ -122,6 +123,11 @@ def set_env_for_device(device):
 
 
 def check_supported_device(device, model_name):
+    if DISABLE_DEV_MODEL_WL:
+        logging.warning(
+            "Skip checking if model is supported on device because the flag `PADDLE_PDX_DISABLE_DEV_MODEL_WL` has been set."
+        )
+        return
     device_type, device_ids = parse_device(device)
     if device_type == "dcu":
         assert (

+ 2 - 0
paddlex/utils/flags.py

@@ -27,6 +27,7 @@ __all__ = [
     "INFER_BENCHMARK_DATA_SIZE",
     "FLAGS_json_format_model",
     "USE_PIR_TRT",
+    "DISABLE_DEV_MODEL_WL",
 ]
 
 
@@ -48,6 +49,7 @@ CHECK_OPTS = get_flag_from_env_var("PADDLE_PDX_CHECK_OPTS", False)
 EAGER_INITIALIZATION = get_flag_from_env_var("PADDLE_PDX_EAGER_INIT", True)
 FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", None)
 USE_PIR_TRT = get_flag_from_env_var("PADDLE_PDX_USE_PIR_TRT", False)
+DISABLE_DEV_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_DEV_MODEL_WL", False)
 
 # Inference Benchmark
 INFER_BENCHMARK = get_flag_from_env_var("PADDLE_PDX_INFER_BENCHMARK", None)