|
|
@@ -15,6 +15,8 @@
|
|
|
import os
|
|
|
import GPUtil
|
|
|
import lazy_paddle as paddle
|
|
|
+
|
|
|
+from . import logging
|
|
|
from .errors import raise_unsupported_device_error
|
|
|
|
|
|
SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]
|
|
|
@@ -69,22 +71,33 @@ def update_device_num(device, num):
|
|
|
|
|
|
|
|
|
def set_env_for_device(device):
|
|
|
+ def _set(envs):
|
|
|
+ for key, val in envs.items():
|
|
|
+ os.environ[key] = val
|
|
|
+ logging.debug(f"{key} has been set to {val}.")
|
|
|
+
|
|
|
device_type, device_ids = parse_device(device)
|
|
|
if device_type.lower() in ["gpu", "xpu", "npu", "mlu"]:
|
|
|
if device_type.lower() == "gpu" and paddle.is_compiled_with_rocm():
|
|
|
- os.environ["FLAGS_conv_workspace_size_limit"] = "2000"
|
|
|
+ envs = {"FLAGS_conv_workspace_size_limit": "2000"}
|
|
|
+ _set(envs)
|
|
|
if device_type.lower() == "npu":
|
|
|
- os.environ["FLAGS_npu_jit_compile"] = "0"
|
|
|
- os.environ["FLAGS_use_stride_kernel"] = "0"
|
|
|
- os.environ["FLAGS_allocator_strategy"] = "auto_growth"
|
|
|
- os.environ["CUSTOM_DEVICE_BLACK_LIST"] = (
|
|
|
- "pad3d,pad3d_grad,set_value,set_value_with_tensor"
|
|
|
- )
|
|
|
- os.environ["FLAGS_npu_scale_aclnn"] = "True"
|
|
|
- os.environ["FLAGS_npu_split_aclnn"] = "True"
|
|
|
+ envs = {
|
|
|
+ "FLAGS_npu_jit_compile": "0",
|
|
|
+ "FLAGS_use_stride_kernel": "0",
|
|
|
+ "FLAGS_allocator_strategy": "auto_growth",
|
|
|
+ "CUSTOM_DEVICE_BLACK_LIST": "pad3d,pad3d_grad,set_value,set_value_with_tensor",
|
|
|
+ "FLAGS_npu_scale_aclnn": "True",
|
|
|
+ "FLAGS_npu_split_aclnn": "True",
|
|
|
+ }
|
|
|
+ _set(envs)
|
|
|
if device_type.lower() == "xpu":
|
|
|
- os.environ["BKCL_FORCE_SYNC"] = "1"
|
|
|
- os.environ["BKCL_TIMEOUT"] = "1800"
|
|
|
- os.environ["FLAGS_use_stride_kernel"] = "0"
|
|
|
+ envs = {
|
|
|
+ "BKCL_FORCE_SYNC": "1",
|
|
|
+ "BKCL_TIMEOUT": "1800",
|
|
|
+ "FLAGS_use_stride_kernel": "0",
|
|
|
+ }
|
|
|
+ _set(envs)
|
|
|
if device_type.lower() == "mlu":
|
|
|
- os.environ["FLAGS_use_stride_kernel"] = "0"
|
|
|
+ envs = {"FLAGS_use_stride_kernel": "0"}
|
|
|
+ _set(envs)
|