|
@@ -31,6 +31,8 @@ SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu", "gcu", "dcu"]
|
|
|
|
|
|
|
|
|
|
|
|
|
def constr_device(device_type, device_ids):
|
|
def constr_device(device_type, device_ids):
|
|
|
|
|
+ if device_type == "cpu" and device_ids is not None:
|
|
|
|
|
+ raise ValueError("`device_ids` must be None for CPUs")
|
|
|
if device_ids:
|
|
if device_ids:
|
|
|
device_ids = ",".join(map(str, device_ids))
|
|
device_ids = ",".join(map(str, device_ids))
|
|
|
return f"{device_type}:{device_ids}"
|
|
return f"{device_type}:{device_ids}"
|
|
@@ -73,6 +75,8 @@ def parse_device(device):
|
|
|
device_type = device_type.lower()
|
|
device_type = device_type.lower()
|
|
|
# raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
|
|
# raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
|
|
|
assert device_type.lower() in SUPPORTED_DEVICE_TYPE
|
|
assert device_type.lower() in SUPPORTED_DEVICE_TYPE
|
|
|
|
|
+ if device_type == "cpu" and device_ids is not None:
|
|
|
|
|
+ raise ValueError("No Device ID should be specified for CPUs")
|
|
|
return device_type, device_ids
|
|
return device_type, device_ids
|
|
|
|
|
|
|
|
|
|
|
|
@@ -86,12 +90,16 @@ def update_device_num(device, num):
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_env_for_device(device):
|
|
def set_env_for_device(device):
|
|
|
|
|
+ device_type, _ = parse_device(device)
|
|
|
|
|
+ return set_env_for_device_type(device_type)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def set_env_for_device_type(device_type):
|
|
|
def _set(envs):
|
|
def _set(envs):
|
|
|
for key, val in envs.items():
|
|
for key, val in envs.items():
|
|
|
os.environ[key] = val
|
|
os.environ[key] = val
|
|
|
logging.debug(f"{key} has been set to {val}.")
|
|
logging.debug(f"{key} has been set to {val}.")
|
|
|
|
|
|
|
|
- device_type, device_ids = parse_device(device)
|
|
|
|
|
# XXX: is_compiled_with_rocm() must be True on dcu platform ?
|
|
# XXX: is_compiled_with_rocm() must be True on dcu platform ?
|
|
|
if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
|
|
if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
|
|
|
envs = {"FLAGS_conv_workspace_size_limit": "2000"}
|
|
envs = {"FLAGS_conv_workspace_size_limit": "2000"}
|
|
@@ -122,17 +130,12 @@ def set_env_for_device(device):
|
|
|
_set(envs)
|
|
_set(envs)
|
|
|
|
|
|
|
|
|
|
|
|
|
-def check_supported_device(device, model_name):
|
|
|
|
|
|
|
+def check_supported_device_type(device_type, model_name):
|
|
|
if DISABLE_DEV_MODEL_WL:
|
|
if DISABLE_DEV_MODEL_WL:
|
|
|
logging.warning(
|
|
logging.warning(
|
|
|
"Skip checking if model is supported on device because the flag `PADDLE_PDX_DISABLE_DEV_MODEL_WL` has been set."
|
|
"Skip checking if model is supported on device because the flag `PADDLE_PDX_DISABLE_DEV_MODEL_WL` has been set."
|
|
|
)
|
|
)
|
|
|
return
|
|
return
|
|
|
- device_type, device_ids = parse_device(device)
|
|
|
|
|
- return check_supported_device_type(device_type, model_name)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def check_supported_device_type(device_type, model_name):
|
|
|
|
|
if device_type == "dcu":
|
|
if device_type == "dcu":
|
|
|
assert (
|
|
assert (
|
|
|
model_name in DCU_WHITELIST
|
|
model_name in DCU_WHITELIST
|
|
@@ -153,3 +156,8 @@ def check_supported_device_type(device_type, model_name):
|
|
|
assert (
|
|
assert (
|
|
|
model_name in GCU_WHITELIST
|
|
model_name in GCU_WHITELIST
|
|
|
), f"The GCU device does not yet support `{model_name}` model!"
|
|
), f"The GCU device does not yet support `{model_name}` model!"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def check_supported_device(device, model_name):
|
|
|
|
|
+ device_type, _ = parse_device(device)
|
|
|
|
|
+ return check_supported_device_type(device_type, model_name)
|