gaotingquan hai 9 meses
pai
achega
2dd7e3dbf9
Modificáronse 1 ficheiros con 28 adicións e 28 borrados
  1. 28 28
      paddlex/utils/device.py

+ 28 - 28
paddlex/utils/device.py

@@ -92,34 +92,34 @@ def set_env_for_device(device):
             logging.debug(f"{key} has been set to {val}.")
 
     device_type, device_ids = parse_device(device)
-    if device_type.lower() in ["gpu", "xpu", "npu", "mlu", "gcu", "dcu"]:
-        if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
-            envs = {"FLAGS_conv_workspace_size_limit": "2000"}
-            _set(envs)
-        if device_type.lower() == "npu":
-            envs = {
-                "FLAGS_npu_jit_compile": "0",
-                "FLAGS_use_stride_kernel": "0",
-                "FLAGS_allocator_strategy": "auto_growth",
-                "CUSTOM_DEVICE_BLACK_LIST": "pad3d,pad3d_grad,set_value,set_value_with_tensor",
-                "FLAGS_npu_scale_aclnn": "True",
-                "FLAGS_npu_split_aclnn": "True",
-            }
-            _set(envs)
-        if device_type.lower() == "xpu":
-            envs = {
-                "BKCL_FORCE_SYNC": "1",
-                "BKCL_TIMEOUT": "1800",
-                "FLAGS_use_stride_kernel": "0",
-                "XPU_BLACK_LIST": "pad3d",
-            }
-            _set(envs)
-        if device_type.lower() == "mlu":
-            envs = {"FLAGS_use_stride_kernel": "0"}
-            _set(envs)
-        if device_type.lower() == "gcu":
-            envs = {"FLAGS_use_stride_kernel": "0"}
-            _set(envs)
+    # XXX: is_compiled_with_rocm() must be True on dcu platform ?
+    if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
+        envs = {"FLAGS_conv_workspace_size_limit": "2000"}
+        _set(envs)
+    if device_type.lower() == "npu":
+        envs = {
+            "FLAGS_npu_jit_compile": "0",
+            "FLAGS_use_stride_kernel": "0",
+            "FLAGS_allocator_strategy": "auto_growth",
+            "CUSTOM_DEVICE_BLACK_LIST": "pad3d,pad3d_grad,set_value,set_value_with_tensor",
+            "FLAGS_npu_scale_aclnn": "True",
+            "FLAGS_npu_split_aclnn": "True",
+        }
+        _set(envs)
+    if device_type.lower() == "xpu":
+        envs = {
+            "BKCL_FORCE_SYNC": "1",
+            "BKCL_TIMEOUT": "1800",
+            "FLAGS_use_stride_kernel": "0",
+            "XPU_BLACK_LIST": "pad3d",
+        }
+        _set(envs)
+    if device_type.lower() == "mlu":
+        envs = {"FLAGS_use_stride_kernel": "0"}
+        _set(envs)
+    if device_type.lower() == "gcu":
+        envs = {"FLAGS_use_stride_kernel": "0"}
+        _set(envs)
 
 
 def check_supported_device(device, model_name):