|
|
@@ -132,7 +132,10 @@ def set_env_for_device_type(device_type):
|
|
|
}
|
|
|
_set(envs)
|
|
|
if device_type.lower() == "mlu":
|
|
|
- envs = {"FLAGS_use_stride_kernel": "0"}
|
|
|
+ envs = {
|
|
|
+ "FLAGS_use_stride_kernel": "0",
|
|
|
+ "FLAGS_use_stream_safe_cuda_allocator": "0",
|
|
|
+ }
|
|
|
_set(envs)
|
|
|
if device_type.lower() == "gcu":
|
|
|
envs = {"FLAGS_use_stride_kernel": "0"}
|