|
|
@@ -15,6 +15,7 @@
|
|
|
|
|
|
|
|
|
import os
|
|
|
+import paddle
|
|
|
from .errors import raise_unsupported_device_error
|
|
|
|
|
|
SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]
|
|
|
@@ -26,12 +27,15 @@ def get_device(device_cfg, using_device_number=None):
|
|
|
device = device_cfg.split(":")[0]
|
|
|
assert device.lower() in SUPPORTED_DEVICE_TYPE
|
|
|
if device.lower() in ["gpu", "xpu", "npu", "mlu"]:
|
|
|
+ if device.lower() == "gpu" and paddle.is_compiled_with_rocm():
|
|
|
+ os.environ['FLAGS_conv_workspace_size_limit'] = '2000'
|
|
|
if device.lower() == "npu":
|
|
|
os.environ["FLAGS_npu_jit_compile"] = "0"
|
|
|
os.environ["FLAGS_use_stride_kernel"] = "0"
|
|
|
os.environ["FLAGS_allocator_strategy"] = "auto_growth"
|
|
|
os.environ[
|
|
|
"CUSTOM_DEVICE_BLACK_LIST"] = "pad3d,pad3d_grad,set_value,set_value_with_tensor"
|
|
|
+
|
|
|
if len(device_cfg.split(":")) == 2:
|
|
|
device_ids = device_cfg.split(":")[1]
|
|
|
else:
|