utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import inspect
  16. import os
  17. from contextlib import ExitStack
  18. from typing import ContextManager, List
  19. import paddle
  20. from ..utils import device_guard
  21. class ContextManagers:
  22. """
  23. Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
  24. in the `fastcore` library.
  25. """
  26. def __init__(self, context_managers: List[ContextManager]):
  27. self.context_managers = context_managers
  28. self.stack = ExitStack()
  29. def __enter__(self):
  30. for context_manager in self.context_managers:
  31. self.stack.enter_context(context_manager)
  32. def __exit__(self, *args, **kwargs):
  33. self.stack.__exit__(*args, **kwargs)
  34. def fn_args_to_dict(func, *args, **kwargs):
  35. """
  36. Inspect function `func` and its arguments for running, and extract a
  37. dict mapping between argument names and keys.
  38. """
  39. if hasattr(inspect, "getfullargspec"):
  40. (spec_args, spec_varargs, spec_varkw, spec_defaults, _, _, _) = (
  41. inspect.getfullargspec(func)
  42. )
  43. else:
  44. (spec_args, spec_varargs, spec_varkw, spec_defaults) = inspect.getargspec(func)
  45. # add positional argument values
  46. init_dict = dict(zip(spec_args, args))
  47. # add default argument values
  48. kwargs_dict = (
  49. dict(zip(spec_args[-len(spec_defaults) :], spec_defaults))
  50. if spec_defaults
  51. else {}
  52. )
  53. for k in list(kwargs_dict.keys()):
  54. if k in init_dict:
  55. kwargs_dict.pop(k)
  56. kwargs_dict.update(kwargs)
  57. init_dict.update(kwargs_dict)
  58. return init_dict
  59. def get_checkpoint_shard_files(
  60. pretrained_model_name_or_path,
  61. index_filename,
  62. cache_dir=None,
  63. subfolder="",
  64. **kwargs,
  65. ):
  66. """
  67. For a given model:
  68. - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
  69. Hub
  70. - returns the list of paths to all the shards, as well as some metadata.
  71. For the description of each arg, see [`PretrainedModel.from_pretrained`]. `index_filename` is the full path to the
  72. index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
  73. """
  74. import json
  75. if not os.path.isfile(index_filename):
  76. raise ValueError(
  77. f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}."
  78. )
  79. with open(index_filename, "r") as f:
  80. index = json.loads(f.read())
  81. shard_filenames = sorted(set(index["weight_map"].values()))
  82. sharded_metadata = index["metadata"]
  83. sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
  84. sharded_metadata["weight_map"] = index["weight_map"].copy()
  85. file_map = {file: set() for file in shard_filenames}
  86. for weight, file in index["weight_map"].items():
  87. file_map[file].add(weight)
  88. sharded_metadata["file_map"] = file_map
  89. # Adapt for PaddleX
  90. assert os.path.isdir(pretrained_model_name_or_path)
  91. shard_filenames = [
  92. os.path.join(pretrained_model_name_or_path, subfolder, f)
  93. for f in shard_filenames
  94. ]
  95. return shard_filenames, sharded_metadata
  96. def is_paddle_support_lazy_init():
  97. return hasattr(paddle, "LazyGuard")
  98. def is_safetensors_available():
  99. return importlib.util.find_spec("safetensors") is not None
  100. def paddlenlp_load(path, map_location="cpu"):
  101. assert map_location in ["cpu", "gpu", "xpu", "npu", "numpy", "np"]
  102. if map_location in ["numpy", "np"]:
  103. return paddle.load(path, return_numpy=True)
  104. else:
  105. with device_guard(map_location):
  106. return paddle.load(path)
  107. # TODO(zhonghui03): the following code has problems when hot start optimizer checkpoint.
  108. if map_location == "cpu":
  109. from paddle.framework.io import (
  110. _parse_every_object,
  111. _to_LodTensor,
  112. _transformed_from_lodtensor,
  113. )
  114. def _ndarray_to_tensor(obj, return_numpy=False):
  115. if return_numpy:
  116. return obj
  117. if paddle.in_dynamic_mode():
  118. return paddle.Tensor(obj, zero_copy=True)
  119. else:
  120. return _to_LodTensor(obj)
  121. state_dict = paddle.load(path, return_numpy=True)
  122. # Hack for zero copy for saving loading time. for paddle.load there need copy to create paddle.Tensor
  123. return _parse_every_object(
  124. state_dict, _transformed_from_lodtensor, _ndarray_to_tensor
  125. )
  126. else:
  127. return paddle.load(path)
  128. def use_hybrid_parallel():
  129. try:
  130. from paddle.distributed import fleet
  131. hcg = fleet.get_hybrid_communicate_group()
  132. return hcg
  133. except:
  134. return None
  135. def weight_name_suffix():
  136. hcg = use_hybrid_parallel()
  137. if hcg is not None:
  138. name = []
  139. if hcg.get_model_parallel_world_size() > 1:
  140. name.append(f"tp{hcg.get_model_parallel_rank():0>2d}")
  141. if hcg.get_pipe_parallel_world_size() > 1:
  142. name.append(f"pp{hcg.get_stage_id():0>2d}")
  143. return "_".join(name)
  144. else:
  145. return None