model_utils.py 86 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import gc
  16. import os
  17. import re
  18. import warnings
  19. from contextlib import contextmanager
  20. from functools import partial
  21. from pathlib import Path
  22. from typing import Any, Dict, List, Optional, Tuple, Type, Union
  23. import numpy as np
  24. import paddle
  25. import paddle.nn as nn
  26. from paddle import Tensor
  27. from paddle.distributed.fleet.meta_parallel.parallel_layers import PipelineLayer
  28. try:
  29. from paddle.distributed.fleet.meta_parallel import LocalSharedLayerDesc
  30. except:
  31. LocalSharedLayerDesc = None
  32. from paddle.nn import Layer
  33. from ......utils import logging
  34. from ......utils.deps import is_dep_available, require_deps
  35. from ...tokenizer.tokenizer_utils import InitTrackerMeta, adapt_stale_fwd_patch
  36. from ..generation import GenerationConfig, GenerationMixin
  37. from ..utils import (
  38. ASYMMETRY_QUANT_SCALE_MAX,
  39. ASYMMETRY_QUANT_SCALE_MIN,
  40. CONFIG_NAME,
  41. LEGACY_CONFIG_NAME,
  42. PADDLE_WEIGHTS_INDEX_NAME,
  43. PADDLE_WEIGHTS_NAME,
  44. PYTORCH_WEIGHTS_INDEX_NAME,
  45. PYTORCH_WEIGHTS_NAME,
  46. SAFE_WEIGHTS_INDEX_NAME,
  47. SAFE_WEIGHTS_NAME,
  48. SYMMETRY_QUANT_SCALE,
  49. device_guard,
  50. resolve_file_path,
  51. )
  52. from .configuration_utils import PretrainedConfig
  53. from .conversion_utils import ConversionMixin
  54. from .utils import (
  55. ContextManagers,
  56. fn_args_to_dict,
  57. get_checkpoint_shard_files,
  58. paddlenlp_load,
  59. weight_name_suffix,
  60. )
  61. __all__ = [
  62. "PretrainedModel",
  63. ]
  64. def _add_variant(weights_name: str, variant=None) -> str:
  65. if variant is not None and len(variant) > 0:
  66. splits = weights_name.split(".")
  67. splits = splits[:-1] + [variant] + splits[-1:]
  68. weights_name = ".".join(splits)
  69. return weights_name
  70. @contextmanager
  71. def dtype_guard(dtype="float32"):
  72. origin_dtype = paddle.get_default_dtype()
  73. paddle.set_default_dtype(dtype)
  74. try:
  75. yield
  76. finally:
  77. paddle.set_default_dtype(origin_dtype)
  78. _init_weights = True
  79. @contextmanager
  80. def no_init_weights(_enable=True):
  81. """
  82. Context manager to globally disable weight initialization to speed up loading large models.
  83. TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
  84. """
  85. global _init_weights
  86. old_init_weights = _init_weights
  87. if _enable:
  88. _init_weights = False
  89. try:
  90. yield
  91. finally:
  92. _init_weights = old_init_weights
  93. def _split_keys_evenly(keys: list, n: int) -> list:
  94. """Split a list into n lists with an equal number of elements.
  95. Args:
  96. keys (list): the list to be split
  97. n (int): number of splits
  98. Returns:
  99. result: list of lists
  100. """
  101. total_len = len(keys)
  102. base_size = total_len // n
  103. extra = total_len % n
  104. result = []
  105. index = 0
  106. for _ in range(n):
  107. part_size = base_size + 1 if extra > 0 else base_size
  108. extra -= 1
  109. result.append(keys[index : index + part_size])
  110. index += part_size
  111. return result
  112. def _load_part_state_dict_from_safetensors(
  113. keys,
  114. checkpoint_file: Union[str, os.PathLike],
  115. tensor_parallel_split_mapping,
  116. fliter_dict_keys,
  117. device,
  118. quantization_linear_list=None,
  119. quantization_config=None,
  120. dtype=None,
  121. return_numpy=False,
  122. convert_from_hf=False,
  123. transpose_weight_keys=None,
  124. ):
  125. import paddle
  126. from safetensors import safe_open
  127. if transpose_weight_keys:
  128. transpose_weight_keys = set(transpose_weight_keys)
  129. def _is_need_transpose(key):
  130. if "lora" not in key and convert_from_hf and transpose_weight_keys:
  131. return key in transpose_weight_keys
  132. def _transpose_hf_weight(key, weight):
  133. if _is_need_transpose(key):
  134. return weight.transpose([-1, -2])
  135. return weight
  136. part_state_dict = {}
  137. scale_dict = {}
  138. with safe_open(checkpoint_file, framework="paddle") as f:
  139. for key in keys:
  140. # 1. non-merge ckpt loading dont have filter key.
  141. # 2. merge ckpt will skip quant scale by `fliter_dict_keys`
  142. if (
  143. key.endswith(SYMMETRY_QUANT_SCALE)
  144. or key.endswith(ASYMMETRY_QUANT_SCALE_MIN)
  145. or key.endswith(ASYMMETRY_QUANT_SCALE_MAX)
  146. ):
  147. continue
  148. if fliter_dict_keys is not None and key not in fliter_dict_keys:
  149. continue
  150. py_safe_slice_ = f.get_slice(key)
  151. if (
  152. quantization_linear_list is not None
  153. and key.split(".weight")[0] in quantization_linear_list
  154. and not key.endswith("_scale")
  155. ):
  156. raise NotImplementedError
  157. else:
  158. if key in tensor_parallel_split_mapping:
  159. tp_fn = tensor_parallel_split_mapping[key]
  160. if _is_need_transpose(key):
  161. assert isinstance(tp_fn, partial)
  162. is_column = True
  163. if "is_column" in tp_fn.keywords:
  164. is_column = tp_fn.keywords["is_column"]
  165. is_column = not is_column
  166. tp_fn = partial(
  167. tp_fn.func,
  168. *tp_fn.args,
  169. **{**tp_fn.keywords, "is_column": is_column},
  170. )
  171. if len(py_safe_slice_.shape) == 0:
  172. weight = tp_fn(py_safe_slice_[:])
  173. else:
  174. weight = tp_fn(py_safe_slice_)
  175. else:
  176. weight = py_safe_slice_[:]
  177. if not return_numpy and device == "expected":
  178. weight = weight._copy_to(
  179. paddle.framework._current_expected_place(), False
  180. )
  181. weight = _transpose_hf_weight(key, weight)
  182. if return_numpy:
  183. weight = weight.numpy()
  184. part_state_dict[key] = weight
  185. for key in keys:
  186. if (
  187. key.endswith(SYMMETRY_QUANT_SCALE)
  188. or key.endswith(ASYMMETRY_QUANT_SCALE_MIN)
  189. or key.endswith(ASYMMETRY_QUANT_SCALE_MAX)
  190. ):
  191. scale = f.get_tensor(key)
  192. if not return_numpy and device == "expected":
  193. scale = scale._copy_to(
  194. paddle.framework._current_expected_place(), False
  195. )
  196. if return_numpy:
  197. scale = scale.numpy()
  198. scale_dict[key] = scale
  199. return part_state_dict, scale_dict
  200. def load_state_dict(
  201. checkpoint_file: Union[str, os.PathLike],
  202. tensor_parallel_split_mapping=None,
  203. fliter_dict_keys=None,
  204. device="cpu",
  205. ckpt_quant_stage="O0",
  206. convert_from_hf=False,
  207. transpose_weight_keys=None,
  208. ):
  209. """
  210. Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
  211. """
  212. if tensor_parallel_split_mapping is None:
  213. tensor_parallel_split_mapping = {}
  214. if Path(checkpoint_file).suffix == ".safetensors":
  215. require_deps("safetensors")
  216. from safetensors import safe_open
  217. with safe_open(checkpoint_file, framework="paddle") as f:
  218. state_dict, scale_dict = _load_part_state_dict_from_safetensors(
  219. list(f.keys()),
  220. checkpoint_file,
  221. tensor_parallel_split_mapping,
  222. fliter_dict_keys,
  223. "expected",
  224. dtype=None,
  225. return_numpy=False,
  226. convert_from_hf=convert_from_hf,
  227. transpose_weight_keys=transpose_weight_keys,
  228. )
  229. else:
  230. state_dict = paddlenlp_load(checkpoint_file, map_location="cpu")
  231. return state_dict
  232. _re_layer_prefix = re.compile(r"\.(\d+)\.")
  233. def _load_state_dict_into_model(
  234. model_to_load, state_dict, start_prefix, convert_from_hf
  235. ):
  236. # torch will cast dtype in load_state_dict, but paddle strictly check dtype
  237. _convert_state_dict_dtype_and_shape(state_dict, model_to_load)
  238. error_msgs = []
  239. if len(start_prefix) > 0:
  240. for key in list(state_dict.keys()):
  241. if key.startswith(start_prefix):
  242. state_dict[key.replace(start_prefix, "")] = state_dict.pop(key)
  243. # TODO: add return status to state_dict
  244. with warnings.catch_warnings(record=True) as w:
  245. warnings.resetwarnings()
  246. # paddlenlp hold missing_keys , just ignore not found warnings.
  247. warnings.filterwarnings(
  248. "ignore", message=r".*is not found in the provided dict.*"
  249. )
  250. warnings.filterwarnings("ignore", message=r".*paddle.to_tensor.*")
  251. if convert_from_hf:
  252. try:
  253. model_to_load.set_hf_state_dict(state_dict)
  254. except NotImplementedError:
  255. pass
  256. model_to_load.set_state_dict(state_dict)
  257. error_msgs.extend([str(x.message) for x in w])
  258. del state_dict
  259. return error_msgs
  260. def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
  261. # convert the dtype of state dict
  262. def is_0d_or_1d(tensor):
  263. return len(tensor.shape) == 0 or list(tensor.shape) == [1]
  264. for key, value in model_to_load.state_dict().items():
  265. if key in list(state_dict.keys()):
  266. if isinstance(state_dict[key], np.ndarray):
  267. raise ValueError(
  268. "convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, please convert numpy.ndarray to paddle.Tensor"
  269. )
  270. # confirm parameter cast is executed on the same device as model
  271. # TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
  272. if (
  273. state_dict[key].is_floating_point()
  274. and state_dict[key].dtype != value.dtype
  275. ):
  276. state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
  277. # unified 0d and 1d tensor
  278. if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
  279. if list(value.shape) != list(state_dict[key].shape):
  280. state_dict[key] = paddle.reshape(state_dict.pop(key), value.shape)
  281. def _load_state_dict_into_meta_model(
  282. model,
  283. state_dict,
  284. loaded_state_dict_keys, # left for now but could be removed, see below
  285. start_prefix,
  286. expected_keys,
  287. dtype=None,
  288. is_safetensors=False,
  289. keep_in_fp32_modules=None,
  290. ):
  291. """
  292. This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
  293. params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
  294. params back to the normal device, but only for `loaded_state_dict_keys`.
  295. `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
  296. `bert.pooler.dense.weight`
  297. """
  298. from paddle.common_ops_import import convert_np_dtype_to_dtype_
  299. dtype = convert_np_dtype_to_dtype_(dtype)
  300. error_msgs = []
  301. model_state_dict = model.state_dict()
  302. for param_name, param in state_dict.items():
  303. # First part of the test is always true as loaded_state_dict_keys always contains state_dict keys.
  304. if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
  305. continue
  306. if param_name.startswith(start_prefix):
  307. param_name = param_name[len(start_prefix) :]
  308. if param.place != paddle.framework._current_expected_place():
  309. param = param._copy_to(paddle.framework._current_expected_place(), False)
  310. # # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
  311. # # in int/uint/bool and not cast them.
  312. if dtype is not None and paddle.is_floating_point(param):
  313. if (
  314. keep_in_fp32_modules is not None
  315. and any(
  316. module_to_keep_in_fp32 in param_name
  317. for module_to_keep_in_fp32 in keep_in_fp32_modules
  318. )
  319. and (dtype == paddle.float16 or dtype == paddle.bfloat16)
  320. ):
  321. param = param.astype(dtype=paddle.float32)
  322. else:
  323. param = param.astype(dtype=dtype)
  324. if dtype is None:
  325. old_param = model
  326. splits = param_name.split(".")
  327. for split in splits:
  328. old_param = getattr(old_param, split)
  329. if old_param is None:
  330. break
  331. if old_param is not None:
  332. param = param.astype(dtype=old_param.dtype)
  333. with paddle.no_grad():
  334. model_state_dict[param_name].get_tensor()._share_data_with(
  335. param.value().get_tensor()
  336. )
  337. param.value().get_tensor()._clear()
  338. return error_msgs
  339. class PretrainedModel(
  340. Layer, GenerationMixin, ConversionMixin, metaclass=InitTrackerMeta
  341. ):
  342. """
  343. The base class for all pretrained models. It mainly provides common methods
  344. for loading (construction and loading) and saving pretrained models. Loading
  345. and saving also rely on the following class attributes which should be overridden
  346. by derived classes accordingly:
  347. - **model_config_file** (str): Represents the file name of model configuration
  348. for configuration saving and loading in local file system. The value is
  349. `model_config.json`.
  350. - **resource_files_names** (dict): Name of local file where the model configuration
  351. can be saved and loaded locally. Currently, resources only include the model state,
  352. thus the dict only includes `'model_state'` as key with corresponding
  353. value `'model_state.pdparams'` for model weights saving and loading.
  354. - **pretrained_init_configuration** (dict): Provides the model configurations
  355. of built-in pretrained models (contrasts to models in local file system).
  356. It has pretrained model names as keys (such as `bert-base-uncased`), and
  357. the values are dict preserving corresponding configuration for model initialization.
  358. - **pretrained_resource_files_map** (dict): Provides resource URLs of built-in
  359. pretrained models (contrasts to models in local file system).
  360. It has the same key as resource_files_names (that is "model_state"),
  361. and the corresponding value is a dict with specific model name to model weights URL mapping
  362. (such as "bert-base-uncased" ->
  363. "https://bj.bcebos.com/paddlenlp/models/transformers/bert-base-uncased.pdparams").
  364. - **base_model_prefix** (str): Represents the attribute associated to the
  365. base model in derived classes of the same architecture adding layers on
  366. top of the base model. Note: A base model class is pretrained model class
  367. decorated by `register_base_model`, such as `BertModel`; A derived model
  368. class is a pretrained model class adding layers on top of the base model,
  369. and it has a base model as attribute, such as `BertForSequenceClassification`.
  370. Methods common to models for text generation are defined in `GenerationMixin`
  371. and also inherited here.
  372. Besides, metaclass `InitTrackerMeta` is used to create `PretrainedModel`,
  373. by which subclasses can track arguments for initialization automatically.
  374. """
  375. # Deprecated(wj-Mcat): after 2.6.* version
  376. # save the old-school `LEGACY_CONFIG_NAME`, and will be changed to `CONFIG_NAME` after 2.6.* version
  377. model_config_file = LEGACY_CONFIG_NAME
  378. pretrained_init_configuration = {}
  379. # TODO: more flexible resource handle, namedtuple with fields as:
  380. # resource_name, saved_file, handle_name_for_load(None for used as __init__
  381. # arguments), handle_name_for_save
  382. resource_files_names = {"model_state": PADDLE_WEIGHTS_NAME}
  383. pretrained_resource_files_map = {}
  384. base_model_prefix = ""
  385. main_input_name = "input_ids"
  386. config_class = None
  387. _keep_in_fp32_modules = None
  388. # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
  389. # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
  390. _keys_to_ignore_on_load_missing = None
  391. # a list of `re` patterns of `state_dict` keys that should be removed from the list of
  392. # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
  393. # warnings.
  394. _keys_to_ignore_on_load_unexpected = None
  395. # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
  396. # trained, but which are either deterministic or tied variables)
  397. _keys_to_ignore_on_save = None
  398. _tied_weights_keys = None
  399. def __init__(self, *args, **kwargs):
  400. super(PretrainedModel, self).__init__()
  401. if not self.constructed_from_pretrained_config():
  402. return
  403. # extract config from args
  404. config = None
  405. for arg in args:
  406. if isinstance(arg, PretrainedConfig):
  407. config = arg
  408. break
  409. if config is not None:
  410. self.config: PretrainedConfig = config
  411. self.model_config_file = CONFIG_NAME
  412. self.generation_config = (
  413. GenerationConfig.from_model_config(self.config)
  414. if self.can_generate()
  415. else None
  416. )
  417. return
  418. # extract config from kwargs
  419. if "config" not in kwargs:
  420. raise ValueError(
  421. "PretrainedConfig instance not found in the arguments, you can set it as args or kwargs with config field"
  422. )
  423. config = kwargs["config"]
  424. if not isinstance(config, PretrainedConfig):
  425. raise TypeError(
  426. "config parameter should be the instance of PretrainedConfig"
  427. )
  428. self.config: PretrainedConfig = kwargs["config"]
  429. self.generation_config = (
  430. GenerationConfig.from_model_config(self.config)
  431. if self.can_generate()
  432. else None
  433. )
  434. self.model_config_file = CONFIG_NAME
  435. self.warnings_issued = {}
  436. def _post_init(self, original_init, *args, **kwargs):
  437. """
  438. It would be hooked after `__init__` to add a dict including arguments of
  439. `__init__` as a attribute named `config` of the pretrained model instance.
  440. """
  441. if not self.constructed_from_pretrained_config():
  442. init_dict = fn_args_to_dict(original_init, *((self,) + args), **kwargs)
  443. self.config = init_dict
  444. # only execute when it's the base method
  445. if (
  446. original_init.__module__ != "paddlenlp.transformers.model_utils"
  447. and self.__class__.init_weights is PretrainedModel.init_weights
  448. ):
  449. self.init_weights()
  450. # Note:
  451. # 1. PipelineLayer will create parameters for each layer and
  452. # call `_synchronize_shared_weights()` to synchronize the shared parameters.
  453. # 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
  454. # synchronize the shared parameters.
  455. # However, `self._init_weights` will re-initialize the parameters without
  456. # synchronizing the shared parameters. If the following step does not load a checkpoint,
  457. # the shared parameters will be different.
  458. if isinstance(self, PipelineLayer):
  459. self._synchronize_shared_weights()
  460. def _init_weights(self, layer):
  461. """
  462. Initialize the weights. This method should be overridden by derived class.
  463. """
  464. pass
  465. def _initialize_weights(self, layer):
  466. """
  467. Initialize the weights if they are not already initialized.
  468. """
  469. if getattr(layer, "_is_initialized", False):
  470. return
  471. self._init_weights(layer)
  472. layer._is_initialized = True
  473. def init_weights(self):
  474. """
  475. If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
  476. initialization logic in `_init_weights`.
  477. """
  478. # call pure
  479. if _init_weights:
  480. # Initialize weights
  481. self.apply(self._initialize_weights)
  482. # Tie weights should be skipped when not initializing all weights
  483. # since from_pretrained(...) calls tie weights anyways
  484. # TODO(wj-Mcat): enable all tie-weights later
  485. # self.tie_weights()
  486. @classmethod
  487. def _from_config(cls, config, **kwargs):
  488. """
  489. All context managers that the model should be initialized under go here.
  490. Args:
  491. dtype (`paddle.dtype`, *optional*):
  492. Override the default `paddle.dtype` and load the model under this dtype.
  493. """
  494. dtype = kwargs.pop("dtype", None)
  495. if dtype is None:
  496. if config.dtype is not None:
  497. dtype = config.dtype
  498. else:
  499. dtype = paddle.get_default_dtype()
  500. with dtype_guard(dtype):
  501. model = cls(config, **kwargs)
  502. return model
  503. @classmethod
  504. def from_config(cls, config, **kwargs):
  505. """
  506. All context managers that the model should be initialized under go here.
  507. Args:
  508. dtype (`paddle.dtype`, *optional*):
  509. Override the default `paddle.dtype` and load the model under this dtype.
  510. """
  511. return cls._from_config(config, **kwargs)
  512. @classmethod
  513. def set_inference_config(cls, config, predictor_args, **kwargs):
  514. """
  515. All inference config can set here.
  516. Args:
  517. config : PretrainedConfig
  518. The config of the model.
  519. predictor_args : PredictorArgument
  520. The args of the predictor.
  521. """
  522. tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", 1)
  523. tensor_parallel_rank = kwargs.pop("tensor_parallel_rank", 0)
  524. if predictor_args.mode == "dynamic" or predictor_args.speculate_method in [
  525. "eagle",
  526. "mtp",
  527. ]:
  528. config.tensor_parallel_degree = tensor_parallel_degree
  529. config.tensor_parallel_rank = tensor_parallel_rank
  530. config.model_name_or_path = predictor_args.model_name_or_path
  531. config.quant_type = predictor_args.quant_type
  532. config.cachekv_int8_type = predictor_args.cachekv_int8_type
  533. config.use_fake_parameter = predictor_args.use_fake_parameter
  534. config.single_card_ptq = not predictor_args.use_fake_parameter
  535. config.append_attn = predictor_args.append_attn
  536. config.decode_strategy = predictor_args.decode_strategy
  537. config.mla_use_matrix_absorption = predictor_args.mla_use_matrix_absorption
  538. config.weightonly_group_size = predictor_args.weightonly_group_size
  539. config.weight_block_size = predictor_args.weight_block_size
  540. config.moe_quant_type = predictor_args.moe_quant_type
  541. if predictor_args.block_attn:
  542. config.block_size = predictor_args.block_size
  543. config.max_seq_len = predictor_args.total_max_length
  544. if predictor_args.speculate_method is not None:
  545. config.speculate_method = predictor_args.speculate_method
  546. config.speculate_max_draft_token_num = (
  547. predictor_args.speculate_max_draft_token_num
  548. )
  549. config.speculate_verify_window = predictor_args.speculate_verify_window
  550. config.speculate_max_candidate_len = (
  551. predictor_args.speculate_max_candidate_len
  552. )
  553. if predictor_args.speculate_method == "inference_with_reference":
  554. config.speculate_max_ngram_size = (
  555. predictor_args.speculate_max_ngram_size
  556. )
  557. if predictor_args.speculate_method is not None:
  558. if not config.get("speculate_model_type", "None") in ["eagle", "mtp"]:
  559. config.decode_strategy = "speculate_decoding"
  560. config.return_full_hidden_states = predictor_args.return_full_hidden_states
  561. @classmethod
  562. def confirm_inference_model(cls, predictor_args, **kwargs):
  563. """
  564. Confirm the inference model whether it need to change the AVX inference Model
  565. Args:
  566. model : PretrainedModel
  567. The model for inference.
  568. predictor_args : PredictorArgument
  569. The args of the predictor.
  570. """
  571. return cls
  572. @property
  573. def base_model(self):
  574. """
  575. PretrainedModel: The body of the same model architecture. It is the base
  576. model itself for base model or the base model attribute for derived
  577. model.
  578. """
  579. return getattr(self, self.base_model_prefix, self)
  580. @property
  581. def model_name_list(self):
  582. """
  583. list: Contains all supported built-in pretrained model names of the
  584. current PretrainedModel class.
  585. """
  586. # Todo: return all model name
  587. return list(self.pretrained_init_configuration.keys())
  588. def can_generate(self) -> bool:
  589. """
  590. Returns whether this model can generate sequences with `.generate()`.
  591. Returns:
  592. `bool`: Whether this model can generate sequences with `.generate()`.
  593. """
  594. # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
  595. if "GenerationMixin" in str(self.prepare_inputs_for_generation):
  596. return False
  597. return True
  598. def recompute_enable(self):
  599. r"""
  600. Enable Recompute.
  601. All layers with the `enable_recompute` attribute will be set to `True`
  602. """
  603. def fn(layer):
  604. if hasattr(layer, "enable_recompute") and (
  605. layer.enable_recompute is False or layer.enable_recompute == 0
  606. ):
  607. layer.enable_recompute = True
  608. self.apply(fn)
  609. def recompute_disable(self):
  610. r"""
  611. Disable Recompute.
  612. All layers with the `enable_recompute` attribute will be set to `False`
  613. """
  614. def fn(layer):
  615. if hasattr(layer, "enable_recompute") and (
  616. layer.enable_recompute is False or layer.enable_recompute == 0
  617. ):
  618. layer.enable_recompute = True
  619. self.apply(fn)
  620. def tie_weights(self):
  621. """
  622. Tie the weights between the input embeddings and the output embeddings.
  623. """
  624. if self.config.tie_word_embeddings:
  625. output_embeddings = self.get_output_embeddings()
  626. input_embeddings = self.get_input_embeddings()
  627. if output_embeddings is not None and input_embeddings is not None:
  628. if input_embeddings.weight.shape != output_embeddings.weight.shape:
  629. logging.warning(
  630. f"The shape of input embeddings is {input_embeddings.weight.shape} and the shape of output embeddings is {output_embeddings.weight.shape}. "
  631. "This is only expected if you are calling the `resize_token_embeddings` method"
  632. )
  633. output_embeddings.weight = input_embeddings.weight
  634. if getattr(output_embeddings, "bias", None) is not None:
  635. # need to pad
  636. if (
  637. output_embeddings.weight.shape[0]
  638. > output_embeddings.bias.shape[0]
  639. ):
  640. old_bias = output_embeddings.bias
  641. pad_length = (
  642. output_embeddings.weight.shape[0] - old_bias.shape[0]
  643. )
  644. output_embeddings.bias = output_embeddings.create_parameter(
  645. shape=[output_embeddings.weight.shape[0]],
  646. attr=output_embeddings._bias_attr,
  647. dtype=output_embeddings._dtype,
  648. is_bias=True,
  649. )
  650. new_bias = paddle.concat(
  651. [
  652. old_bias,
  653. paddle.zeros(
  654. [pad_length], dtype=output_embeddings.bias.dtype
  655. ),
  656. ]
  657. )
  658. output_embeddings.bias.set_value(new_bias)
  659. # need to trim
  660. elif (
  661. output_embeddings.weight.shape[0]
  662. < output_embeddings.bias.shape[0]
  663. ):
  664. new_bias = output_embeddings.bias[
  665. : output_embeddings.weight.shape[0]
  666. ]
  667. output_embeddings.bias = output_embeddings.create_parameter(
  668. shape=[output_embeddings.weight.shape[0]],
  669. attr=output_embeddings._bias_attr,
  670. dtype=output_embeddings._dtype,
  671. is_bias=True,
  672. )
  673. output_embeddings.bias.set_value(new_bias)
  674. def resize_position_embeddings(self, new_num_position_embeddings: int):
  675. """resize position embedding, this method should be overrited overwrited by downstream models
  676. Args:
  677. new_num_position_embeddings (int): the new position size
  678. Raises:
  679. NotImplementedError: when called and not be implemented
  680. """
  681. raise NotImplementedError(
  682. f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
  683. f"overwrite this method in the class {self.__class__} in `{self.__class__.__module__}.py`"
  684. )
  685. @classmethod
  686. def constructed_from_pretrained_config(cls, init_func=None) -> bool:
  687. """check if the model is constructed from `PretrainedConfig`
  688. Returns:
  689. bool: if the model is constructed from `PretrainedConfig`
  690. """
  691. return cls.config_class is not None and issubclass(
  692. cls.config_class, PretrainedConfig
  693. )
  694. def resize_token_embeddings(
  695. self, new_num_tokens: Optional[int] = None
  696. ) -> nn.Embedding:
  697. """
  698. Resizes input token embeddings matrix of the model according to new_num_tokens.
  699. Args:
  700. new_num_tokens (Optional[int]):
  701. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  702. vectors at the end. Reducing the size will remove vectors from the end. If not provided or None, just
  703. returns a pointer to the input tokens embedding module of the model without doing anything.
  704. Returns:
  705. paddle.nn.Embedding: The input tokens Embeddings Module of the model.
  706. """
  707. old_embeddings: nn.Embedding = self.get_input_embeddings()
  708. if not new_num_tokens or new_num_tokens == old_embeddings.weight.shape[0]:
  709. return old_embeddings
  710. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
  711. self.set_input_embeddings(new_embeddings)
  712. # 2. Update vocab_size
  713. self.base_model.config["vocab_size"] = new_num_tokens
  714. self.vocab_size = new_num_tokens
  715. # update init_config
  716. self._update_init_config(self.init_config, "vocab_size", new_num_tokens)
  717. # Tie the weights between the input embeddings and the output embeddings if needed.
  718. self.tie_weights()
  719. return new_embeddings
  720. def _update_init_config(self, init_config: dict, key: str, value: Any):
  721. """update init_config by <key, value> pair
  722. Args:
  723. init_config (dict): the init_config instance
  724. key (str): the key field
  725. value (Any): the new value of instance
  726. """
  727. if key in init_config:
  728. init_config[key] = value
  729. return
  730. for arg in init_config.get("init_args", []):
  731. if not isinstance(arg, PretrainedModel):
  732. continue
  733. self._update_init_config(arg.init_config, key, value)
  734. def _get_resized_embeddings(
  735. self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None
  736. ) -> nn.Embedding:
  737. """
  738. Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
  739. initialized vectors at the end. Reducing the size will remove vectors from the end
  740. Args:
  741. old_embeddings (nn.Embedding):
  742. Old embeddings to be resized.
  743. new_num_tokens (Optional[int]):
  744. New number of tokens in the embedding matrix.
  745. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  746. vectors from the end.
  747. Returns:
  748. paddle.nn.Embedding: The resized Embedding Module or the old Embedding Module if new_num_tokens is None.
  749. """
  750. if new_num_tokens is None:
  751. return old_embeddings
  752. old_num_tokens, old_embedding_dim = old_embeddings.weight.shape
  753. if old_num_tokens == new_num_tokens:
  754. return old_embeddings
  755. if not isinstance(old_embeddings, nn.Embedding):
  756. raise TypeError(
  757. f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
  758. " should either use a different resize function or make sure that old_embeddings are an instance of"
  759. f" {nn.Embedding}."
  760. )
  761. # Build new embeddings
  762. new_embeddings = nn.Embedding(
  763. new_num_tokens,
  764. old_embedding_dim,
  765. padding_idx=old_embeddings._padding_idx,
  766. sparse=old_embeddings._sparse,
  767. )
  768. # make sure that new_embeddings's dtype is same as the old embeddings' dtype
  769. if new_embeddings.weight.dtype != old_embeddings.weight.dtype:
  770. new_embeddings.to(dtype=old_embeddings.weight.dtype)
  771. # numbers of tokens to copy
  772. n = min(old_num_tokens, new_num_tokens)
  773. with paddle.no_grad():
  774. new_embeddings.weight[:n, :] = old_embeddings.weight[:n, :]
  775. return new_embeddings
  776. def __setattr__(self, name, value):
  777. value = adapt_stale_fwd_patch(self, name, value)
  778. return super(PretrainedModel, self).__setattr__(name, value)
  779. @classmethod
  780. def _resolve_model_file_path(
  781. cls: Type[PretrainedModel],
  782. pretrained_model_name_or_path: str,
  783. from_hf_hub: bool = False,
  784. from_aistudio: bool = False,
  785. cache_dir: str | None = None,
  786. subfolder: Optional[str] = "",
  787. config: PretrainedConfig = None,
  788. convert_from_torch: bool = False,
  789. use_safetensors: bool | None = None,
  790. variant=None,
  791. ) -> str:
  792. """resolve model target file path from `` and `cache_dir`
  793. 1. when it is file path:
  794. return the weight file
  795. 2. when it is model-name:
  796. 2.1 check default `MODEL_HOME` + `model-mame` + model_state.pdparams
  797. 2.2 get the url from `pretrained_resource_files_map`, and set it to `pretrained_model_name_or_path`
  798. 3. when it is local dir:
  799. check whether the file<local_dir + weight_file> exist
  800. Args:
  801. cls (Type[PretrainedModel]): the inherited PretrainedModel class
  802. pretrained_model_name_or_path (str): the model-name/url/local_dir/local_dir
  803. cache_dir (Optional[str], optional): cache_dir is used when name_or_path is model-name/url. Defaults to None.
  804. convert_from_torch (bool, optional): whether support convert pytorch model to paddle model
  805. Returns:
  806. str: the model weight file path
  807. """
  808. is_sharded = False
  809. sharded_metadata = None
  810. if pretrained_model_name_or_path is not None:
  811. # the following code use a lot of os.path.join, hence setting subfolder to empty str if None
  812. if subfolder is None:
  813. subfolder = ""
  814. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  815. is_local = os.path.isdir(pretrained_model_name_or_path)
  816. def get_file_path(
  817. pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, variant
  818. ):
  819. return os.path.join(
  820. pretrained_model_name_or_path,
  821. subfolder,
  822. _add_variant(SAFE_WEIGHTS_NAME, variant),
  823. )
  824. # pretrained_model_name_or_path is file
  825. if os.path.isfile(pretrained_model_name_or_path):
  826. archive_file = pretrained_model_name_or_path
  827. is_local = True
  828. # pretrained_model_name_or_path is dir
  829. elif is_local:
  830. if use_safetensors is not False and os.path.isfile(
  831. get_file_path(
  832. pretrained_model_name_or_path,
  833. subfolder,
  834. SAFE_WEIGHTS_INDEX_NAME,
  835. variant,
  836. )
  837. ):
  838. # Load from a sharded safetensors checkpoint
  839. archive_file = get_file_path(
  840. pretrained_model_name_or_path,
  841. subfolder,
  842. SAFE_WEIGHTS_INDEX_NAME,
  843. variant,
  844. )
  845. is_sharded = True
  846. elif use_safetensors is not False and os.path.isfile(
  847. get_file_path(
  848. pretrained_model_name_or_path,
  849. subfolder,
  850. SAFE_WEIGHTS_INDEX_NAME,
  851. weight_name_suffix(),
  852. )
  853. ):
  854. # Load from a sharded safetensors checkpoint
  855. archive_file = get_file_path(
  856. pretrained_model_name_or_path,
  857. subfolder,
  858. SAFE_WEIGHTS_INDEX_NAME,
  859. weight_name_suffix(),
  860. )
  861. is_sharded = True
  862. elif use_safetensors is not False and os.path.isfile(
  863. get_file_path(
  864. pretrained_model_name_or_path,
  865. subfolder,
  866. SAFE_WEIGHTS_NAME,
  867. variant,
  868. )
  869. ):
  870. # Load from a safetensors checkpoint
  871. archive_file = get_file_path(
  872. pretrained_model_name_or_path,
  873. subfolder,
  874. SAFE_WEIGHTS_NAME,
  875. variant,
  876. )
  877. elif use_safetensors is not False and os.path.isfile(
  878. get_file_path(
  879. pretrained_model_name_or_path,
  880. subfolder,
  881. SAFE_WEIGHTS_NAME,
  882. weight_name_suffix(),
  883. )
  884. ):
  885. # Load from a safetensors checkpoint
  886. archive_file = get_file_path(
  887. pretrained_model_name_or_path,
  888. subfolder,
  889. SAFE_WEIGHTS_NAME,
  890. weight_name_suffix(),
  891. )
  892. elif os.path.isfile(
  893. get_file_path(
  894. pretrained_model_name_or_path,
  895. subfolder,
  896. PADDLE_WEIGHTS_INDEX_NAME,
  897. variant,
  898. )
  899. ):
  900. # Load from a sharded PaddlePaddle checkpoint
  901. archive_file = get_file_path(
  902. pretrained_model_name_or_path,
  903. subfolder,
  904. PADDLE_WEIGHTS_INDEX_NAME,
  905. variant,
  906. )
  907. is_sharded = True
  908. elif os.path.isfile(
  909. get_file_path(
  910. pretrained_model_name_or_path,
  911. subfolder,
  912. PADDLE_WEIGHTS_INDEX_NAME,
  913. weight_name_suffix(),
  914. )
  915. ):
  916. # Load from a sharded PaddlePaddle checkpoint for hybrid parallel model
  917. archive_file = get_file_path(
  918. pretrained_model_name_or_path,
  919. subfolder,
  920. PADDLE_WEIGHTS_INDEX_NAME,
  921. weight_name_suffix(),
  922. )
  923. is_sharded = True
  924. elif os.path.isfile(
  925. get_file_path(
  926. pretrained_model_name_or_path,
  927. subfolder,
  928. PADDLE_WEIGHTS_NAME,
  929. variant,
  930. )
  931. ):
  932. # Load from a PaddlePaddle checkpoint
  933. archive_file = get_file_path(
  934. pretrained_model_name_or_path,
  935. subfolder,
  936. PADDLE_WEIGHTS_NAME,
  937. variant,
  938. )
  939. elif os.path.isfile(
  940. get_file_path(
  941. pretrained_model_name_or_path,
  942. subfolder,
  943. PADDLE_WEIGHTS_NAME,
  944. weight_name_suffix(),
  945. )
  946. ):
  947. # Load from a PaddlePaddle checkpoint for hybrid parallel model
  948. archive_file = get_file_path(
  949. pretrained_model_name_or_path,
  950. subfolder,
  951. PADDLE_WEIGHTS_NAME,
  952. weight_name_suffix(),
  953. )
  954. elif os.path.isfile(
  955. os.path.join(
  956. pretrained_model_name_or_path,
  957. subfolder,
  958. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  959. )
  960. ):
  961. if from_hf_hub or convert_from_torch:
  962. archive_file = os.path.join(
  963. pretrained_model_name_or_path,
  964. subfolder,
  965. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  966. )
  967. else:
  968. raise ValueError(
  969. f"Found {_add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant)} in directory"
  970. f" {pretrained_model_name_or_path}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
  971. )
  972. elif os.path.isfile(
  973. os.path.join(
  974. pretrained_model_name_or_path,
  975. subfolder,
  976. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  977. )
  978. ):
  979. if from_hf_hub or convert_from_torch:
  980. archive_file = os.path.join(
  981. pretrained_model_name_or_path,
  982. subfolder,
  983. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  984. )
  985. else:
  986. raise ValueError(
  987. f"Found {_add_variant(PYTORCH_WEIGHTS_NAME, variant)} in directory"
  988. f" {pretrained_model_name_or_path}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
  989. )
  990. else:
  991. raise EnvironmentError(
  992. f"Error no file named {_add_variant(PADDLE_WEIGHTS_NAME, variant)}, found in directory"
  993. f" {pretrained_model_name_or_path}."
  994. )
  995. elif pretrained_model_name_or_path in cls.pretrained_init_configuration:
  996. # fetch the weight url from the `pretrained_resource_files_map`
  997. resource_file_url = cls.pretrained_resource_files_map["model_state"][
  998. pretrained_model_name_or_path
  999. ]
  1000. resolved_archive_file = resolve_file_path(
  1001. pretrained_model_name_or_path,
  1002. [resource_file_url],
  1003. subfolder,
  1004. cache_dir=cache_dir,
  1005. from_aistudio=from_aistudio,
  1006. from_hf_hub=from_hf_hub,
  1007. )
  1008. else:
  1009. if use_safetensors is True:
  1010. filenames = [
  1011. _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
  1012. _add_variant(SAFE_WEIGHTS_NAME, variant),
  1013. ]
  1014. elif use_safetensors is None:
  1015. filenames = [
  1016. _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
  1017. _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
  1018. _add_variant(SAFE_WEIGHTS_NAME, variant),
  1019. _add_variant(PADDLE_WEIGHTS_NAME, variant),
  1020. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  1021. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  1022. ]
  1023. else:
  1024. filenames = [
  1025. _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
  1026. _add_variant(PADDLE_WEIGHTS_NAME, variant),
  1027. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  1028. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  1029. ]
  1030. resolved_archive_file = resolve_file_path(
  1031. pretrained_model_name_or_path,
  1032. filenames,
  1033. subfolder,
  1034. cache_dir=cache_dir,
  1035. from_aistudio=from_aistudio,
  1036. from_hf_hub=from_hf_hub,
  1037. )
  1038. if resolved_archive_file is None:
  1039. raise EnvironmentError(
  1040. f"Error no files {filenames} found in repo {pretrained_model_name_or_path}."
  1041. )
  1042. elif "pytorch_model.bin" in str(resolved_archive_file):
  1043. if not from_hf_hub and not convert_from_torch:
  1044. raise ValueError(
  1045. f"Download pytorch weight in "
  1046. f" {resolved_archive_file}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
  1047. )
  1048. if is_local:
  1049. logging.info(f"Loading weights file {archive_file}")
  1050. resolved_archive_file = archive_file
  1051. else:
  1052. logging.info(
  1053. f"Loading weights file from cache at {resolved_archive_file}"
  1054. )
  1055. else:
  1056. resolved_archive_file = None
  1057. # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
  1058. resolved_sharded_files = None
  1059. if str(resolved_archive_file).endswith(".json"):
  1060. is_sharded = True
  1061. if is_sharded:
  1062. # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
  1063. resolved_sharded_files, sharded_metadata = get_checkpoint_shard_files(
  1064. pretrained_model_name_or_path,
  1065. resolved_archive_file,
  1066. from_aistudio=from_aistudio,
  1067. from_hf_hub=from_hf_hub,
  1068. cache_dir=cache_dir,
  1069. subfolder=subfolder,
  1070. )
  1071. return (
  1072. resolved_archive_file,
  1073. resolved_sharded_files,
  1074. sharded_metadata,
  1075. is_sharded,
  1076. )
  1077. @classmethod
  1078. def _load_pretrained_model(
  1079. cls,
  1080. model: PretrainedModel,
  1081. state_dict: Dict[str, Tensor],
  1082. loaded_keys: List[str],
  1083. resolved_archive_file: Union[str, List] = [],
  1084. pretrained_model_name_or_path=None,
  1085. config=None,
  1086. ignore_mismatched_sizes=False,
  1087. low_cpu_mem_usage=False,
  1088. dtype=None,
  1089. keep_in_fp32_modules=None,
  1090. quantization_linear_list=None,
  1091. sharded_metadata=None,
  1092. convert_from_hf=False,
  1093. ) -> Tuple[List[str]]:
  1094. """load the state_dict into model, and do the following things:
  1095. * check the
  1096. Args:
  1097. model (PretrainedModel): the pretrained model instance
  1098. state_dict (Dict[str, Tensor]): the model state dict data
  1099. loaded_keys (List[str]):
  1100. ignore_mismatched_sizes (bool, optional): whether ignore error when tensor size mismatched. Defaults to False.
  1101. dtype (_type_, optional): the dtype of model state dict. Defaults to None.
  1102. Returns:
  1103. Tuple[List[str]]: _description_
  1104. """
  1105. is_safetensors = False
  1106. if convert_from_hf:
  1107. try:
  1108. model_state_dict = model.get_hf_state_dict()
  1109. except NotImplementedError:
  1110. model_state_dict = model.state_dict()
  1111. else:
  1112. model_state_dict = model.state_dict()
  1113. expected_keys = list(model_state_dict.keys())
  1114. prefix = model.base_model_prefix
  1115. if len(prefix) > 0:
  1116. has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
  1117. expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
  1118. else:
  1119. has_prefix_module = False
  1120. expects_prefix_module = False
  1121. # key re-naming operations are never done on the keys
  1122. # that are loaded, but always on the keys of the newly initialized model
  1123. remove_prefix_from_model = not has_prefix_module and expects_prefix_module
  1124. add_prefix_to_model = has_prefix_module and not expects_prefix_module
  1125. if remove_prefix_from_model:
  1126. _prefix = f"{prefix}."
  1127. expected_keys_not_prefixed = [
  1128. s for s in expected_keys if not s.startswith(_prefix)
  1129. ]
  1130. expected_keys = [
  1131. s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys
  1132. ]
  1133. if quantization_linear_list is not None:
  1134. quantization_linear_list = [
  1135. s[len(_prefix) :] if s.startswith(_prefix) else s
  1136. for s in quantization_linear_list
  1137. ]
  1138. elif add_prefix_to_model:
  1139. expected_keys = [".".join([prefix, s]) for s in expected_keys]
  1140. if quantization_linear_list is not None:
  1141. quantization_linear_list = [
  1142. ".".join([prefix, s]) for s in quantization_linear_list
  1143. ]
  1144. missing_keys = list(set(expected_keys) - set(loaded_keys))
  1145. unexpected_keys = list(set(loaded_keys) - set(expected_keys))
  1146. # Optimize for skip unused shard files for supper large model
  1147. if sharded_metadata is not None:
  1148. assert isinstance(resolved_archive_file, list)
  1149. new_archive_file = []
  1150. skip_archive_file = []
  1151. expected_keys_set = set(expected_keys)
  1152. for file in resolved_archive_file:
  1153. filename = os.path.split(file)[-1]
  1154. if not expected_keys_set.isdisjoint(
  1155. set(sharded_metadata["file_map"][filename])
  1156. ):
  1157. new_archive_file.append(file)
  1158. else:
  1159. skip_archive_file.append(filename)
  1160. resolved_archive_file = new_archive_file
  1161. if len(skip_archive_file) > 0:
  1162. logging.info(
  1163. f"Skip load files for not contrains expected key, {skip_archive_file}"
  1164. )
  1165. # Some models may have keys that are not in the state by design, removing them before needlessly warning
  1166. # the user.
  1167. if cls._keys_to_ignore_on_load_missing is not None:
  1168. for pat in cls._keys_to_ignore_on_load_missing:
  1169. missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
  1170. if cls._keys_to_ignore_on_load_unexpected is not None:
  1171. for pat in cls._keys_to_ignore_on_load_unexpected:
  1172. unexpected_keys = [
  1173. k for k in unexpected_keys if re.search(pat, k) is None
  1174. ]
  1175. # Set some modules to fp32 if any
  1176. if keep_in_fp32_modules is not None:
  1177. for name, param in model.named_parameters():
  1178. if any(
  1179. module_to_keep_in_fp32 in name
  1180. for module_to_keep_in_fp32 in keep_in_fp32_modules
  1181. ):
  1182. if param.dtype != paddle.float32:
  1183. param_fp32 = param.cast(dtype=paddle.float32)
  1184. param_fp32_tensor = param_fp32.value().get_tensor()
  1185. param_tensor = param.value().get_tensor()
  1186. param_tensor._share_data_with(param_fp32_tensor)
  1187. # Make sure we are able to load base models as well as derived models (with heads)
  1188. start_prefix = ""
  1189. model_to_load = model
  1190. if (
  1191. len(cls.base_model_prefix) > 0
  1192. and not hasattr(model, cls.base_model_prefix)
  1193. and has_prefix_module
  1194. ):
  1195. start_prefix = cls.base_model_prefix + "."
  1196. if (
  1197. len(cls.base_model_prefix) > 0
  1198. and hasattr(model, cls.base_model_prefix)
  1199. and not has_prefix_module
  1200. ):
  1201. model_to_load = getattr(model, cls.base_model_prefix)
  1202. base_model_expected_keys = list(model_to_load.state_dict().keys())
  1203. if any(
  1204. key in expected_keys_not_prefixed
  1205. and key not in base_model_expected_keys
  1206. for key in loaded_keys
  1207. ):
  1208. raise ValueError(
  1209. "The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
  1210. "properly saved?"
  1211. )
  1212. def _find_mismatched_keys(
  1213. state_dict,
  1214. model_state_dict,
  1215. loaded_keys,
  1216. add_prefix_to_model,
  1217. remove_prefix_from_model,
  1218. ignore_mismatched_sizes,
  1219. ):
  1220. mismatched_keys = []
  1221. if ignore_mismatched_sizes:
  1222. for checkpoint_key in loaded_keys:
  1223. # If the checkpoint is sharded, we may not have the key here.
  1224. if checkpoint_key not in state_dict:
  1225. continue
  1226. model_key = checkpoint_key
  1227. if remove_prefix_from_model:
  1228. # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
  1229. model_key = f"{prefix}.{checkpoint_key}"
  1230. elif add_prefix_to_model:
  1231. # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
  1232. model_key = ".".join(checkpoint_key.split(".")[1:])
  1233. if (
  1234. model_key in model_state_dict
  1235. and state_dict[checkpoint_key].shape
  1236. != model_state_dict[model_key].shape
  1237. ):
  1238. mismatched_keys.append(
  1239. (
  1240. checkpoint_key,
  1241. state_dict[checkpoint_key].shape,
  1242. model_state_dict[model_key].shape,
  1243. )
  1244. )
  1245. del state_dict[checkpoint_key]
  1246. return mismatched_keys
  1247. def _fuse_or_split_keys(
  1248. state_dict,
  1249. config,
  1250. loaded_keys,
  1251. pre_tensor_parallel_split=False,
  1252. resume_state_dict=None,
  1253. ):
  1254. if resume_state_dict is not None:
  1255. state_dict.update(resume_state_dict)
  1256. before_fuse_keys = list(state_dict.keys())
  1257. if pre_tensor_parallel_split:
  1258. tp_actions = cls.get_tensor_parallel_convert_actions(
  1259. config, loaded_keys, ignore_error=True
  1260. )
  1261. else:
  1262. tp_actions = None
  1263. state_dict, resume_state_dict = cls.convert_fuse_and_split(
  1264. config, state_dict, tp_actions
  1265. )
  1266. after_fuse_keys = list(state_dict.keys())
  1267. fused_keys = list(set(before_fuse_keys) - set(after_fuse_keys))
  1268. new_keys = list(set(after_fuse_keys) - set(before_fuse_keys))
  1269. return state_dict, resume_state_dict, fused_keys, new_keys
  1270. if state_dict is not None:
  1271. # have loaded all state_dict, no resume state_dict
  1272. state_dict, _, fused_keys, new_keys = _fuse_or_split_keys(
  1273. state_dict,
  1274. config,
  1275. loaded_keys,
  1276. pre_tensor_parallel_split=(
  1277. True
  1278. if config is not None and config.tensor_parallel_degree > 1
  1279. else False
  1280. ),
  1281. )
  1282. missing_keys = list(set(missing_keys) - set(new_keys))
  1283. unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
  1284. mismatched_keys = _find_mismatched_keys(
  1285. state_dict,
  1286. model_state_dict,
  1287. loaded_keys,
  1288. add_prefix_to_model,
  1289. remove_prefix_from_model,
  1290. ignore_mismatched_sizes,
  1291. )
  1292. error_msgs = _load_state_dict_into_model(
  1293. model_to_load,
  1294. state_dict,
  1295. start_prefix,
  1296. convert_from_hf=convert_from_hf,
  1297. )
  1298. else:
  1299. # Sharded checkpoint or whole but low_cpu_mem_usage==True
  1300. # This should always be a list but, just to be sure.
  1301. if not isinstance(resolved_archive_file, list):
  1302. resolved_archive_file = [resolved_archive_file]
  1303. error_msgs = []
  1304. mismatched_keys = []
  1305. resume_state_dict = {}
  1306. for shard_file in resolved_archive_file:
  1307. pre_tensor_parallel_split = False
  1308. if (
  1309. shard_file.endswith(".safetensors")
  1310. and config.tensor_parallel_degree > 1
  1311. and "tp" not in os.path.split(shard_file)[-1]
  1312. ):
  1313. pre_tensor_parallel_split = True
  1314. assert loaded_keys is not None, "loaded_keys is not None."
  1315. tp_actions = cls.get_tensor_parallel_convert_actions(
  1316. config, loaded_keys, ignore_error=True
  1317. )
  1318. # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
  1319. filter_dict_keys = set(expected_keys)
  1320. fuse_actions, _ = cls.get_fuse_or_split_param_convert_actions(
  1321. config, loaded_keys, is_fuse=True
  1322. )
  1323. split_actions, _ = cls.get_fuse_or_split_param_convert_actions(
  1324. config, loaded_keys, is_fuse=False
  1325. )
  1326. for k in list(fuse_actions.keys()):
  1327. need_add_except_key = k[-1] in expected_keys
  1328. if need_add_except_key:
  1329. filter_dict_keys |= set(k[:-1])
  1330. # remove pre_tensor_parallel_split function from tp_actions
  1331. if pre_tensor_parallel_split:
  1332. for item in k[:-1]:
  1333. if item in tp_actions:
  1334. tp_actions.pop(item, None)
  1335. for k in list(split_actions.keys()):
  1336. need_add_except_key = False
  1337. for item in k[:-1]:
  1338. if item in expected_keys:
  1339. need_add_except_key = True
  1340. break
  1341. if need_add_except_key:
  1342. filter_dict_keys.add(k[-1])
  1343. # remove pre_tensor_parallel_split function from tp_actions
  1344. if pre_tensor_parallel_split:
  1345. if k[-1] in tp_actions:
  1346. fuse_actions.pop(k[-1], None)
  1347. try:
  1348. transpose_weight_keys = model.get_transpose_weight_keys()
  1349. except NotImplementedError:
  1350. if convert_from_hf:
  1351. raise ValueError("`convert_from_hf=True` is not supported")
  1352. else:
  1353. transpose_weight_keys = None
  1354. state_dict = load_state_dict(
  1355. shard_file,
  1356. tp_actions if pre_tensor_parallel_split else None,
  1357. filter_dict_keys,
  1358. convert_from_hf=convert_from_hf,
  1359. transpose_weight_keys=transpose_weight_keys,
  1360. )
  1361. # convert for fusing or splitting weights
  1362. state_dict, resume_state_dict, fused_keys, new_keys = (
  1363. _fuse_or_split_keys(
  1364. state_dict,
  1365. config,
  1366. loaded_keys,
  1367. pre_tensor_parallel_split=pre_tensor_parallel_split,
  1368. resume_state_dict=resume_state_dict,
  1369. )
  1370. )
  1371. missing_keys = list(set(missing_keys) - set(new_keys))
  1372. unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
  1373. # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
  1374. # matching the weights in the model.
  1375. mismatched_keys += _find_mismatched_keys(
  1376. state_dict,
  1377. model_state_dict,
  1378. loaded_keys,
  1379. add_prefix_to_model,
  1380. remove_prefix_from_model,
  1381. ignore_mismatched_sizes,
  1382. )
  1383. if (
  1384. config.tensor_parallel_degree > 1
  1385. and ".tp" not in shard_file
  1386. and not pre_tensor_parallel_split
  1387. ):
  1388. logging.info("Converting state_dict to Tensor Parallel Format")
  1389. # ignore error for multi shard, since only parts of data
  1390. state_dict = cls.convert_tensor_parallel(
  1391. None,
  1392. config,
  1393. state_dict=state_dict,
  1394. ignore_error=len(resolved_archive_file) > 1,
  1395. )
  1396. logging.info("Converted state_dict to Tensor Parallel Format")
  1397. if low_cpu_mem_usage:
  1398. new_error_msgs = _load_state_dict_into_meta_model(
  1399. model_to_load,
  1400. state_dict,
  1401. loaded_keys,
  1402. start_prefix,
  1403. expected_keys,
  1404. dtype=dtype,
  1405. is_safetensors=is_safetensors,
  1406. keep_in_fp32_modules=keep_in_fp32_modules,
  1407. )
  1408. error_msgs += new_error_msgs
  1409. else:
  1410. error_msgs += _load_state_dict_into_model(
  1411. model_to_load,
  1412. state_dict,
  1413. start_prefix,
  1414. convert_from_hf=convert_from_hf,
  1415. )
  1416. # force memory release
  1417. del state_dict
  1418. gc.collect()
  1419. if len(error_msgs) > 0:
  1420. error_msg = "\n\t".join(error_msgs)
  1421. if " but the expected shape is" in error_msg:
  1422. error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
  1423. raise RuntimeError(
  1424. f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
  1425. )
  1426. if len(unexpected_keys) > 0:
  1427. logging.warning(
  1428. f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
  1429. f" initializing {model.__class__.__name__}: {sorted(unexpected_keys)}\n- This IS expected if you are"
  1430. f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
  1431. " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
  1432. " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
  1433. f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
  1434. " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
  1435. )
  1436. else:
  1437. logging.info(
  1438. f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
  1439. )
  1440. if len(missing_keys) > 0:
  1441. logging.warning(
  1442. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  1443. f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
  1444. " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
  1445. )
  1446. elif len(mismatched_keys) == 0:
  1447. logging.info(
  1448. f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
  1449. f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
  1450. f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
  1451. " training."
  1452. )
  1453. if len(mismatched_keys) > 0:
  1454. mismatched_warning = "\n".join(
  1455. [
  1456. f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
  1457. for key, shape1, shape2 in mismatched_keys
  1458. ]
  1459. )
  1460. logging.warning(
  1461. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  1462. f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
  1463. f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
  1464. " to use it for predictions and inference."
  1465. )
  1466. return model, missing_keys, unexpected_keys, mismatched_keys
  1467. @classmethod
  1468. def from_pretrained(
  1469. cls, pretrained_model_name_or_path, *args, convert_from_hf=False, **kwargs
  1470. ):
  1471. """
  1472. Creates an instance of `PretrainedModel`. Model weights are loaded
  1473. by specifying name of a built-in pretrained model, a pretrained model from HF Hub, a community contributed model,
  1474. or a local file directory path.
  1475. Args:
  1476. pretrained_model_name_or_path (str): Name of pretrained model or dir path
  1477. to load from. The string can be:
  1478. - Name of a built-in pretrained model
  1479. - Name of a pretrained model from HF Hub
  1480. - Name of a community-contributed pretrained model.
  1481. - Local directory path which contains model weights file("model_state.pdparams")
  1482. and model config file ("model_config.json").
  1483. from_hf_hub (bool): load model from huggingface hub. Default to `False`.
  1484. subfolder (str, optional) An optional value corresponding to a folder inside the repo.
  1485. Only works when loading from Huggingface Hub.
  1486. *args (tuple): Position arguments for model `__init__`. If provided,
  1487. use these as position argument values for model initialization.
  1488. **kwargs (dict): Keyword arguments for model `__init__`. If provided,
  1489. use these to update pre-defined keyword argument values for model
  1490. initialization. If the keyword is in `__init__` argument names of
  1491. base model, update argument values of the base model; else update
  1492. argument values of derived model.
  1493. load_state_as_np (bool, optional): The weights read in can be choosed
  1494. to place on CPU or GPU though the model is on the default device.
  1495. If `True`, load the model weights as `numpy.ndarray` on CPU.
  1496. Otherwise, weights would be loaded as tensors on the default
  1497. device. Note that if on GPU, the latter would creates extra
  1498. temporary tensors in addition to the model weights, which
  1499. doubles the memory usage . Thus it is suggested to use `True`
  1500. for big models on GPU. Default to `False`.
  1501. Returns:
  1502. PretrainedModel: An instance of `PretrainedModel`.
  1503. Example:
  1504. .. code-block::
  1505. from paddlenlp.transformers import BertForSequenceClassification
  1506. # Name of built-in pretrained model
  1507. model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
  1508. # Name of pretrained model from PaddleHub
  1509. model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
  1510. # Name of community-contributed pretrained model
  1511. model = BertForSequenceClassification.from_pretrained('yingyibiao/bert-base-uncased-sst-2-finetuned', num_labels=3)
  1512. # Load from local directory path
  1513. model = BertForSequenceClassification.from_pretrained('./my_bert/')
  1514. """
  1515. config = kwargs.pop("config", None)
  1516. state_dict = kwargs.pop("state_dict", None)
  1517. cache_dir = kwargs.pop("cache_dir", None)
  1518. force_download = kwargs.get("force_download", False)
  1519. ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
  1520. dtype = kwargs.pop("dtype", None)
  1521. from_hf_hub = kwargs.pop("from_hf_hub", False)
  1522. from_aistudio = kwargs.pop("from_aistudio", False)
  1523. subfolder = kwargs.pop("subfolder", None)
  1524. if subfolder is None:
  1525. subfolder = ""
  1526. variant = kwargs.pop("variant", None)
  1527. use_safetensors = kwargs.pop(
  1528. "use_safetensors", None if is_dep_available("safetensors") else False
  1529. )
  1530. low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
  1531. convert_from_torch = kwargs.pop("convert_from_torch", None)
  1532. load_state_as_np = kwargs.pop("load_state_as_np", None)
  1533. if load_state_as_np is not None:
  1534. logging.warning("`load_state_as_np` is deprecated, please delete it!")
  1535. model_kwargs = kwargs
  1536. if convert_from_torch is None and os.environ.get("from_modelscope", False):
  1537. logging.warning(
  1538. "If you are attempting to load weights from ModelScope Hub and want to disable the default behavior of considering torch weights,"
  1539. " you can set ·convert_from_torch=False·. By default, `convert_from_torch` is set to `True`. "
  1540. )
  1541. convert_from_torch = True
  1542. # from_hf_hub default enable convert_from_torch
  1543. if from_hf_hub and convert_from_torch is None:
  1544. logging.warning(
  1545. "If you are attempting to load weights from Hugging Face Hub and want to disable the default behavior of considering torch weights,"
  1546. " you can set ·convert_from_torch=False·. By default, `convert_from_torch` is set to `True`. "
  1547. )
  1548. convert_from_torch = True
  1549. # convert_from_torch default is False
  1550. if convert_from_torch is None:
  1551. convert_from_torch = False
  1552. # 1. get the PretrainedConfig to init model
  1553. if not isinstance(config, PretrainedConfig):
  1554. config_path = (
  1555. config if config is not None else pretrained_model_name_or_path
  1556. )
  1557. config, model_kwargs = (
  1558. cls.config_class.from_pretrained( # NOTE cls.config_class : Qwen2VLForConditionalGeneration
  1559. config_path,
  1560. cache_dir=cache_dir,
  1561. from_hf_hub=from_hf_hub,
  1562. from_aistudio=from_aistudio,
  1563. subfolder=subfolder,
  1564. return_unused_kwargs=True,
  1565. **kwargs,
  1566. )
  1567. )
  1568. if "from_aistudio" in model_kwargs:
  1569. model_kwargs.pop("from_aistudio")
  1570. if dtype is None:
  1571. dtype = config.dtype
  1572. config.dtype = dtype
  1573. init_contexts = []
  1574. if dtype:
  1575. init_contexts.append(dtype_guard(dtype))
  1576. # Keep in fp32 modules
  1577. keep_in_fp32_modules = None
  1578. use_keep_in_fp32_modules = False
  1579. # resolve model_weight file
  1580. resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = (
  1581. cls._resolve_model_file_path(
  1582. pretrained_model_name_or_path,
  1583. cache_dir=cache_dir,
  1584. subfolder=subfolder,
  1585. from_hf_hub=from_hf_hub,
  1586. from_aistudio=from_aistudio,
  1587. config=config,
  1588. convert_from_torch=False,
  1589. use_safetensors=use_safetensors,
  1590. variant=variant,
  1591. )
  1592. )
  1593. init_args = config["init_args"] or ()
  1594. with ContextManagers(init_contexts):
  1595. model = cls(config, *init_args, **model_kwargs)
  1596. if convert_from_torch and state_dict is None:
  1597. if (
  1598. resolved_archive_file.endswith(PYTORCH_WEIGHTS_NAME)
  1599. or resolved_archive_file.endswith(PYTORCH_WEIGHTS_INDEX_NAME)
  1600. or resolved_archive_file.endswith(SAFE_WEIGHTS_NAME)
  1601. or resolved_archive_file.endswith(SAFE_WEIGHTS_INDEX_NAME)
  1602. ):
  1603. # try to get the name-mapping info
  1604. convert_dir = os.path.dirname(resolved_archive_file)
  1605. logging.info(
  1606. f"Starting to convert pytorch weight file<{resolved_archive_file}> to "
  1607. f"paddle weight file<{convert_dir}> ..."
  1608. )
  1609. state_dict = cls.convert(
  1610. resolved_archive_file,
  1611. config,
  1612. # cache_dir=os.path.join(cache_dir, pretrained_model_name_or_path, subfolder),
  1613. cache_dir=convert_dir,
  1614. )
  1615. elif (
  1616. resolved_archive_file.endswith(PADDLE_WEIGHTS_NAME)
  1617. or resolved_archive_file.endswith(PADDLE_WEIGHTS_INDEX_NAME)
  1618. or resolved_archive_file.endswith(".pdparams")
  1619. ):
  1620. print(f"file: {resolved_archive_file} is paddle weight.")
  1621. else:
  1622. raise ValueError(
  1623. f"Unexpected file: {resolved_archive_file} for weight conversion."
  1624. )
  1625. # load pt weights early so that we know which dtype to init the model under
  1626. if not is_sharded and state_dict is None:
  1627. # 4. loading non-sharded ckpt from the state dict
  1628. if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
  1629. "model_state.pdparams"
  1630. ):
  1631. state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
  1632. elif config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
  1633. "model.safetensors"
  1634. ):
  1635. raise NotImplementedError
  1636. else:
  1637. try:
  1638. transpose_weight_keys = model.get_transpose_weight_keys()
  1639. except NotImplementedError:
  1640. if convert_from_hf:
  1641. raise ValueError("`convert_from_hf=True` is not supported")
  1642. else:
  1643. transpose_weight_keys = None
  1644. state_dict = load_state_dict(
  1645. resolved_archive_file,
  1646. convert_from_hf=convert_from_hf,
  1647. transpose_weight_keys=transpose_weight_keys,
  1648. )
  1649. logging.info("Loaded weights file from disk, setting weights to model.")
  1650. # Check if `_keep_in_fp32_modules` is not None
  1651. use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
  1652. dtype == "float16" or dtype == "bfloat16"
  1653. )
  1654. if state_dict is not None:
  1655. loaded_state_dict_keys = [k for k in state_dict.keys()]
  1656. # will only support load paddle.Tensor to model.
  1657. for k in list(state_dict.keys()):
  1658. if not isinstance(state_dict[k], paddle.Tensor):
  1659. with device_guard():
  1660. state_dict[k] = paddle.Tensor.__call__(
  1661. state_dict.pop(k), zero_copy=True
  1662. )
  1663. else:
  1664. if is_sharded:
  1665. loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
  1666. else:
  1667. loaded_state_dict_keys = [k for k in state_dict.keys()]
  1668. if low_cpu_mem_usage: # or use_keep_in_fp32_modules:
  1669. state_dict = None
  1670. # will only support load paddle.Tensor to model.
  1671. if state_dict is not None:
  1672. for k in list(state_dict.keys()):
  1673. if not isinstance(state_dict[k], paddle.Tensor):
  1674. with device_guard():
  1675. state_dict[k] = paddle.Tensor.__call__(
  1676. state_dict.pop(k), zero_copy=True
  1677. )
  1678. if use_keep_in_fp32_modules:
  1679. # low_cpu_mem_usage = True
  1680. keep_in_fp32_modules = model._keep_in_fp32_modules
  1681. else:
  1682. keep_in_fp32_modules = []
  1683. quantization_linear_list = None
  1684. model, missing_keys, unexpected_keys, mismatched_keys = (
  1685. cls._load_pretrained_model(
  1686. model=model,
  1687. state_dict=state_dict,
  1688. loaded_keys=loaded_state_dict_keys,
  1689. resolved_archive_file=(
  1690. resolved_sharded_files if is_sharded else resolved_archive_file
  1691. ),
  1692. pretrained_model_name_or_path=pretrained_model_name_or_path,
  1693. config=config,
  1694. ignore_mismatched_sizes=ignore_mismatched_sizes,
  1695. low_cpu_mem_usage=low_cpu_mem_usage,
  1696. dtype=dtype,
  1697. keep_in_fp32_modules=keep_in_fp32_modules,
  1698. quantization_linear_list=quantization_linear_list,
  1699. sharded_metadata=sharded_metadata if is_sharded else None,
  1700. convert_from_hf=convert_from_hf,
  1701. )
  1702. )
  1703. # load generation_config.json
  1704. if model.can_generate() and pretrained_model_name_or_path is not None:
  1705. try:
  1706. model.generation_config = GenerationConfig.from_pretrained(
  1707. pretrained_model_name_or_path,
  1708. cache_dir=cache_dir,
  1709. force_download=force_download,
  1710. from_hf_hub=from_hf_hub,
  1711. from_aistudio=from_aistudio,
  1712. subfolder=subfolder,
  1713. **kwargs,
  1714. )
  1715. except:
  1716. logging.info(
  1717. "Generation config file not found, using a generation config created from the model config."
  1718. )
  1719. pass
  1720. # Note:
  1721. # 1. PipelineLayer will create parameters for each layer and
  1722. # call `_synchronize_shared_weights()` to synchronize the shared parameters.
  1723. # 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
  1724. # synchronize the shared parameters.
  1725. # However, when state dict only contains the one piece of shared parameters, the shared parameters
  1726. # will be different from the original shared parameters.
  1727. if isinstance(model, PipelineLayer):
  1728. model._synchronize_shared_weights()
  1729. if paddle.in_dynamic_mode():
  1730. return model
  1731. return model, state_dict
  1732. def merge_auto_dist_configs(self, configs):
  1733. """
  1734. Merged all auto dist configs into one config.
  1735. configs is a list of config,every config is a dict,which means a model auto_dist_config.
  1736. [
  1737. {
  1738. mp_config (dict): {
  1739. "parallelize_plan": dict, the plan to shard the layer.
  1740. }
  1741. pp_config (dict): {
  1742. "split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
  1743. "global_spec": str|list(str), make the output tensor of specific layers on global mesh.
  1744. }
  1745. },{
  1746. mp_config (dict): {
  1747. "parallelize_plan": dict, the plan to shard the layer.
  1748. }
  1749. pp_config (dict): {
  1750. "split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
  1751. "global_spec": str|list(str), make the output tensor of specific layers on global mesh.
  1752. }
  1753. },....
  1754. ]
  1755. """
  1756. assert isinstance(configs, (dict, list))
  1757. if isinstance(configs, dict):
  1758. return configs
  1759. final_config = {
  1760. "mp_config": None,
  1761. "sp_config": None,
  1762. "pp_config": None,
  1763. }
  1764. for config in configs:
  1765. if "mp_config" in config and config["mp_config"] is not None:
  1766. if final_config["mp_config"] is None:
  1767. final_config["mp_config"] = config["mp_config"]
  1768. else:
  1769. for k, v in config["mp_config"]["parallelize_plan"].items():
  1770. assert (
  1771. k
  1772. not in final_config["mp_config"]["parallelize_plan"].keys()
  1773. ), f"sublayer mp_config should be a subset of model but got sublayer config {config['mp_config']} and model config {final_config['mp_config']}."
  1774. final_config["mp_config"]["parallelize_plan"][k] = v
  1775. if "sp_config" in config and config["sp_config"] is not None:
  1776. if final_config["sp_config"] is None:
  1777. final_config["sp_config"] = config["sp_config"]
  1778. else:
  1779. for k, v in config["sp_config"]["parallelize_plan"].items():
  1780. assert (
  1781. k
  1782. not in final_config["sp_config"]["parallelize_plan"].keys()
  1783. ), f"sublayer sp_config should be a subset of model but got sublayer config {config['sp_config']} and model config {final_config['sp_config']}."
  1784. final_config["sp_config"]["parallelize_plan"][k] = v
  1785. if "pp_config" in config and config["pp_config"] is not None:
  1786. if isinstance(config["pp_config"]["split_spec"], str):
  1787. config["pp_config"]["split_spec"] = [
  1788. config["pp_config"]["split_spec"]
  1789. ]
  1790. if final_config["pp_config"] is None:
  1791. final_config["pp_config"] = config["pp_config"]
  1792. else:
  1793. final_config["pp_config"]["split_spec"] += config["pp_config"][
  1794. "split_spec"
  1795. ]
  1796. elif isinstance(config["pp_config"]["split_spec"], (tuple, list)):
  1797. if final_config["pp_config"] is None:
  1798. final_config["pp_config"] = config["pp_config"]
  1799. else:
  1800. final_config["pp_config"]["split_spec"] += config["pp_config"][
  1801. "split_spec"
  1802. ]
  1803. if (
  1804. final_config["pp_config"] is not None
  1805. and len(final_config["pp_config"]["split_spec"]) == 1
  1806. ):
  1807. final_config["pp_config"]["split_spec"] = final_config["pp_config"][
  1808. "split_spec"
  1809. ][0]
  1810. return final_config
  1811. def _generate_auto_dist_config(self, auto_dist_degree):
  1812. merged_config = {
  1813. "sp_config": None,
  1814. "mp_config": None,
  1815. "pp_config": None,
  1816. }
  1817. for name, layer in self.named_sublayers(include_self=True):
  1818. if hasattr(layer, "auto_dist_config"):
  1819. if name != "":
  1820. prefix = name + "."
  1821. else:
  1822. prefix = ""
  1823. layer_config = layer.auto_dist_config(prefix)
  1824. merged_config = self.merge_auto_dist_configs(
  1825. [merged_config, layer_config]
  1826. )
  1827. for _, deeper_layer in layer.named_sublayers():
  1828. if hasattr(deeper_layer, "auto_dist_config"):
  1829. # mask all `auto_dist_config` methods in deeper layer
  1830. deeper_layer.auto_dist_config = lambda x: {}
  1831. final_config = {
  1832. "dp_config": None,
  1833. "mp_config": None,
  1834. "pp_config": None,
  1835. }
  1836. if (
  1837. "tensor_parallel" in auto_dist_degree
  1838. and auto_dist_degree["tensor_parallel"]
  1839. ):
  1840. merged_config["mp_config"] is not None
  1841. final_config["mp_config"] = merged_config["mp_config"]
  1842. if (
  1843. "sequence_parallel" in auto_dist_degree
  1844. and auto_dist_degree["sequence_parallel"]
  1845. ):
  1846. merged_config["sp_config"] is not None
  1847. final_config["mp_config"] = merged_config["sp_config"]
  1848. if (
  1849. "pipeline_parallel" in auto_dist_degree
  1850. and auto_dist_degree["pipeline_parallel"]
  1851. ):
  1852. merged_config["pp_config"] is not None
  1853. final_config["pp_config"] = merged_config["pp_config"]
  1854. return final_config
  1855. def get_transpose_weight_keys(self):
  1856. raise NotImplementedError
  1857. def get_hf_state_dict(self, *args, **kwargs):
  1858. raise NotImplementedError
  1859. def set_hf_state_dict(self, *args, **kwargs):
  1860. raise NotImplementedError