Procházet zdrojové kódy

Add GPU compatibility check (#4499)

timminator před 2 měsíci
rodič
revize
30621ce497
2 změnil soubory, kde provedl 22 přidání a 1 odebrání
  1. 18 1
      paddlex/inference/utils/pp_option.py
  2. 4 0
      paddlex/utils/flags.py

+ 18 - 1
paddlex/inference/utils/pp_option.py

@@ -18,7 +18,7 @@ from typing import Dict, List
 
 from ...utils import logging
 from ...utils.device import get_default_device, parse_device, set_env_for_device_type
-from ...utils.flags import ENABLE_MKLDNN_BYDEFAULT, USE_PIR_TRT
+from ...utils.flags import ENABLE_MKLDNN_BYDEFAULT, USE_PIR_TRT, DISABLE_DEVICE_FALLBACK
 from .misc import is_mkldnn_available
 from .mkldnn_blocklist import MKLDNN_BLOCKLIST
 from .new_ir_blocklist import NEWIR_BLOCKLIST
@@ -81,6 +81,23 @@ class PaddlePredictorOption(object):
         for k, v in self._get_default_config(model_name).items():
             self._cfg.setdefault(k, v)
 
+        if self.device_type == "gpu":
+            import paddle
+
+            if not (paddle.device.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0):
+                if DISABLE_DEVICE_FALLBACK:
+                    raise RuntimeError(
+                        "Device fallback is disabled and the specified device (GPU) is not available. "
+                        "To fall back to CPU instead, unset the PADDLE_PDX_DISABLE_DEVICE_FALLBACK environment variable."
+                    )
+                else:
+                    logging.warning(
+                        "The specified device (GPU) is not available! Switching to CPU instead."
+                    )
+                self.device_type = "cpu"
+                self.run_mode = get_default_run_mode(model_name, "cpu")
+                self.device_id = None
+
         # for trt
         if self.run_mode in ("trt_int8", "trt_fp32", "trt_fp16"):
             trt_cfg_setting = TRT_CFG_SETTING[model_name]

+ 4 - 0
paddlex/utils/flags.py

@@ -29,6 +29,7 @@ __all__ = [
     "USE_PIR_TRT",
     "DISABLE_DEV_MODEL_WL",
     "DISABLE_CINN_MODEL_WL",
+    "DISABLE_DEVICE_FALLBACK",
 ]
 
 
@@ -60,6 +61,9 @@ LOCAL_FONT_FILE_PATH = get_flag_from_env_var("PADDLE_PDX_LOCAL_FONT_FILE_PATH",
 ENABLE_MKLDNN_BYDEFAULT = get_flag_from_env_var(
     "PADDLE_PDX_ENABLE_MKLDNN_BYDEFAULT", True
 )
+DISABLE_DEVICE_FALLBACK = get_flag_from_env_var(
+    "PADDLE_PDX_DISABLE_DEVICE_FALLBACK", False
+)
 
 MODEL_SOURCE = os.environ.get("PADDLE_PDX_MODEL_SOURCE", "huggingface").lower()