瀏覽代碼

fix env setting when using dcu

gaotingquan 9 月之前
父節點
當前提交
75402dfe8a
共有 1 個文件被更改,包括 2 次插入2 次删除
  1. 2 2
      paddlex/utils/device.py

+ 2 - 2
paddlex/utils/device.py

@@ -92,8 +92,8 @@ def set_env_for_device(device):
             logging.debug(f"{key} has been set to {val}.")
 
     device_type, device_ids = parse_device(device)
-    if device_type.lower() in ["gpu", "xpu", "npu", "mlu", "gcu"]:
-        if device_type.lower() == "gpu" and paddle.is_compiled_with_rocm():
+    if device_type.lower() in ["gpu", "xpu", "npu", "mlu", "gcu", "dcu"]:
+        if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
             envs = {"FLAGS_conv_workspace_size_limit": "2000"}
             _set(envs)
         if device_type.lower() == "npu":