utils.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import os
  16. from contextlib import ExitStack
  17. from typing import ContextManager, List
  18. import paddle
  19. from ..utils import device_guard
  20. class ContextManagers:
  21. """
  22. Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
  23. in the `fastcore` library.
  24. """
  25. def __init__(self, context_managers: List[ContextManager]):
  26. self.context_managers = context_managers
  27. self.stack = ExitStack()
  28. def __enter__(self):
  29. for context_manager in self.context_managers:
  30. self.stack.enter_context(context_manager)
  31. def __exit__(self, *args, **kwargs):
  32. self.stack.__exit__(*args, **kwargs)
  33. def fn_args_to_dict(func, *args, **kwargs):
  34. """
  35. Inspect function `func` and its arguments for running, and extract a
  36. dict mapping between argument names and keys.
  37. """
  38. if hasattr(inspect, "getfullargspec"):
  39. (spec_args, spec_varargs, spec_varkw, spec_defaults, _, _, _) = (
  40. inspect.getfullargspec(func)
  41. )
  42. else:
  43. (spec_args, spec_varargs, spec_varkw, spec_defaults) = inspect.getargspec(func)
  44. # add positional argument values
  45. init_dict = dict(zip(spec_args, args))
  46. # add default argument values
  47. kwargs_dict = (
  48. dict(zip(spec_args[-len(spec_defaults) :], spec_defaults))
  49. if spec_defaults
  50. else {}
  51. )
  52. for k in list(kwargs_dict.keys()):
  53. if k in init_dict:
  54. kwargs_dict.pop(k)
  55. kwargs_dict.update(kwargs)
  56. init_dict.update(kwargs_dict)
  57. return init_dict
  58. def get_checkpoint_shard_files(
  59. pretrained_model_name_or_path,
  60. index_filename,
  61. cache_dir=None,
  62. subfolder="",
  63. **kwargs,
  64. ):
  65. """
  66. For a given model:
  67. - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
  68. Hub
  69. - returns the list of paths to all the shards, as well as some metadata.
  70. For the description of each arg, see [`PretrainedModel.from_pretrained`]. `index_filename` is the full path to the
  71. index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
  72. """
  73. import json
  74. if not os.path.isfile(index_filename):
  75. raise ValueError(
  76. f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}."
  77. )
  78. with open(index_filename, "r") as f:
  79. index = json.loads(f.read())
  80. shard_filenames = sorted(set(index["weight_map"].values()))
  81. sharded_metadata = index["metadata"]
  82. sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
  83. sharded_metadata["weight_map"] = index["weight_map"].copy()
  84. file_map = {file: set() for file in shard_filenames}
  85. for weight, file in index["weight_map"].items():
  86. file_map[file].add(weight)
  87. sharded_metadata["file_map"] = file_map
  88. # Adapt for PaddleX
  89. assert os.path.isdir(pretrained_model_name_or_path)
  90. shard_filenames = [
  91. os.path.join(pretrained_model_name_or_path, subfolder, f)
  92. for f in shard_filenames
  93. ]
  94. return shard_filenames, sharded_metadata
  95. def is_paddle_support_lazy_init():
  96. return hasattr(paddle, "LazyGuard")
  97. def paddlenlp_load(path, map_location="cpu"):
  98. assert map_location in ["cpu", "gpu", "xpu", "npu", "numpy", "np"]
  99. if map_location in ["numpy", "np"]:
  100. return paddle.load(path, return_numpy=True)
  101. else:
  102. with device_guard(map_location):
  103. return paddle.load(path)
  104. # TODO(zhonghui03): the following code has problems when hot start optimizer checkpoint.
  105. if map_location == "cpu":
  106. from paddle.framework.io import (
  107. _parse_every_object,
  108. _to_LodTensor,
  109. _transformed_from_lodtensor,
  110. )
  111. def _ndarray_to_tensor(obj, return_numpy=False):
  112. if return_numpy:
  113. return obj
  114. if paddle.in_dynamic_mode():
  115. return paddle.Tensor(obj, zero_copy=True)
  116. else:
  117. return _to_LodTensor(obj)
  118. state_dict = paddle.load(path, return_numpy=True)
  119. # Hack for zero copy for saving loading time. for paddle.load there need copy to create paddle.Tensor
  120. return _parse_every_object(
  121. state_dict, _transformed_from_lodtensor, _ndarray_to_tensor
  122. )
  123. else:
  124. return paddle.load(path)
  125. def use_hybrid_parallel():
  126. try:
  127. from paddle.distributed import fleet
  128. hcg = fleet.get_hybrid_communicate_group()
  129. return hcg
  130. except:
  131. return None
  132. def weight_name_suffix():
  133. hcg = use_hybrid_parallel()
  134. if hcg is not None:
  135. name = []
  136. if hcg.get_model_parallel_world_size() > 1:
  137. name.append(f"tp{hcg.get_model_parallel_rank():0>2d}")
  138. if hcg.get_pipe_parallel_world_size() > 1:
  139. name.append(f"pp{hcg.get_stage_id():0>2d}")
  140. return "_".join(name)
  141. else:
  142. return None