device.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from contextlib import ContextDecorator
  16. from . import logging
  17. from .custom_device_list import (
  18. DCU_WHITELIST,
  19. GCU_WHITELIST,
  20. MLU_WHITELIST,
  21. NPU_BLACKLIST,
  22. XPU_WHITELIST,
  23. )
  24. from .flags import DISABLE_DEV_MODEL_WL
  25. SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu", "gcu", "dcu"]
  26. def constr_device(device_type, device_ids):
  27. if device_type == "cpu" and device_ids is not None:
  28. raise ValueError("`device_ids` must be None for CPUs")
  29. if device_ids:
  30. device_ids = ",".join(map(str, device_ids))
  31. return f"{device_type}:{device_ids}"
  32. else:
  33. return f"{device_type}"
  34. def get_default_device():
  35. import paddle
  36. if paddle.device.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0:
  37. return constr_device("gpu", [0])
  38. else:
  39. return "cpu"
  40. def parse_device(device):
  41. """parse_device"""
  42. # According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
  43. parts = device.split(":")
  44. if len(parts) > 2:
  45. raise ValueError(f"Invalid device: {device}")
  46. if len(parts) == 1:
  47. device_type, device_ids = parts[0], None
  48. else:
  49. device_type, device_ids = parts
  50. device_ids = device_ids.split(",")
  51. for device_id in device_ids:
  52. if not device_id.isdigit():
  53. raise ValueError(
  54. f"Device ID must be an integer. Invalid device ID: {device_id}"
  55. )
  56. device_ids = list(map(int, device_ids))
  57. device_type = device_type.lower()
  58. # raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
  59. assert device_type.lower() in SUPPORTED_DEVICE_TYPE
  60. if device_type == "cpu" and device_ids is not None:
  61. raise ValueError("No Device ID should be specified for CPUs")
  62. return device_type, device_ids
  63. def update_device_num(device, num):
  64. device_type, device_ids = parse_device(device)
  65. if device_ids:
  66. assert len(device_ids) >= num
  67. return constr_device(device_type, device_ids[:num])
  68. else:
  69. return constr_device(device_type, device_ids)
  70. def set_env_for_device(device):
  71. device_type, _ = parse_device(device)
  72. return set_env_for_device_type(device_type)
  73. def set_env_for_device_type(device_type):
  74. import paddle
  75. def _set(envs):
  76. for key, val in envs.items():
  77. os.environ[key] = val
  78. logging.debug(f"{key} has been set to {val}.")
  79. # XXX: is_compiled_with_rocm() must be True on dcu platform ?
  80. if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
  81. envs = {"FLAGS_conv_workspace_size_limit": "2000"}
  82. _set(envs)
  83. if device_type.lower() == "npu":
  84. envs = {
  85. "FLAGS_npu_jit_compile": "0",
  86. "FLAGS_use_stride_kernel": "0",
  87. "FLAGS_allocator_strategy": "auto_growth",
  88. "CUSTOM_DEVICE_BLACK_LIST": "pad3d,pad3d_grad,set_value,set_value_with_tensor",
  89. "FLAGS_npu_scale_aclnn": "True",
  90. "FLAGS_npu_split_aclnn": "True",
  91. }
  92. _set(envs)
  93. if device_type.lower() == "xpu":
  94. envs = {
  95. "BKCL_FORCE_SYNC": "1",
  96. "BKCL_TIMEOUT": "1800",
  97. "FLAGS_use_stride_kernel": "0",
  98. "XPU_BLACK_LIST": "pad3d",
  99. }
  100. _set(envs)
  101. if device_type.lower() == "mlu":
  102. envs = {
  103. "FLAGS_use_stride_kernel": "0",
  104. "FLAGS_use_stream_safe_cuda_allocator": "0",
  105. }
  106. _set(envs)
  107. if device_type.lower() == "gcu":
  108. envs = {"FLAGS_use_stride_kernel": "0"}
  109. _set(envs)
  110. def check_supported_device_type(device_type, model_name):
  111. if DISABLE_DEV_MODEL_WL:
  112. logging.warning(
  113. "Skip checking if model is supported on device because the flag `PADDLE_PDX_DISABLE_DEV_MODEL_WL` has been set."
  114. )
  115. return
  116. tips = "You could set env `PADDLE_PDX_DISABLE_DEV_MODEL_WL` to `true` to disable this checking."
  117. if device_type == "dcu":
  118. assert model_name in DCU_WHITELIST, (
  119. f"The DCU device does not yet support `{model_name}` model!" + tips
  120. )
  121. elif device_type == "mlu":
  122. assert model_name in MLU_WHITELIST, (
  123. f"The MLU device does not yet support `{model_name}` model!" + tips
  124. )
  125. elif device_type == "npu":
  126. assert model_name not in NPU_BLACKLIST, (
  127. f"The NPU device does not yet support `{model_name}` model!" + tips
  128. )
  129. elif device_type == "xpu":
  130. assert model_name in XPU_WHITELIST, (
  131. f"The XPU device does not yet support `{model_name}` model!" + tips
  132. )
  133. elif device_type == "gcu":
  134. assert model_name in GCU_WHITELIST, (
  135. f"The GCU device does not yet support `{model_name}` model!" + tips
  136. )
  137. def check_supported_device(device, model_name):
  138. device_type, _ = parse_device(device)
  139. return check_supported_device_type(device_type, model_name)
  140. class TemporaryDeviceChanger(ContextDecorator):
  141. """
  142. A context manager to temporarily change global device
  143. """
  144. def __init__(self, new_device):
  145. # if new_device is None, nothing changed
  146. import paddle
  147. self.new_device = new_device
  148. self.original_device = paddle.device.get_device()
  149. def __enter__(self):
  150. import paddle
  151. if self.new_device is None:
  152. return self
  153. paddle.device.set_device(self.new_device)
  154. return self
  155. def __exit__(self, exc_type, exc_val, exc_tb):
  156. import paddle
  157. if self.new_device is None:
  158. return False
  159. paddle.device.set_device(self.original_device)
  160. return False