device.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import GPUtil
  16. import lazy_paddle as paddle
  17. from . import logging
  18. from .errors import raise_unsupported_device_error
  19. SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]
  20. def _constr_device(device_type, device_ids):
  21. if device_ids:
  22. device_ids = ",".join(map(str, device_ids))
  23. return f"{device_type}:{device_ids}"
  24. else:
  25. return f"{device_type}"
  26. def get_default_device():
  27. avail_gpus = GPUtil.getAvailable()
  28. if not avail_gpus:
  29. return "cpu"
  30. else:
  31. return _constr_device("gpu", [avail_gpus[0]])
  32. def parse_device(device):
  33. """parse_device"""
  34. # According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
  35. parts = device.split(":")
  36. if len(parts) > 2:
  37. raise ValueError(f"Invalid device: {device}")
  38. if len(parts) == 1:
  39. device_type, device_ids = parts[0], None
  40. else:
  41. device_type, device_ids = parts
  42. device_ids = device_ids.split(",")
  43. for device_id in device_ids:
  44. if not device_id.isdigit():
  45. raise ValueError(
  46. f"Device ID must be an integer. Invalid device ID: {device_id}"
  47. )
  48. device_ids = list(map(int, device_ids))
  49. device_type = device_type.lower()
  50. # raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
  51. assert device_type.lower() in SUPPORTED_DEVICE_TYPE
  52. return device_type, device_ids
  53. def update_device_num(device, num):
  54. device_type, device_ids = parse_device(device)
  55. if device_ids:
  56. assert len(device_ids) >= num
  57. return _constr_device(device_type, device_ids[:num])
  58. else:
  59. return _constr_device(device_type, device_ids)
  60. def set_env_for_device(device):
  61. def _set(envs):
  62. for key, val in envs.items():
  63. os.environ[key] = val
  64. logging.debug(f"{key} has been set to {val}.")
  65. device_type, device_ids = parse_device(device)
  66. if device_type.lower() in ["gpu", "xpu", "npu", "mlu"]:
  67. if device_type.lower() == "gpu" and paddle.is_compiled_with_rocm():
  68. envs = {"FLAGS_conv_workspace_size_limit": "2000"}
  69. _set(envs)
  70. if device_type.lower() == "npu":
  71. envs = {
  72. "FLAGS_npu_jit_compile": "0",
  73. "FLAGS_use_stride_kernel": "0",
  74. "FLAGS_allocator_strategy": "auto_growth",
  75. "CUSTOM_DEVICE_BLACK_LIST": "pad3d,pad3d_grad,set_value,set_value_with_tensor",
  76. "FLAGS_npu_scale_aclnn": "True",
  77. "FLAGS_npu_split_aclnn": "True",
  78. }
  79. _set(envs)
  80. if device_type.lower() == "xpu":
  81. envs = {
  82. "BKCL_FORCE_SYNC": "1",
  83. "BKCL_TIMEOUT": "1800",
  84. "FLAGS_use_stride_kernel": "0",
  85. }
  86. _set(envs)
  87. if device_type.lower() == "mlu":
  88. envs = {"FLAGS_use_stride_kernel": "0"}
  89. _set(envs)