# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import os from contextlib import ExitStack from typing import ContextManager, List import paddle from ..utils import device_guard class ContextManagers: """ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` in the `fastcore` library. """ def __init__(self, context_managers: List[ContextManager]): self.context_managers = context_managers self.stack = ExitStack() def __enter__(self): for context_manager in self.context_managers: self.stack.enter_context(context_manager) def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs) def fn_args_to_dict(func, *args, **kwargs): """ Inspect function `func` and its arguments for running, and extract a dict mapping between argument names and keys. """ if hasattr(inspect, "getfullargspec"): (spec_args, spec_varargs, spec_varkw, spec_defaults, _, _, _) = ( inspect.getfullargspec(func) ) else: (spec_args, spec_varargs, spec_varkw, spec_defaults) = inspect.getargspec(func) # add positional argument values init_dict = dict(zip(spec_args, args)) # add default argument values kwargs_dict = ( dict(zip(spec_args[-len(spec_defaults) :], spec_defaults)) if spec_defaults else {} ) for k in list(kwargs_dict.keys()): if k in init_dict: kwargs_dict.pop(k) kwargs_dict.update(kwargs) init_dict.update(kwargs_dict) return init_dict def get_checkpoint_shard_files( pretrained_model_name_or_path, index_filename, cache_dir=None, subfolder="", **kwargs, ): """ For a given model: - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the Hub - returns the list of paths to all the shards, as well as some metadata. For the description of each arg, see [`PretrainedModel.from_pretrained`]. `index_filename` is the full path to the index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). """ import json if not os.path.isfile(index_filename): raise ValueError( f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}." ) with open(index_filename, "r") as f: index = json.loads(f.read()) shard_filenames = sorted(set(index["weight_map"].values())) sharded_metadata = index["metadata"] sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) sharded_metadata["weight_map"] = index["weight_map"].copy() file_map = {file: set() for file in shard_filenames} for weight, file in index["weight_map"].items(): file_map[file].add(weight) sharded_metadata["file_map"] = file_map # Adapt for PaddleX assert os.path.isdir(pretrained_model_name_or_path) shard_filenames = [ os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames ] return shard_filenames, sharded_metadata def is_paddle_support_lazy_init(): return hasattr(paddle, "LazyGuard") def paddlenlp_load(path, map_location="cpu"): assert map_location in ["cpu", "gpu", "xpu", "npu", "numpy", "np"] if map_location in ["numpy", "np"]: return paddle.load(path, return_numpy=True) else: with device_guard(map_location): return paddle.load(path) # TODO(zhonghui03): the following code has problems when hot start optimizer checkpoint. if map_location == "cpu": from paddle.framework.io import ( _parse_every_object, _to_LodTensor, _transformed_from_lodtensor, ) def _ndarray_to_tensor(obj, return_numpy=False): if return_numpy: return obj if paddle.in_dynamic_mode(): return paddle.Tensor(obj, zero_copy=True) else: return _to_LodTensor(obj) state_dict = paddle.load(path, return_numpy=True) # Hack for zero copy for saving loading time. for paddle.load there need copy to create paddle.Tensor return _parse_every_object( state_dict, _transformed_from_lodtensor, _ndarray_to_tensor ) else: return paddle.load(path) def use_hybrid_parallel(): try: from paddle.distributed import fleet hcg = fleet.get_hybrid_communicate_group() return hcg except: return None def weight_name_suffix(): hcg = use_hybrid_parallel() if hcg is not None: name = [] if hcg.get_model_parallel_world_size() > 1: name.append(f"tp{hcg.get_model_parallel_rank():0>2d}") if hcg.get_pipe_parallel_world_size() > 1: name.append(f"pp{hcg.get_stage_id():0>2d}") return "_".join(name) else: return None