distributed.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. from typing import Any, Union
  16. import numpy as np
  17. import paddle
  18. import paddle.distributed as distributed
  19. from .utils import device_guard
  20. world_size = distributed.get_world_size()
  21. def convert_file_size_to_int(size: Union[int, str]):
  22. """
  23. Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
  24. Args:
  25. size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
  26. """
  27. if isinstance(size, int):
  28. return size
  29. if size.upper().endswith("GIB"):
  30. return int(size[:-3]) * (2**30)
  31. if size.upper().endswith("MIB"):
  32. return int(size[:-3]) * (2**20)
  33. if size.upper().endswith("KIB"):
  34. return int(size[:-3]) * (2**10)
  35. if size.upper().endswith("GB"):
  36. int_size = int(size[:-2]) * (10**9)
  37. return int_size // 8 if size.endswith("b") else int_size
  38. if size.upper().endswith("MB"):
  39. int_size = int(size[:-2]) * (10**6)
  40. return int_size // 8 if size.endswith("b") else int_size
  41. if size.upper().endswith("KB"):
  42. int_size = int(size[:-2]) * (10**3)
  43. return int_size // 8 if size.endswith("b") else int_size
  44. raise ValueError(
  45. "`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'."
  46. )
  47. def reduce_tensor(tensor, buffer_size="32MiB"):
  48. if tensor.dtype == paddle.int8:
  49. numel = np.prod(tensor.shape)
  50. else:
  51. numel = int(paddle.numel(tensor).item())
  52. buffer_size = convert_file_size_to_int(buffer_size)
  53. tensor.reshape_([-1])
  54. send_size = buffer_size // dtype_byte_size(tensor.dtype)
  55. for x in range(0, numel, send_size):
  56. part_tensor = tensor[x : min(numel, x + send_size)]
  57. yield part_tensor, (x, min(numel, x + send_size))
  58. def dtype_byte_size(dtype):
  59. """
  60. Returns the size (in bytes) occupied by one parameter of type `dtype`.
  61. """
  62. if dtype == paddle.bool:
  63. return 1 / 8
  64. bit_search = re.search(r"[^\d](\d+)$", str(dtype))
  65. if bit_search is None:
  66. raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
  67. bit_size = int(bit_search.groups()[0])
  68. return bit_size // 8
  69. @paddle.no_grad()
  70. def distributed_gather(tensor: Any, dst: int = 0, group=None, offload=False) -> Any:
  71. try:
  72. if isinstance(tensor, (tuple, list)):
  73. return type(tensor)(
  74. distributed_gather(t, dst, group, offload) for t in tensor
  75. )
  76. if isinstance(tensor, dict):
  77. return {
  78. k: distributed_gather(v, dst, group, offload) for k, v in tensor.items()
  79. }
  80. output_tensors = None
  81. is_dst = dst == distributed.get_rank(group=group)
  82. if is_dst:
  83. if offload:
  84. output_tensors = [
  85. [] for _ in range(distributed.get_world_size(group=group))
  86. ]
  87. else:
  88. output_tensors = [
  89. paddle.empty_like(tensor)
  90. for _ in range(distributed.get_world_size(group=group))
  91. ]
  92. output_tensors = [
  93. t if len(t.shape) > 0 else t[None] for t in output_tensors
  94. ]
  95. if offload:
  96. origin_shape = tensor.shape
  97. tensor.reshape_([-1])
  98. for slice_tensor, index in reduce_tensor(tensor):
  99. slice_output_tensors = None
  100. if distributed.get_rank(group=group) == dst:
  101. slice_output_tensors = [
  102. paddle.empty_like(slice_tensor)
  103. for _ in range(distributed.get_world_size(group=group))
  104. ]
  105. paddle.distributed.communication.stream.gather(
  106. slice_tensor,
  107. slice_output_tensors,
  108. dst=group.ranks[dst] if group else dst,
  109. group=group,
  110. sync_op=True,
  111. use_calc_stream=False,
  112. )
  113. if is_dst:
  114. for i in range(len(output_tensors)):
  115. output_tensors[i].append(slice_output_tensors[i].cpu().numpy())
  116. tensor.reshape_(origin_shape)
  117. if is_dst:
  118. with device_guard("cpu"):
  119. new_output_tensors = []
  120. for x in output_tensors:
  121. t = np.concatenate(x)
  122. t = t.reshape(origin_shape)
  123. new_output_tensors.append(t)
  124. output_tensors = new_output_tensors
  125. else:
  126. paddle.distributed.communication.stream.gather(
  127. tensor,
  128. output_tensors,
  129. dst=group.ranks[dst] if group else dst,
  130. group=group,
  131. sync_op=True,
  132. use_calc_stream=False,
  133. )
  134. return output_tensors
  135. except AssertionError:
  136. raise AssertionError("Not currently using distributed training")
  137. @paddle.no_grad()
  138. def distributed_allgather(tensor: Any, group=None, offload=False):
  139. """nested all gather function with offload
  140. Args:
  141. tensor (Any): the desired tensor, list of tensor, dict of tensor to allgather.
  142. group (_type_, optional): the communication group. Defaults to None.
  143. offload (bool, optional): If True, we offload the received tensor to cpu/(numpy). Defaults to False.
  144. Raises:
  145. AssertionError: Unexpected errors.
  146. Returns:
  147. tensor list: list of all gathered tensors
  148. """
  149. try:
  150. if isinstance(tensor, (tuple, list)):
  151. return type(tensor)(
  152. distributed_allgather(t, group, offload) for t in tensor
  153. )
  154. if isinstance(tensor, dict):
  155. return {
  156. k: distributed_allgather(v, group, offload) for k, v in tensor.items()
  157. }
  158. output_tensors = []
  159. if offload:
  160. with device_guard("cpu"):
  161. output_tensors = [
  162. paddle.empty_like(tensor)
  163. for _ in range(distributed.get_world_size(group))
  164. ]
  165. else:
  166. output_tensors = [
  167. paddle.empty_like(tensor)
  168. for _ in range(distributed.get_world_size(group))
  169. ]
  170. output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
  171. if offload:
  172. origin_shape = tensor.shape
  173. tensor.reshape_([-1])
  174. for x in output_tensors:
  175. x.reshape_([-1])
  176. for slice_tensor, index in reduce_tensor(tensor):
  177. slice_output_tensors = [
  178. paddle.empty_like(slice_tensor)
  179. for _ in range(distributed.get_world_size(group))
  180. ]
  181. distributed.all_gather(slice_output_tensors, slice_tensor, group=group)
  182. for x, y in zip(slice_output_tensors, output_tensors):
  183. with device_guard("cpu"):
  184. y[index[0] : index[1]] = x.cpu()
  185. tensor.reshape_(origin_shape)
  186. for x in output_tensors:
  187. x.reshape_(origin_shape)
  188. else:
  189. distributed.all_gather(output_tensors, tensor)
  190. return output_tensors
  191. except AssertionError:
  192. raise AssertionError("Not currently using distributed training")