|
@@ -92,34 +92,34 @@ def set_env_for_device(device):
|
|
|
logging.debug(f"{key} has been set to {val}.")
|
|
logging.debug(f"{key} has been set to {val}.")
|
|
|
|
|
|
|
|
device_type, device_ids = parse_device(device)
|
|
device_type, device_ids = parse_device(device)
|
|
|
- if device_type.lower() in ["gpu", "xpu", "npu", "mlu", "gcu", "dcu"]:
|
|
|
|
|
- if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
|
|
|
|
|
- envs = {"FLAGS_conv_workspace_size_limit": "2000"}
|
|
|
|
|
- _set(envs)
|
|
|
|
|
- if device_type.lower() == "npu":
|
|
|
|
|
- envs = {
|
|
|
|
|
- "FLAGS_npu_jit_compile": "0",
|
|
|
|
|
- "FLAGS_use_stride_kernel": "0",
|
|
|
|
|
- "FLAGS_allocator_strategy": "auto_growth",
|
|
|
|
|
- "CUSTOM_DEVICE_BLACK_LIST": "pad3d,pad3d_grad,set_value,set_value_with_tensor",
|
|
|
|
|
- "FLAGS_npu_scale_aclnn": "True",
|
|
|
|
|
- "FLAGS_npu_split_aclnn": "True",
|
|
|
|
|
- }
|
|
|
|
|
- _set(envs)
|
|
|
|
|
- if device_type.lower() == "xpu":
|
|
|
|
|
- envs = {
|
|
|
|
|
- "BKCL_FORCE_SYNC": "1",
|
|
|
|
|
- "BKCL_TIMEOUT": "1800",
|
|
|
|
|
- "FLAGS_use_stride_kernel": "0",
|
|
|
|
|
- "XPU_BLACK_LIST": "pad3d",
|
|
|
|
|
- }
|
|
|
|
|
- _set(envs)
|
|
|
|
|
- if device_type.lower() == "mlu":
|
|
|
|
|
- envs = {"FLAGS_use_stride_kernel": "0"}
|
|
|
|
|
- _set(envs)
|
|
|
|
|
- if device_type.lower() == "gcu":
|
|
|
|
|
- envs = {"FLAGS_use_stride_kernel": "0"}
|
|
|
|
|
- _set(envs)
|
|
|
|
|
|
|
+ # XXX: is_compiled_with_rocm() must be True on dcu platform ?
|
|
|
|
|
+ if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
|
|
|
|
|
+ envs = {"FLAGS_conv_workspace_size_limit": "2000"}
|
|
|
|
|
+ _set(envs)
|
|
|
|
|
+ if device_type.lower() == "npu":
|
|
|
|
|
+ envs = {
|
|
|
|
|
+ "FLAGS_npu_jit_compile": "0",
|
|
|
|
|
+ "FLAGS_use_stride_kernel": "0",
|
|
|
|
|
+ "FLAGS_allocator_strategy": "auto_growth",
|
|
|
|
|
+ "CUSTOM_DEVICE_BLACK_LIST": "pad3d,pad3d_grad,set_value,set_value_with_tensor",
|
|
|
|
|
+ "FLAGS_npu_scale_aclnn": "True",
|
|
|
|
|
+ "FLAGS_npu_split_aclnn": "True",
|
|
|
|
|
+ }
|
|
|
|
|
+ _set(envs)
|
|
|
|
|
+ if device_type.lower() == "xpu":
|
|
|
|
|
+ envs = {
|
|
|
|
|
+ "BKCL_FORCE_SYNC": "1",
|
|
|
|
|
+ "BKCL_TIMEOUT": "1800",
|
|
|
|
|
+ "FLAGS_use_stride_kernel": "0",
|
|
|
|
|
+ "XPU_BLACK_LIST": "pad3d",
|
|
|
|
|
+ }
|
|
|
|
|
+ _set(envs)
|
|
|
|
|
+ if device_type.lower() == "mlu":
|
|
|
|
|
+ envs = {"FLAGS_use_stride_kernel": "0"}
|
|
|
|
|
+ _set(envs)
|
|
|
|
|
+ if device_type.lower() == "gcu":
|
|
|
|
|
+ envs = {"FLAGS_use_stride_kernel": "0"}
|
|
|
|
|
+ _set(envs)
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_supported_device(device, model_name):
|
|
def check_supported_device(device, model_name):
|