|
|
@@ -92,8 +92,8 @@ def set_env_for_device(device):
|
|
|
logging.debug(f"{key} has been set to {val}.")
|
|
|
|
|
|
device_type, device_ids = parse_device(device)
|
|
|
- if device_type.lower() in ["gpu", "xpu", "npu", "mlu", "gcu"]:
|
|
|
- if device_type.lower() == "gpu" and paddle.is_compiled_with_rocm():
|
|
|
+ if device_type.lower() in ["gpu", "xpu", "npu", "mlu", "gcu", "dcu"]:
|
|
|
+ if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
|
|
|
envs = {"FLAGS_conv_workspace_size_limit": "2000"}
|
|
|
_set(envs)
|
|
|
if device_type.lower() == "npu":
|