model_utils.py 90 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import gc
  16. import os
  17. import re
  18. import warnings
  19. from contextlib import contextmanager
  20. from functools import partial
  21. from pathlib import Path
  22. from typing import Any, Dict, List, Optional, Tuple, Type, Union
  23. import numpy as np
  24. import paddle
  25. import paddle.nn as nn
  26. from paddle import Tensor
  27. from paddle.distributed.fleet.meta_parallel.parallel_layers import PipelineLayer
  28. try:
  29. from paddle.distributed.fleet.meta_parallel import LocalSharedLayerDesc
  30. except:
  31. LocalSharedLayerDesc = None
  32. from paddle.nn import Layer
  33. from ......utils import logging
  34. from ......utils.deps import is_dep_available, require_deps
  35. from ...tokenizer.tokenizer_utils import InitTrackerMeta, adapt_stale_fwd_patch
  36. from ..generation import GenerationConfig, GenerationMixin
  37. from ..utils import (
  38. ASYMMETRY_QUANT_SCALE_MAX,
  39. ASYMMETRY_QUANT_SCALE_MIN,
  40. CONFIG_NAME,
  41. LEGACY_CONFIG_NAME,
  42. PADDLE_WEIGHTS_INDEX_NAME,
  43. PADDLE_WEIGHTS_NAME,
  44. PYTORCH_WEIGHTS_INDEX_NAME,
  45. PYTORCH_WEIGHTS_NAME,
  46. SAFE_WEIGHTS_INDEX_NAME,
  47. SAFE_WEIGHTS_NAME,
  48. SYMMETRY_QUANT_SCALE,
  49. device_guard,
  50. resolve_file_path,
  51. )
  52. from .configuration_utils import PretrainedConfig
  53. from .conversion_utils import ConversionMixin
  54. from .utils import (
  55. ContextManagers,
  56. fn_args_to_dict,
  57. get_checkpoint_shard_files,
  58. paddlenlp_load,
  59. weight_name_suffix,
  60. )
  61. __all__ = [
  62. "PretrainedModel",
  63. ]
  64. def _add_variant(weights_name: str, variant=None) -> str:
  65. if variant is not None and len(variant) > 0:
  66. splits = weights_name.split(".")
  67. splits = splits[:-1] + [variant] + splits[-1:]
  68. weights_name = ".".join(splits)
  69. return weights_name
  70. @contextmanager
  71. def dtype_guard(dtype="float32"):
  72. origin_dtype = paddle.get_default_dtype()
  73. paddle.set_default_dtype(dtype)
  74. try:
  75. yield
  76. finally:
  77. paddle.set_default_dtype(origin_dtype)
  78. _init_weights = True
  79. @contextmanager
  80. def no_init_weights(_enable=True):
  81. """
  82. Context manager to globally disable weight initialization to speed up loading large models.
  83. TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
  84. """
  85. global _init_weights
  86. old_init_weights = _init_weights
  87. if _enable:
  88. _init_weights = False
  89. try:
  90. yield
  91. finally:
  92. _init_weights = old_init_weights
  93. def _split_keys_evenly(keys: list, n: int) -> list:
  94. """Split a list into n lists with an equal number of elements.
  95. Args:
  96. keys (list): the list to be split
  97. n (int): number of splits
  98. Returns:
  99. result: list of lists
  100. """
  101. total_len = len(keys)
  102. base_size = total_len // n
  103. extra = total_len % n
  104. result = []
  105. index = 0
  106. for _ in range(n):
  107. part_size = base_size + 1 if extra > 0 else base_size
  108. extra -= 1
  109. result.append(keys[index : index + part_size])
  110. index += part_size
  111. return result
  112. def _load_part_state_dict_from_safetensors(
  113. keys,
  114. checkpoint_file: Union[str, os.PathLike],
  115. tensor_parallel_split_mapping,
  116. fliter_dict_keys,
  117. device,
  118. quantization_linear_list=None,
  119. quantization_config=None,
  120. dtype=None,
  121. return_numpy=False,
  122. convert_from_hf=False,
  123. transpose_weight_keys=None,
  124. ):
  125. import paddle
  126. from safetensors import safe_open
  127. if transpose_weight_keys:
  128. transpose_weight_keys = set(transpose_weight_keys)
  129. def _is_need_transpose(key):
  130. if "lora" not in key and convert_from_hf and transpose_weight_keys:
  131. return key in transpose_weight_keys
  132. def _transpose_hf_weight(key, weight):
  133. if _is_need_transpose(key):
  134. return weight.transpose([-1, -2])
  135. return weight
  136. part_state_dict = {}
  137. scale_dict = {}
  138. with safe_open(checkpoint_file, framework="paddle") as f:
  139. for key in keys:
  140. # 1. non-merge ckpt loading dont have filter key.
  141. # 2. merge ckpt will skip quant scale by `fliter_dict_keys`
  142. if (
  143. key.endswith(SYMMETRY_QUANT_SCALE)
  144. or key.endswith(ASYMMETRY_QUANT_SCALE_MIN)
  145. or key.endswith(ASYMMETRY_QUANT_SCALE_MAX)
  146. ):
  147. continue
  148. if fliter_dict_keys is not None and key not in fliter_dict_keys:
  149. continue
  150. py_safe_slice_ = f.get_slice(key)
  151. if (
  152. quantization_linear_list is not None
  153. and key.split(".weight")[0] in quantization_linear_list
  154. and not key.endswith("_scale")
  155. ):
  156. raise NotImplementedError
  157. else:
  158. if key in tensor_parallel_split_mapping:
  159. tp_fn = tensor_parallel_split_mapping[key]
  160. if _is_need_transpose(key):
  161. assert isinstance(tp_fn, partial)
  162. is_column = True
  163. if "is_column" in tp_fn.keywords:
  164. is_column = tp_fn.keywords["is_column"]
  165. is_column = not is_column
  166. tp_fn = partial(
  167. tp_fn.func,
  168. *tp_fn.args,
  169. **{**tp_fn.keywords, "is_column": is_column},
  170. )
  171. if len(py_safe_slice_.shape) == 0:
  172. weight = tp_fn(py_safe_slice_[:])
  173. else:
  174. weight = tp_fn(py_safe_slice_)
  175. else:
  176. weight = py_safe_slice_[:]
  177. if not return_numpy and device == "expected":
  178. weight = weight._copy_to(
  179. paddle.framework._current_expected_place(), False
  180. )
  181. weight = _transpose_hf_weight(key, weight)
  182. if return_numpy:
  183. weight = weight.numpy()
  184. part_state_dict[key] = weight
  185. for key in keys:
  186. if (
  187. key.endswith(SYMMETRY_QUANT_SCALE)
  188. or key.endswith(ASYMMETRY_QUANT_SCALE_MIN)
  189. or key.endswith(ASYMMETRY_QUANT_SCALE_MAX)
  190. ):
  191. scale = f.get_tensor(key)
  192. if not return_numpy and device == "expected":
  193. scale = scale._copy_to(
  194. paddle.framework._current_expected_place(), False
  195. )
  196. if return_numpy:
  197. scale = scale.numpy()
  198. scale_dict[key] = scale
  199. return part_state_dict, scale_dict
  200. def load_state_dict(
  201. checkpoint_file: Union[str, os.PathLike],
  202. tensor_parallel_split_mapping=None,
  203. fliter_dict_keys=None,
  204. device="cpu",
  205. ckpt_quant_stage="O0",
  206. convert_from_hf=False,
  207. transpose_weight_keys=None,
  208. ):
  209. """
  210. Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
  211. """
  212. if tensor_parallel_split_mapping is None:
  213. tensor_parallel_split_mapping = {}
  214. if Path(checkpoint_file).suffix == ".safetensors":
  215. require_deps("safetensors")
  216. from safetensors import safe_open
  217. with safe_open(checkpoint_file, framework="paddle") as f:
  218. state_dict, scale_dict = _load_part_state_dict_from_safetensors(
  219. list(f.keys()),
  220. checkpoint_file,
  221. tensor_parallel_split_mapping,
  222. fliter_dict_keys,
  223. "expected",
  224. quantization_linear_list=None,
  225. quantization_config=None,
  226. dtype=None,
  227. return_numpy=False,
  228. convert_from_hf=convert_from_hf,
  229. transpose_weight_keys=transpose_weight_keys,
  230. )
  231. else:
  232. state_dict = paddlenlp_load(checkpoint_file, map_location="cpu")
  233. return state_dict
  234. _re_layer_prefix = re.compile(r"\.(\d+)\.")
  235. def _load_state_dict_into_model(
  236. model_to_load, state_dict, start_prefix, convert_from_hf
  237. ):
  238. # torch will cast dtype in load_state_dict, but paddle strictly check dtype
  239. _convert_state_dict_dtype_and_shape(state_dict, model_to_load)
  240. error_msgs = []
  241. if len(start_prefix) > 0:
  242. for key in list(state_dict.keys()):
  243. if key.startswith(start_prefix):
  244. state_dict[key.replace(start_prefix, "")] = state_dict.pop(key)
  245. # TODO: add return status to state_dict
  246. with warnings.catch_warnings(record=True) as w:
  247. warnings.resetwarnings()
  248. # paddlenlp hold missing_keys , just ignore not found warnings.
  249. warnings.filterwarnings(
  250. "ignore", message=r".*is not found in the provided dict.*"
  251. )
  252. warnings.filterwarnings("ignore", message=r".*paddle.to_tensor.*")
  253. if convert_from_hf:
  254. try:
  255. model_to_load.set_hf_state_dict(state_dict)
  256. except NotImplementedError:
  257. pass
  258. model_to_load.set_state_dict(state_dict)
  259. error_msgs.extend([str(x.message) for x in w])
  260. del state_dict
  261. return error_msgs
  262. def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
  263. # convert the dtype of state dict
  264. def is_0d_or_1d(tensor):
  265. return len(tensor.shape) == 0 or list(tensor.shape) == [1]
  266. for key, value in model_to_load.state_dict().items():
  267. if key in list(state_dict.keys()):
  268. if isinstance(state_dict[key], np.ndarray):
  269. raise ValueError(
  270. "convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, please convert numpy.ndarray to paddle.Tensor"
  271. )
  272. # confirm parameter cast is executed on the same device as model
  273. # TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
  274. if (
  275. state_dict[key].is_floating_point()
  276. and state_dict[key].dtype != value.dtype
  277. ):
  278. state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
  279. # unified 0d and 1d tensor
  280. if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
  281. if list(value.shape) != list(state_dict[key].shape):
  282. state_dict[key] = paddle.reshape(state_dict.pop(key), value.shape)
  283. def _load_state_dict_into_meta_model(
  284. model,
  285. state_dict,
  286. loaded_state_dict_keys, # left for now but could be removed, see below
  287. start_prefix,
  288. expected_keys,
  289. dtype=None,
  290. is_safetensors=False,
  291. keep_in_fp32_modules=None,
  292. ):
  293. """
  294. This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
  295. params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
  296. params back to the normal device, but only for `loaded_state_dict_keys`.
  297. `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
  298. `bert.pooler.dense.weight`
  299. """
  300. from paddle.common_ops_import import convert_np_dtype_to_dtype_
  301. dtype = convert_np_dtype_to_dtype_(dtype)
  302. error_msgs = []
  303. model_state_dict = model.state_dict()
  304. for param_name, param in state_dict.items():
  305. # First part of the test is always true as loaded_state_dict_keys always contains state_dict keys.
  306. if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
  307. continue
  308. if param_name.startswith(start_prefix):
  309. param_name = param_name[len(start_prefix) :]
  310. if param.place != paddle.framework._current_expected_place():
  311. param = param._copy_to(paddle.framework._current_expected_place(), False)
  312. # # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
  313. # # in int/uint/bool and not cast them.
  314. if dtype is not None and paddle.is_floating_point(param):
  315. if (
  316. keep_in_fp32_modules is not None
  317. and any(
  318. module_to_keep_in_fp32 in param_name
  319. for module_to_keep_in_fp32 in keep_in_fp32_modules
  320. )
  321. and (dtype == paddle.float16 or dtype == paddle.bfloat16)
  322. ):
  323. param = param.astype(dtype=paddle.float32)
  324. else:
  325. param = param.astype(dtype=dtype)
  326. if dtype is None:
  327. old_param = model
  328. splits = param_name.split(".")
  329. for split in splits:
  330. old_param = getattr(old_param, split)
  331. if old_param is None:
  332. break
  333. if old_param is not None:
  334. param = param.astype(dtype=old_param.dtype)
  335. with paddle.no_grad():
  336. model_state_dict[param_name].get_tensor()._share_data_with(
  337. param.value().get_tensor()
  338. )
  339. param.value().get_tensor()._clear()
  340. return error_msgs
  341. class PretrainedModel(
  342. Layer, GenerationMixin, ConversionMixin, metaclass=InitTrackerMeta
  343. ):
  344. """
  345. The base class for all pretrained models. It mainly provides common methods
  346. for loading (construction and loading) and saving pretrained models. Loading
  347. and saving also rely on the following class attributes which should be overridden
  348. by derived classes accordingly:
  349. - **model_config_file** (str): Represents the file name of model configuration
  350. for configuration saving and loading in local file system. The value is
  351. `model_config.json`.
  352. - **resource_files_names** (dict): Name of local file where the model configuration
  353. can be saved and loaded locally. Currently, resources only include the model state,
  354. thus the dict only includes `'model_state'` as key with corresponding
  355. value `'model_state.pdparams'` for model weights saving and loading.
  356. - **pretrained_init_configuration** (dict): Provides the model configurations
  357. of built-in pretrained models (contrasts to models in local file system).
  358. It has pretrained model names as keys (such as `bert-base-uncased`), and
  359. the values are dict preserving corresponding configuration for model initialization.
  360. - **pretrained_resource_files_map** (dict): Provides resource URLs of built-in
  361. pretrained models (contrasts to models in local file system).
  362. It has the same key as resource_files_names (that is "model_state"),
  363. and the corresponding value is a dict with specific model name to model weights URL mapping
  364. (such as "bert-base-uncased" ->
  365. "https://bj.bcebos.com/paddlenlp/models/transformers/bert-base-uncased.pdparams").
  366. - **base_model_prefix** (str): Represents the attribute associated to the
  367. base model in derived classes of the same architecture adding layers on
  368. top of the base model. Note: A base model class is pretrained model class
  369. decorated by `register_base_model`, such as `BertModel`; A derived model
  370. class is a pretrained model class adding layers on top of the base model,
  371. and it has a base model as attribute, such as `BertForSequenceClassification`.
  372. Methods common to models for text generation are defined in `GenerationMixin`
  373. and also inherited here.
  374. Besides, metaclass `InitTrackerMeta` is used to create `PretrainedModel`,
  375. by which subclasses can track arguments for initialization automatically.
  376. """
  377. # Deprecated(wj-Mcat): after 2.6.* version
  378. # save the old-school `LEGACY_CONFIG_NAME`, and will be changed to `CONFIG_NAME` after 2.6.* version
  379. model_config_file = LEGACY_CONFIG_NAME
  380. pretrained_init_configuration = {}
  381. # TODO: more flexible resource handle, namedtuple with fields as:
  382. # resource_name, saved_file, handle_name_for_load(None for used as __init__
  383. # arguments), handle_name_for_save
  384. resource_files_names = {"model_state": PADDLE_WEIGHTS_NAME}
  385. pretrained_resource_files_map = {}
  386. base_model_prefix = ""
  387. main_input_name = "input_ids"
  388. config_class = None
  389. _keep_in_fp32_modules = None
  390. # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
  391. # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
  392. _keys_to_ignore_on_load_missing = None
  393. # a list of `re` patterns of `state_dict` keys that should be removed from the list of
  394. # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
  395. # warnings.
  396. _keys_to_ignore_on_load_unexpected = None
  397. # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
  398. # trained, but which are either deterministic or tied variables)
  399. _keys_to_ignore_on_save = None
  400. _tied_weights_keys = None
  401. def __init__(self, *args, **kwargs):
  402. super(PretrainedModel, self).__init__()
  403. if not self.constructed_from_pretrained_config():
  404. return
  405. # extract config from args
  406. config = None
  407. for arg in args:
  408. if isinstance(arg, PretrainedConfig):
  409. config = arg
  410. break
  411. if config is not None:
  412. self.config: PretrainedConfig = config
  413. self.model_config_file = CONFIG_NAME
  414. self.generation_config = (
  415. GenerationConfig.from_model_config(self.config)
  416. if self.can_generate()
  417. else None
  418. )
  419. return
  420. # extract config from kwargs
  421. if "config" not in kwargs:
  422. raise ValueError(
  423. "PretrainedConfig instance not found in the arguments, you can set it as args or kwargs with config field"
  424. )
  425. config = kwargs["config"]
  426. if not isinstance(config, PretrainedConfig):
  427. raise TypeError(
  428. "config parameter should be the instance of PretrainedConfig"
  429. )
  430. self.config: PretrainedConfig = kwargs["config"]
  431. self.generation_config = (
  432. GenerationConfig.from_model_config(self.config)
  433. if self.can_generate()
  434. else None
  435. )
  436. self.model_config_file = CONFIG_NAME
  437. self.warnings_issued = {}
  438. def _post_init(self, original_init, *args, **kwargs):
  439. """
  440. It would be hooked after `__init__` to add a dict including arguments of
  441. `__init__` as a attribute named `config` of the pretrained model instance.
  442. """
  443. if not self.constructed_from_pretrained_config():
  444. init_dict = fn_args_to_dict(original_init, *((self,) + args), **kwargs)
  445. self.config = init_dict
  446. # only execute when it's the base method
  447. if (
  448. original_init.__module__ != "paddlenlp.transformers.model_utils"
  449. and self.__class__.init_weights is PretrainedModel.init_weights
  450. ):
  451. self.init_weights()
  452. # Note:
  453. # 1. PipelineLayer will create parameters for each layer and
  454. # call `_synchronize_shared_weights()` to synchronize the shared parameters.
  455. # 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
  456. # synchronize the shared parameters.
  457. # However, `self._init_weights` will re-initialize the parameters without
  458. # synchronizing the shared parameters. If the following step does not load a checkpoint,
  459. # the shared parameters will be different.
  460. if isinstance(self, PipelineLayer):
  461. self._synchronize_shared_weights()
  462. def _init_weights(self, layer):
  463. """
  464. Initialize the weights. This method should be overridden by derived class.
  465. """
  466. pass
  467. def _initialize_weights(self, layer):
  468. """
  469. Initialize the weights if they are not already initialized.
  470. """
  471. if getattr(layer, "_is_initialized", False):
  472. return
  473. self._init_weights(layer)
  474. layer._is_initialized = True
  475. def init_weights(self):
  476. """
  477. If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
  478. initialization logic in `_init_weights`.
  479. """
  480. # call pure
  481. if _init_weights:
  482. # Initialize weights
  483. self.apply(self._initialize_weights)
  484. # Tie weights should be skipped when not initializing all weights
  485. # since from_pretrained(...) calls tie weights anyways
  486. # TODO(wj-Mcat): enable all tie-weights later
  487. # self.tie_weights()
  488. @classmethod
  489. def _from_config(cls, config, **kwargs):
  490. """
  491. All context managers that the model should be initialized under go here.
  492. Args:
  493. dtype (`paddle.dtype`, *optional*):
  494. Override the default `paddle.dtype` and load the model under this dtype.
  495. """
  496. dtype = kwargs.pop("dtype", None)
  497. if dtype is None:
  498. if config.dtype is not None:
  499. dtype = config.dtype
  500. else:
  501. dtype = paddle.get_default_dtype()
  502. with dtype_guard(dtype):
  503. model = cls(config, **kwargs)
  504. return model
  505. @classmethod
  506. def from_config(cls, config, **kwargs):
  507. """
  508. All context managers that the model should be initialized under go here.
  509. Args:
  510. dtype (`paddle.dtype`, *optional*):
  511. Override the default `paddle.dtype` and load the model under this dtype.
  512. """
  513. return cls._from_config(config, **kwargs)
  514. @classmethod
  515. def set_inference_config(cls, config, predictor_args, **kwargs):
  516. """
  517. All inference config can set here.
  518. Args:
  519. config : PretrainedConfig
  520. The config of the model.
  521. predictor_args : PredictorArgument
  522. The args of the predictor.
  523. """
  524. tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", 1)
  525. tensor_parallel_rank = kwargs.pop("tensor_parallel_rank", 0)
  526. if predictor_args.mode == "dynamic" or predictor_args.speculate_method in [
  527. "eagle",
  528. "mtp",
  529. ]:
  530. config.tensor_parallel_degree = tensor_parallel_degree
  531. config.tensor_parallel_rank = tensor_parallel_rank
  532. config.model_name_or_path = predictor_args.model_name_or_path
  533. config.quant_type = predictor_args.quant_type
  534. config.cachekv_int8_type = predictor_args.cachekv_int8_type
  535. config.use_fake_parameter = predictor_args.use_fake_parameter
  536. config.single_card_ptq = not predictor_args.use_fake_parameter
  537. config.append_attn = predictor_args.append_attn
  538. config.decode_strategy = predictor_args.decode_strategy
  539. config.mla_use_matrix_absorption = predictor_args.mla_use_matrix_absorption
  540. config.weightonly_group_size = predictor_args.weightonly_group_size
  541. config.weight_block_size = predictor_args.weight_block_size
  542. config.moe_quant_type = predictor_args.moe_quant_type
  543. if config.quantization_config.quant_method is not None:
  544. predictor_args.weight_block_size = (
  545. config.quantization_config.weight_block_size
  546. )
  547. config.weight_block_size = predictor_args.weight_block_size
  548. if config.quantization_config.quant_type is not None:
  549. if predictor_args.mode == "dynamic":
  550. predictor_args.quant_type = config.quantization_config.quant_type
  551. config.quant_type = config.quantization_config.quant_type
  552. if "c8" in config.quant_type:
  553. predictor_args.cachekv_int8_type = "static"
  554. if predictor_args.mode == "dynamic":
  555. config.cachekv_int8_type = "static"
  556. if predictor_args.mode == "dynamic":
  557. ptq_multicards_num = 0
  558. if os.path.exists(config.model_name_or_path):
  559. prefix = "act_scales_"
  560. for filename in os.listdir(config.model_name_or_path):
  561. if filename.startswith(prefix):
  562. ptq_multicards_num += 1
  563. logging.info(
  564. f"PTQ from {ptq_multicards_num} cards, so we will not split"
  565. )
  566. if ptq_multicards_num > 1:
  567. config.single_card_ptq = False
  568. if predictor_args.block_attn:
  569. config.block_size = predictor_args.block_size
  570. config.max_seq_len = predictor_args.total_max_length
  571. if predictor_args.speculate_method is not None:
  572. config.speculate_method = predictor_args.speculate_method
  573. config.speculate_max_draft_token_num = (
  574. predictor_args.speculate_max_draft_token_num
  575. )
  576. config.speculate_verify_window = predictor_args.speculate_verify_window
  577. config.speculate_max_candidate_len = (
  578. predictor_args.speculate_max_candidate_len
  579. )
  580. if predictor_args.speculate_method == "inference_with_reference":
  581. config.speculate_max_ngram_size = (
  582. predictor_args.speculate_max_ngram_size
  583. )
  584. if predictor_args.speculate_method is not None:
  585. if not config.get("speculate_model_type", "None") in ["eagle", "mtp"]:
  586. config.decode_strategy = "speculate_decoding"
  587. config.return_full_hidden_states = predictor_args.return_full_hidden_states
  588. @classmethod
  589. def confirm_inference_model(cls, predictor_args, **kwargs):
  590. """
  591. Confirm the inference model whether it need to change the AVX inference Model
  592. Args:
  593. model : PretrainedModel
  594. The model for inference.
  595. predictor_args : PredictorArgument
  596. The args of the predictor.
  597. """
  598. return cls
  599. @property
  600. def base_model(self):
  601. """
  602. PretrainedModel: The body of the same model architecture. It is the base
  603. model itself for base model or the base model attribute for derived
  604. model.
  605. """
  606. return getattr(self, self.base_model_prefix, self)
  607. @property
  608. def model_name_list(self):
  609. """
  610. list: Contains all supported built-in pretrained model names of the
  611. current PretrainedModel class.
  612. """
  613. # Todo: return all model name
  614. return list(self.pretrained_init_configuration.keys())
  615. def can_generate(self) -> bool:
  616. """
  617. Returns whether this model can generate sequences with `.generate()`.
  618. Returns:
  619. `bool`: Whether this model can generate sequences with `.generate()`.
  620. """
  621. # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
  622. if "GenerationMixin" in str(self.prepare_inputs_for_generation):
  623. return False
  624. return True
  625. def recompute_enable(self):
  626. r"""
  627. Enable Recompute.
  628. All layers with the `enable_recompute` attribute will be set to `True`
  629. """
  630. def fn(layer):
  631. if hasattr(layer, "enable_recompute") and (
  632. layer.enable_recompute is False or layer.enable_recompute == 0
  633. ):
  634. layer.enable_recompute = True
  635. self.apply(fn)
  636. def recompute_disable(self):
  637. r"""
  638. Disable Recompute.
  639. All layers with the `enable_recompute` attribute will be set to `False`
  640. """
  641. def fn(layer):
  642. if hasattr(layer, "enable_recompute") and (
  643. layer.enable_recompute is False or layer.enable_recompute == 0
  644. ):
  645. layer.enable_recompute = True
  646. self.apply(fn)
  647. def tie_weights(self):
  648. """
  649. Tie the weights between the input embeddings and the output embeddings.
  650. """
  651. if self.config.tie_word_embeddings:
  652. output_embeddings = self.get_output_embeddings()
  653. input_embeddings = self.get_input_embeddings()
  654. if output_embeddings is not None and input_embeddings is not None:
  655. if input_embeddings.weight.shape != output_embeddings.weight.shape:
  656. logging.warning(
  657. f"The shape of input embeddings is {input_embeddings.weight.shape} and the shape of output embeddings is {output_embeddings.weight.shape}. "
  658. "This is only expected if you are calling the `resize_token_embeddings` method"
  659. )
  660. output_embeddings.weight = input_embeddings.weight
  661. if getattr(output_embeddings, "bias", None) is not None:
  662. # need to pad
  663. if (
  664. output_embeddings.weight.shape[0]
  665. > output_embeddings.bias.shape[0]
  666. ):
  667. old_bias = output_embeddings.bias
  668. pad_length = (
  669. output_embeddings.weight.shape[0] - old_bias.shape[0]
  670. )
  671. output_embeddings.bias = output_embeddings.create_parameter(
  672. shape=[output_embeddings.weight.shape[0]],
  673. attr=output_embeddings._bias_attr,
  674. dtype=output_embeddings._dtype,
  675. is_bias=True,
  676. )
  677. new_bias = paddle.concat(
  678. [
  679. old_bias,
  680. paddle.zeros(
  681. [pad_length], dtype=output_embeddings.bias.dtype
  682. ),
  683. ]
  684. )
  685. output_embeddings.bias.set_value(new_bias)
  686. # need to trim
  687. elif (
  688. output_embeddings.weight.shape[0]
  689. < output_embeddings.bias.shape[0]
  690. ):
  691. new_bias = output_embeddings.bias[
  692. : output_embeddings.weight.shape[0]
  693. ]
  694. output_embeddings.bias = output_embeddings.create_parameter(
  695. shape=[output_embeddings.weight.shape[0]],
  696. attr=output_embeddings._bias_attr,
  697. dtype=output_embeddings._dtype,
  698. is_bias=True,
  699. )
  700. output_embeddings.bias.set_value(new_bias)
  701. def resize_position_embeddings(self, new_num_position_embeddings: int):
  702. """resize position embedding, this method should be overrited overwrited by downstream models
  703. Args:
  704. new_num_position_embeddings (int): the new position size
  705. Raises:
  706. NotImplementedError: when called and not be implemented
  707. """
  708. raise NotImplementedError(
  709. f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
  710. f"overwrite this method in the class {self.__class__} in `{self.__class__.__module__}.py`"
  711. )
  712. @classmethod
  713. def constructed_from_pretrained_config(cls, init_func=None) -> bool:
  714. """check if the model is constructed from `PretrainedConfig`
  715. Returns:
  716. bool: if the model is constructed from `PretrainedConfig`
  717. """
  718. return cls.config_class is not None and issubclass(
  719. cls.config_class, PretrainedConfig
  720. )
  721. def resize_token_embeddings(
  722. self, new_num_tokens: Optional[int] = None
  723. ) -> nn.Embedding:
  724. """
  725. Resizes input token embeddings matrix of the model according to new_num_tokens.
  726. Args:
  727. new_num_tokens (Optional[int]):
  728. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  729. vectors at the end. Reducing the size will remove vectors from the end. If not provided or None, just
  730. returns a pointer to the input tokens embedding module of the model without doing anything.
  731. Returns:
  732. paddle.nn.Embedding: The input tokens Embeddings Module of the model.
  733. """
  734. old_embeddings: nn.Embedding = self.get_input_embeddings()
  735. if not new_num_tokens or new_num_tokens == old_embeddings.weight.shape[0]:
  736. return old_embeddings
  737. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
  738. self.set_input_embeddings(new_embeddings)
  739. # 2. Update vocab_size
  740. self.base_model.config["vocab_size"] = new_num_tokens
  741. self.vocab_size = new_num_tokens
  742. # update init_config
  743. self._update_init_config(self.init_config, "vocab_size", new_num_tokens)
  744. # Tie the weights between the input embeddings and the output embeddings if needed.
  745. self.tie_weights()
  746. return new_embeddings
  747. def _update_init_config(self, init_config: dict, key: str, value: Any):
  748. """update init_config by <key, value> pair
  749. Args:
  750. init_config (dict): the init_config instance
  751. key (str): the key field
  752. value (Any): the new value of instance
  753. """
  754. if key in init_config:
  755. init_config[key] = value
  756. return
  757. for arg in init_config.get("init_args", []):
  758. if not isinstance(arg, PretrainedModel):
  759. continue
  760. self._update_init_config(arg.init_config, key, value)
  761. def _get_resized_embeddings(
  762. self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None
  763. ) -> nn.Embedding:
  764. """
  765. Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
  766. initialized vectors at the end. Reducing the size will remove vectors from the end
  767. Args:
  768. old_embeddings (nn.Embedding):
  769. Old embeddings to be resized.
  770. new_num_tokens (Optional[int]):
  771. New number of tokens in the embedding matrix.
  772. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  773. vectors from the end.
  774. Returns:
  775. paddle.nn.Embedding: The resized Embedding Module or the old Embedding Module if new_num_tokens is None.
  776. """
  777. if new_num_tokens is None:
  778. return old_embeddings
  779. old_num_tokens, old_embedding_dim = old_embeddings.weight.shape
  780. if old_num_tokens == new_num_tokens:
  781. return old_embeddings
  782. if not isinstance(old_embeddings, nn.Embedding):
  783. raise TypeError(
  784. f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
  785. " should either use a different resize function or make sure that old_embeddings are an instance of"
  786. f" {nn.Embedding}."
  787. )
  788. # Build new embeddings
  789. new_embeddings = nn.Embedding(
  790. new_num_tokens,
  791. old_embedding_dim,
  792. padding_idx=old_embeddings._padding_idx,
  793. sparse=old_embeddings._sparse,
  794. )
  795. # make sure that new_embeddings's dtype is same as the old embeddings' dtype
  796. if new_embeddings.weight.dtype != old_embeddings.weight.dtype:
  797. new_embeddings.to(dtype=old_embeddings.weight.dtype)
  798. # numbers of tokens to copy
  799. n = min(old_num_tokens, new_num_tokens)
  800. with paddle.no_grad():
  801. new_embeddings.weight[:n, :] = old_embeddings.weight[:n, :]
  802. return new_embeddings
  803. def __setattr__(self, name, value):
  804. value = adapt_stale_fwd_patch(self, name, value)
  805. return super(PretrainedModel, self).__setattr__(name, value)
  806. @classmethod
  807. def _resolve_model_file_path(
  808. cls: Type[PretrainedModel],
  809. pretrained_model_name_or_path: str,
  810. from_hf_hub: bool = False,
  811. from_aistudio: bool = False,
  812. cache_dir: str | None = None,
  813. subfolder: Optional[str] = "",
  814. config: PretrainedConfig = None,
  815. convert_from_torch: bool = False,
  816. use_safetensors: bool | None = None,
  817. variant=None,
  818. ) -> str:
  819. """resolve model target file path from `` and `cache_dir`
  820. 1. when it is file path:
  821. return the weight file
  822. 2. when it is model-name:
  823. 2.1 check default `MODEL_HOME` + `model-mame` + model_state.pdparams
  824. 2.2 get the url from `pretrained_resource_files_map`, and set it to `pretrained_model_name_or_path`
  825. 3. when it is local dir:
  826. check whether the file<local_dir + weight_file> exist
  827. Args:
  828. cls (Type[PretrainedModel]): the inherited PretrainedModel class
  829. pretrained_model_name_or_path (str): the model-name/url/local_dir/local_dir
  830. cache_dir (Optional[str], optional): cache_dir is used when name_or_path is model-name/url. Defaults to None.
  831. convert_from_torch (bool, optional): whether support convert pytorch model to paddle model
  832. Returns:
  833. str: the model weight file path
  834. """
  835. is_sharded = False
  836. sharded_metadata = None
  837. if pretrained_model_name_or_path is not None:
  838. # the following code use a lot of os.path.join, hence setting subfolder to empty str if None
  839. if subfolder is None:
  840. subfolder = ""
  841. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  842. is_local = os.path.isdir(pretrained_model_name_or_path)
  843. def get_file_path(
  844. pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, variant
  845. ):
  846. return os.path.join(
  847. pretrained_model_name_or_path,
  848. subfolder,
  849. _add_variant(SAFE_WEIGHTS_NAME, variant),
  850. )
  851. # pretrained_model_name_or_path is file
  852. if os.path.isfile(pretrained_model_name_or_path):
  853. archive_file = pretrained_model_name_or_path
  854. is_local = True
  855. # pretrained_model_name_or_path is dir
  856. elif is_local:
  857. if use_safetensors is not False and os.path.isfile(
  858. get_file_path(
  859. pretrained_model_name_or_path,
  860. subfolder,
  861. SAFE_WEIGHTS_INDEX_NAME,
  862. variant,
  863. )
  864. ):
  865. # Load from a sharded safetensors checkpoint
  866. archive_file = get_file_path(
  867. pretrained_model_name_or_path,
  868. subfolder,
  869. SAFE_WEIGHTS_INDEX_NAME,
  870. variant,
  871. )
  872. is_sharded = True
  873. elif use_safetensors is not False and os.path.isfile(
  874. get_file_path(
  875. pretrained_model_name_or_path,
  876. subfolder,
  877. SAFE_WEIGHTS_INDEX_NAME,
  878. weight_name_suffix(),
  879. )
  880. ):
  881. # Load from a sharded safetensors checkpoint
  882. archive_file = get_file_path(
  883. pretrained_model_name_or_path,
  884. subfolder,
  885. SAFE_WEIGHTS_INDEX_NAME,
  886. weight_name_suffix(),
  887. )
  888. is_sharded = True
  889. elif use_safetensors is not False and os.path.isfile(
  890. get_file_path(
  891. pretrained_model_name_or_path,
  892. subfolder,
  893. SAFE_WEIGHTS_NAME,
  894. variant,
  895. )
  896. ):
  897. # Load from a safetensors checkpoint
  898. archive_file = get_file_path(
  899. pretrained_model_name_or_path,
  900. subfolder,
  901. SAFE_WEIGHTS_NAME,
  902. variant,
  903. )
  904. elif use_safetensors is not False and os.path.isfile(
  905. get_file_path(
  906. pretrained_model_name_or_path,
  907. subfolder,
  908. SAFE_WEIGHTS_NAME,
  909. weight_name_suffix(),
  910. )
  911. ):
  912. # Load from a safetensors checkpoint
  913. archive_file = get_file_path(
  914. pretrained_model_name_or_path,
  915. subfolder,
  916. SAFE_WEIGHTS_NAME,
  917. weight_name_suffix(),
  918. )
  919. elif os.path.isfile(
  920. get_file_path(
  921. pretrained_model_name_or_path,
  922. subfolder,
  923. PADDLE_WEIGHTS_INDEX_NAME,
  924. variant,
  925. )
  926. ):
  927. # Load from a sharded PaddlePaddle checkpoint
  928. archive_file = get_file_path(
  929. pretrained_model_name_or_path,
  930. subfolder,
  931. PADDLE_WEIGHTS_INDEX_NAME,
  932. variant,
  933. )
  934. is_sharded = True
  935. elif os.path.isfile(
  936. get_file_path(
  937. pretrained_model_name_or_path,
  938. subfolder,
  939. PADDLE_WEIGHTS_INDEX_NAME,
  940. weight_name_suffix(),
  941. )
  942. ):
  943. # Load from a sharded PaddlePaddle checkpoint for hybrid parallel model
  944. archive_file = get_file_path(
  945. pretrained_model_name_or_path,
  946. subfolder,
  947. PADDLE_WEIGHTS_INDEX_NAME,
  948. weight_name_suffix(),
  949. )
  950. is_sharded = True
  951. elif os.path.isfile(
  952. get_file_path(
  953. pretrained_model_name_or_path,
  954. subfolder,
  955. PADDLE_WEIGHTS_NAME,
  956. variant,
  957. )
  958. ):
  959. # Load from a PaddlePaddle checkpoint
  960. archive_file = get_file_path(
  961. pretrained_model_name_or_path,
  962. subfolder,
  963. PADDLE_WEIGHTS_NAME,
  964. variant,
  965. )
  966. elif os.path.isfile(
  967. get_file_path(
  968. pretrained_model_name_or_path,
  969. subfolder,
  970. PADDLE_WEIGHTS_NAME,
  971. weight_name_suffix(),
  972. )
  973. ):
  974. # Load from a PaddlePaddle checkpoint for hybrid parallel model
  975. archive_file = get_file_path(
  976. pretrained_model_name_or_path,
  977. subfolder,
  978. PADDLE_WEIGHTS_NAME,
  979. weight_name_suffix(),
  980. )
  981. elif os.path.isfile(
  982. os.path.join(
  983. pretrained_model_name_or_path,
  984. subfolder,
  985. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  986. )
  987. ):
  988. if from_hf_hub or convert_from_torch:
  989. archive_file = os.path.join(
  990. pretrained_model_name_or_path,
  991. subfolder,
  992. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  993. )
  994. else:
  995. raise ValueError(
  996. f"Found {_add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant)} in directory"
  997. f" {pretrained_model_name_or_path}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
  998. )
  999. elif os.path.isfile(
  1000. os.path.join(
  1001. pretrained_model_name_or_path,
  1002. subfolder,
  1003. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  1004. )
  1005. ):
  1006. if from_hf_hub or convert_from_torch:
  1007. archive_file = os.path.join(
  1008. pretrained_model_name_or_path,
  1009. subfolder,
  1010. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  1011. )
  1012. else:
  1013. raise ValueError(
  1014. f"Found {_add_variant(PYTORCH_WEIGHTS_NAME, variant)} in directory"
  1015. f" {pretrained_model_name_or_path}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
  1016. )
  1017. else:
  1018. raise EnvironmentError(
  1019. f"Error no file named {_add_variant(PADDLE_WEIGHTS_NAME, variant)}, found in directory"
  1020. f" {pretrained_model_name_or_path}."
  1021. )
  1022. elif pretrained_model_name_or_path in cls.pretrained_init_configuration:
  1023. # fetch the weight url from the `pretrained_resource_files_map`
  1024. resource_file_url = cls.pretrained_resource_files_map["model_state"][
  1025. pretrained_model_name_or_path
  1026. ]
  1027. resolved_archive_file = resolve_file_path(
  1028. pretrained_model_name_or_path,
  1029. [resource_file_url],
  1030. subfolder,
  1031. cache_dir=cache_dir,
  1032. from_aistudio=from_aistudio,
  1033. from_hf_hub=from_hf_hub,
  1034. )
  1035. else:
  1036. if use_safetensors is True:
  1037. filenames = [
  1038. _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
  1039. _add_variant(SAFE_WEIGHTS_NAME, variant),
  1040. ]
  1041. elif use_safetensors is None:
  1042. filenames = [
  1043. _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
  1044. _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
  1045. _add_variant(SAFE_WEIGHTS_NAME, variant),
  1046. _add_variant(PADDLE_WEIGHTS_NAME, variant),
  1047. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  1048. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  1049. ]
  1050. else:
  1051. filenames = [
  1052. _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
  1053. _add_variant(PADDLE_WEIGHTS_NAME, variant),
  1054. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  1055. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  1056. ]
  1057. resolved_archive_file = resolve_file_path(
  1058. pretrained_model_name_or_path,
  1059. filenames,
  1060. subfolder,
  1061. cache_dir=cache_dir,
  1062. from_aistudio=from_aistudio,
  1063. from_hf_hub=from_hf_hub,
  1064. )
  1065. if resolved_archive_file is None:
  1066. raise EnvironmentError(
  1067. f"Error no files {filenames} found in repo {pretrained_model_name_or_path}."
  1068. )
  1069. elif "pytorch_model.bin" in str(resolved_archive_file):
  1070. if not from_hf_hub and not convert_from_torch:
  1071. raise ValueError(
  1072. f"Download pytorch weight in "
  1073. f" {resolved_archive_file}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
  1074. )
  1075. if is_local:
  1076. logging.info(f"Loading weights file {archive_file}")
  1077. resolved_archive_file = archive_file
  1078. else:
  1079. logging.info(
  1080. f"Loading weights file from cache at {resolved_archive_file}"
  1081. )
  1082. else:
  1083. resolved_archive_file = None
  1084. # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
  1085. resolved_sharded_files = None
  1086. if str(resolved_archive_file).endswith(".json"):
  1087. is_sharded = True
  1088. if is_sharded:
  1089. # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
  1090. resolved_sharded_files, sharded_metadata = get_checkpoint_shard_files(
  1091. pretrained_model_name_or_path,
  1092. resolved_archive_file,
  1093. from_aistudio=from_aistudio,
  1094. from_hf_hub=from_hf_hub,
  1095. cache_dir=cache_dir,
  1096. subfolder=subfolder,
  1097. )
  1098. return (
  1099. resolved_archive_file,
  1100. resolved_sharded_files,
  1101. sharded_metadata,
  1102. is_sharded,
  1103. )
  1104. @classmethod
  1105. def _load_pretrained_model(
  1106. cls,
  1107. model: PretrainedModel,
  1108. state_dict: Dict[str, Tensor],
  1109. loaded_keys: List[str],
  1110. resolved_archive_file: Union[str, List] = [],
  1111. pretrained_model_name_or_path=None,
  1112. config=None,
  1113. ignore_mismatched_sizes=False,
  1114. low_cpu_mem_usage=False,
  1115. dtype=None,
  1116. keep_in_fp32_modules=None,
  1117. quantization_linear_list=None,
  1118. sharded_metadata=None,
  1119. convert_from_hf=False,
  1120. ) -> Tuple[List[str]]:
  1121. """load the state_dict into model, and do the following things:
  1122. * check the
  1123. Args:
  1124. model (PretrainedModel): the pretrained model instance
  1125. state_dict (Dict[str, Tensor]): the model state dict data
  1126. loaded_keys (List[str]):
  1127. ignore_mismatched_sizes (bool, optional): whether ignore error when tensor size mismatched. Defaults to False.
  1128. dtype (_type_, optional): the dtype of model state dict. Defaults to None.
  1129. Returns:
  1130. Tuple[List[str]]: _description_
  1131. """
  1132. is_safetensors = False
  1133. if convert_from_hf:
  1134. try:
  1135. model_state_dict = model.get_hf_state_dict()
  1136. except NotImplementedError:
  1137. model_state_dict = model.state_dict()
  1138. else:
  1139. model_state_dict = model.state_dict()
  1140. expected_keys = list(model_state_dict.keys())
  1141. prefix = model.base_model_prefix
  1142. if len(prefix) > 0:
  1143. has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
  1144. expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
  1145. else:
  1146. has_prefix_module = False
  1147. expects_prefix_module = False
  1148. # key re-naming operations are never done on the keys
  1149. # that are loaded, but always on the keys of the newly initialized model
  1150. remove_prefix_from_model = not has_prefix_module and expects_prefix_module
  1151. add_prefix_to_model = has_prefix_module and not expects_prefix_module
  1152. if remove_prefix_from_model:
  1153. _prefix = f"{prefix}."
  1154. expected_keys_not_prefixed = [
  1155. s for s in expected_keys if not s.startswith(_prefix)
  1156. ]
  1157. expected_keys = [
  1158. s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys
  1159. ]
  1160. if quantization_linear_list is not None:
  1161. quantization_linear_list = [
  1162. s[len(_prefix) :] if s.startswith(_prefix) else s
  1163. for s in quantization_linear_list
  1164. ]
  1165. elif add_prefix_to_model:
  1166. expected_keys = [".".join([prefix, s]) for s in expected_keys]
  1167. if quantization_linear_list is not None:
  1168. quantization_linear_list = [
  1169. ".".join([prefix, s]) for s in quantization_linear_list
  1170. ]
  1171. # Weight quantization if not yet quantized & update loaded_keys
  1172. if (
  1173. hasattr(config, "quantization_config")
  1174. and config.quantization_config.is_weight_quantize()
  1175. ):
  1176. try:
  1177. from ..quantization.quantization_utils import (
  1178. convert_to_quantize_state_dict,
  1179. update_loaded_state_dict_keys,
  1180. )
  1181. except ImportError:
  1182. raise ImportError(
  1183. "Quantization features require `paddlepaddle >= 2.5.2`"
  1184. )
  1185. if state_dict is not None:
  1186. state_dict = convert_to_quantize_state_dict(
  1187. state_dict,
  1188. quantization_linear_list,
  1189. config.quantization_config,
  1190. dtype,
  1191. )
  1192. loaded_keys = [k for k in state_dict.keys()]
  1193. else:
  1194. loaded_keys = update_loaded_state_dict_keys(
  1195. loaded_keys, quantization_linear_list, config.quantization_config
  1196. )
  1197. if keep_in_fp32_modules is None:
  1198. keep_in_fp32_modules = (
  1199. ["quant_scale"]
  1200. if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
  1201. else None
  1202. )
  1203. else:
  1204. keep_in_fp32_modules = (
  1205. keep_in_fp32_modules + ["quant_scale"]
  1206. if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
  1207. else keep_in_fp32_modules
  1208. )
  1209. missing_keys = list(set(expected_keys) - set(loaded_keys))
  1210. unexpected_keys = list(set(loaded_keys) - set(expected_keys))
  1211. # Optimize for skip unused shard files for supper large model
  1212. if sharded_metadata is not None:
  1213. assert isinstance(resolved_archive_file, list)
  1214. new_archive_file = []
  1215. skip_archive_file = []
  1216. expected_keys_set = set(expected_keys)
  1217. for file in resolved_archive_file:
  1218. filename = os.path.split(file)[-1]
  1219. if not expected_keys_set.isdisjoint(
  1220. set(sharded_metadata["file_map"][filename])
  1221. ):
  1222. new_archive_file.append(file)
  1223. else:
  1224. skip_archive_file.append(filename)
  1225. resolved_archive_file = new_archive_file
  1226. if len(skip_archive_file) > 0:
  1227. logging.info(
  1228. f"Skip load files for not contrains expected key, {skip_archive_file}"
  1229. )
  1230. # Some models may have keys that are not in the state by design, removing them before needlessly warning
  1231. # the user.
  1232. if cls._keys_to_ignore_on_load_missing is not None:
  1233. for pat in cls._keys_to_ignore_on_load_missing:
  1234. missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
  1235. if cls._keys_to_ignore_on_load_unexpected is not None:
  1236. for pat in cls._keys_to_ignore_on_load_unexpected:
  1237. unexpected_keys = [
  1238. k for k in unexpected_keys if re.search(pat, k) is None
  1239. ]
  1240. # Set some modules to fp32 if any
  1241. if keep_in_fp32_modules is not None:
  1242. for name, param in model.named_parameters():
  1243. if any(
  1244. module_to_keep_in_fp32 in name
  1245. for module_to_keep_in_fp32 in keep_in_fp32_modules
  1246. ):
  1247. if param.dtype != paddle.float32:
  1248. param_fp32 = param.cast(dtype=paddle.float32)
  1249. param_fp32_tensor = param_fp32.value().get_tensor()
  1250. param_tensor = param.value().get_tensor()
  1251. param_tensor._share_data_with(param_fp32_tensor)
  1252. # Make sure we are able to load base models as well as derived models (with heads)
  1253. start_prefix = ""
  1254. model_to_load = model
  1255. if (
  1256. len(cls.base_model_prefix) > 0
  1257. and not hasattr(model, cls.base_model_prefix)
  1258. and has_prefix_module
  1259. ):
  1260. start_prefix = cls.base_model_prefix + "."
  1261. if (
  1262. len(cls.base_model_prefix) > 0
  1263. and hasattr(model, cls.base_model_prefix)
  1264. and not has_prefix_module
  1265. ):
  1266. model_to_load = getattr(model, cls.base_model_prefix)
  1267. base_model_expected_keys = list(model_to_load.state_dict().keys())
  1268. if any(
  1269. key in expected_keys_not_prefixed
  1270. and key not in base_model_expected_keys
  1271. for key in loaded_keys
  1272. ):
  1273. raise ValueError(
  1274. "The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
  1275. "properly saved?"
  1276. )
  1277. def _find_mismatched_keys(
  1278. state_dict,
  1279. model_state_dict,
  1280. loaded_keys,
  1281. add_prefix_to_model,
  1282. remove_prefix_from_model,
  1283. ignore_mismatched_sizes,
  1284. ):
  1285. mismatched_keys = []
  1286. if ignore_mismatched_sizes:
  1287. for checkpoint_key in loaded_keys:
  1288. # If the checkpoint is sharded, we may not have the key here.
  1289. if checkpoint_key not in state_dict:
  1290. continue
  1291. model_key = checkpoint_key
  1292. if remove_prefix_from_model:
  1293. # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
  1294. model_key = f"{prefix}.{checkpoint_key}"
  1295. elif add_prefix_to_model:
  1296. # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
  1297. model_key = ".".join(checkpoint_key.split(".")[1:])
  1298. if (
  1299. model_key in model_state_dict
  1300. and state_dict[checkpoint_key].shape
  1301. != model_state_dict[model_key].shape
  1302. ):
  1303. mismatched_keys.append(
  1304. (
  1305. checkpoint_key,
  1306. state_dict[checkpoint_key].shape,
  1307. model_state_dict[model_key].shape,
  1308. )
  1309. )
  1310. del state_dict[checkpoint_key]
  1311. return mismatched_keys
  1312. def _fuse_or_split_keys(
  1313. state_dict,
  1314. config,
  1315. loaded_keys,
  1316. pre_tensor_parallel_split=False,
  1317. resume_state_dict=None,
  1318. ):
  1319. if resume_state_dict is not None:
  1320. state_dict.update(resume_state_dict)
  1321. before_fuse_keys = list(state_dict.keys())
  1322. if pre_tensor_parallel_split:
  1323. tp_actions = cls.get_tensor_parallel_convert_actions(
  1324. config, loaded_keys, ignore_error=True
  1325. )
  1326. else:
  1327. tp_actions = None
  1328. state_dict, resume_state_dict = cls.convert_fuse_and_split(
  1329. config, state_dict, tp_actions
  1330. )
  1331. after_fuse_keys = list(state_dict.keys())
  1332. fused_keys = list(set(before_fuse_keys) - set(after_fuse_keys))
  1333. new_keys = list(set(after_fuse_keys) - set(before_fuse_keys))
  1334. return state_dict, resume_state_dict, fused_keys, new_keys
  1335. if state_dict is not None:
  1336. # have loaded all state_dict, no resume state_dict
  1337. state_dict, _, fused_keys, new_keys = _fuse_or_split_keys(
  1338. state_dict,
  1339. config,
  1340. loaded_keys,
  1341. pre_tensor_parallel_split=(
  1342. True
  1343. if config is not None and config.tensor_parallel_degree > 1
  1344. else False
  1345. ),
  1346. )
  1347. missing_keys = list(set(missing_keys) - set(new_keys))
  1348. unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
  1349. mismatched_keys = _find_mismatched_keys(
  1350. state_dict,
  1351. model_state_dict,
  1352. loaded_keys,
  1353. add_prefix_to_model,
  1354. remove_prefix_from_model,
  1355. ignore_mismatched_sizes,
  1356. )
  1357. if (
  1358. hasattr(config, "quantization_config")
  1359. and config.quantization_config.is_weight_quantize()
  1360. ):
  1361. error_msgs = _load_state_dict_into_meta_model(
  1362. model_to_load,
  1363. state_dict,
  1364. loaded_keys,
  1365. start_prefix,
  1366. expected_keys,
  1367. dtype=dtype,
  1368. is_safetensors=is_safetensors,
  1369. keep_in_fp32_modules=keep_in_fp32_modules,
  1370. )
  1371. else:
  1372. error_msgs = _load_state_dict_into_model(
  1373. model_to_load,
  1374. state_dict,
  1375. start_prefix,
  1376. convert_from_hf=convert_from_hf,
  1377. )
  1378. else:
  1379. # Sharded checkpoint or whole but low_cpu_mem_usage==True
  1380. # This should always be a list but, just to be sure.
  1381. if not isinstance(resolved_archive_file, list):
  1382. resolved_archive_file = [resolved_archive_file]
  1383. error_msgs = []
  1384. mismatched_keys = []
  1385. resume_state_dict = {}
  1386. for shard_file in resolved_archive_file:
  1387. pre_tensor_parallel_split = False
  1388. if (
  1389. shard_file.endswith(".safetensors")
  1390. and config.tensor_parallel_degree > 1
  1391. and "tp" not in os.path.split(shard_file)[-1]
  1392. ):
  1393. pre_tensor_parallel_split = True
  1394. assert loaded_keys is not None, "loaded_keys is not None."
  1395. tp_actions = cls.get_tensor_parallel_convert_actions(
  1396. config, loaded_keys, ignore_error=True
  1397. )
  1398. # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
  1399. filter_dict_keys = set(expected_keys)
  1400. fuse_actions, _ = cls.get_fuse_or_split_param_convert_actions(
  1401. config, loaded_keys, is_fuse=True
  1402. )
  1403. split_actions, _ = cls.get_fuse_or_split_param_convert_actions(
  1404. config, loaded_keys, is_fuse=False
  1405. )
  1406. for k in list(fuse_actions.keys()):
  1407. need_add_except_key = k[-1] in expected_keys
  1408. if need_add_except_key:
  1409. filter_dict_keys |= set(k[:-1])
  1410. # remove pre_tensor_parallel_split function from tp_actions
  1411. if pre_tensor_parallel_split:
  1412. for item in k[:-1]:
  1413. if item in tp_actions:
  1414. tp_actions.pop(item, None)
  1415. for k in list(split_actions.keys()):
  1416. need_add_except_key = False
  1417. for item in k[:-1]:
  1418. if item in expected_keys:
  1419. need_add_except_key = True
  1420. break
  1421. if need_add_except_key:
  1422. filter_dict_keys.add(k[-1])
  1423. # remove pre_tensor_parallel_split function from tp_actions
  1424. if pre_tensor_parallel_split:
  1425. if k[-1] in tp_actions:
  1426. fuse_actions.pop(k[-1], None)
  1427. if config.quantization_config.is_weight_quantize():
  1428. filter_dict_keys = None
  1429. try:
  1430. transpose_weight_keys = model.get_transpose_weight_keys()
  1431. except NotImplementedError:
  1432. if convert_from_hf:
  1433. raise ValueError("`convert_from_hf=True` is not supported")
  1434. else:
  1435. transpose_weight_keys = None
  1436. state_dict = load_state_dict(
  1437. shard_file,
  1438. tp_actions if pre_tensor_parallel_split else None,
  1439. filter_dict_keys,
  1440. convert_from_hf=convert_from_hf,
  1441. transpose_weight_keys=transpose_weight_keys,
  1442. )
  1443. # convert for fusing or splitting weights
  1444. state_dict, resume_state_dict, fused_keys, new_keys = (
  1445. _fuse_or_split_keys(
  1446. state_dict,
  1447. config,
  1448. loaded_keys,
  1449. pre_tensor_parallel_split=pre_tensor_parallel_split,
  1450. resume_state_dict=resume_state_dict,
  1451. )
  1452. )
  1453. missing_keys = list(set(missing_keys) - set(new_keys))
  1454. unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
  1455. if config.quantization_config.is_weight_quantize():
  1456. state_dict = convert_to_quantize_state_dict(
  1457. state_dict,
  1458. quantization_linear_list,
  1459. config.quantization_config,
  1460. dtype,
  1461. )
  1462. # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
  1463. # matching the weights in the model.
  1464. mismatched_keys += _find_mismatched_keys(
  1465. state_dict,
  1466. model_state_dict,
  1467. loaded_keys,
  1468. add_prefix_to_model,
  1469. remove_prefix_from_model,
  1470. ignore_mismatched_sizes,
  1471. )
  1472. if (
  1473. config.tensor_parallel_degree > 1
  1474. and ".tp" not in shard_file
  1475. and not pre_tensor_parallel_split
  1476. ):
  1477. logging.info("Converting state_dict to Tensor Parallel Format")
  1478. # ignore error for multi shard, since only parts of data
  1479. state_dict = cls.convert_tensor_parallel(
  1480. None,
  1481. config,
  1482. state_dict=state_dict,
  1483. ignore_error=len(resolved_archive_file) > 1,
  1484. )
  1485. logging.info("Converted state_dict to Tensor Parallel Format")
  1486. if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
  1487. new_error_msgs = _load_state_dict_into_meta_model(
  1488. model_to_load,
  1489. state_dict,
  1490. loaded_keys,
  1491. start_prefix,
  1492. expected_keys,
  1493. dtype=dtype,
  1494. is_safetensors=is_safetensors,
  1495. keep_in_fp32_modules=keep_in_fp32_modules,
  1496. )
  1497. error_msgs += new_error_msgs
  1498. else:
  1499. error_msgs += _load_state_dict_into_model(
  1500. model_to_load,
  1501. state_dict,
  1502. start_prefix,
  1503. convert_from_hf=convert_from_hf,
  1504. )
  1505. # force memory release
  1506. del state_dict
  1507. gc.collect()
  1508. if len(error_msgs) > 0:
  1509. error_msg = "\n\t".join(error_msgs)
  1510. if " but the expected shape is" in error_msg:
  1511. error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
  1512. raise RuntimeError(
  1513. f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
  1514. )
  1515. if len(unexpected_keys) > 0:
  1516. logging.warning(
  1517. f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
  1518. f" initializing {model.__class__.__name__}: {sorted(unexpected_keys)}\n- This IS expected if you are"
  1519. f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
  1520. " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
  1521. " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
  1522. f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
  1523. " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
  1524. )
  1525. else:
  1526. logging.info(
  1527. f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
  1528. )
  1529. if len(missing_keys) > 0:
  1530. logging.warning(
  1531. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  1532. f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
  1533. " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
  1534. )
  1535. elif len(mismatched_keys) == 0:
  1536. logging.info(
  1537. f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
  1538. f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
  1539. f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
  1540. " training."
  1541. )
  1542. if len(mismatched_keys) > 0:
  1543. mismatched_warning = "\n".join(
  1544. [
  1545. f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
  1546. for key, shape1, shape2 in mismatched_keys
  1547. ]
  1548. )
  1549. logging.warning(
  1550. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  1551. f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
  1552. f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
  1553. " to use it for predictions and inference."
  1554. )
  1555. return model, missing_keys, unexpected_keys, mismatched_keys
  1556. @classmethod
  1557. def from_pretrained(
  1558. cls, pretrained_model_name_or_path, *args, convert_from_hf=False, **kwargs
  1559. ):
  1560. """
  1561. Creates an instance of `PretrainedModel`. Model weights are loaded
  1562. by specifying name of a built-in pretrained model, a pretrained model from HF Hub, a community contributed model,
  1563. or a local file directory path.
  1564. Args:
  1565. pretrained_model_name_or_path (str): Name of pretrained model or dir path
  1566. to load from. The string can be:
  1567. - Name of a built-in pretrained model
  1568. - Name of a pretrained model from HF Hub
  1569. - Name of a community-contributed pretrained model.
  1570. - Local directory path which contains model weights file("model_state.pdparams")
  1571. and model config file ("model_config.json").
  1572. from_hf_hub (bool): load model from huggingface hub. Default to `False`.
  1573. subfolder (str, optional) An optional value corresponding to a folder inside the repo.
  1574. Only works when loading from Huggingface Hub.
  1575. *args (tuple): Position arguments for model `__init__`. If provided,
  1576. use these as position argument values for model initialization.
  1577. **kwargs (dict): Keyword arguments for model `__init__`. If provided,
  1578. use these to update pre-defined keyword argument values for model
  1579. initialization. If the keyword is in `__init__` argument names of
  1580. base model, update argument values of the base model; else update
  1581. argument values of derived model.
  1582. load_state_as_np (bool, optional): The weights read in can be choosed
  1583. to place on CPU or GPU though the model is on the default device.
  1584. If `True`, load the model weights as `numpy.ndarray` on CPU.
  1585. Otherwise, weights would be loaded as tensors on the default
  1586. device. Note that if on GPU, the latter would creates extra
  1587. temporary tensors in addition to the model weights, which
  1588. doubles the memory usage . Thus it is suggested to use `True`
  1589. for big models on GPU. Default to `False`.
  1590. Returns:
  1591. PretrainedModel: An instance of `PretrainedModel`.
  1592. Example:
  1593. .. code-block::
  1594. from paddlenlp.transformers import BertForSequenceClassification
  1595. # Name of built-in pretrained model
  1596. model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
  1597. # Name of pretrained model from PaddleHub
  1598. model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
  1599. # Name of community-contributed pretrained model
  1600. model = BertForSequenceClassification.from_pretrained('yingyibiao/bert-base-uncased-sst-2-finetuned', num_labels=3)
  1601. # Load from local directory path
  1602. model = BertForSequenceClassification.from_pretrained('./my_bert/')
  1603. """
  1604. config = kwargs.pop("config", None)
  1605. state_dict = kwargs.pop("state_dict", None)
  1606. cache_dir = kwargs.pop("cache_dir", None)
  1607. force_download = kwargs.get("force_download", False)
  1608. ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
  1609. dtype = kwargs.pop("dtype", None)
  1610. from_hf_hub = kwargs.pop("from_hf_hub", False)
  1611. from_aistudio = kwargs.pop("from_aistudio", False)
  1612. subfolder = kwargs.pop("subfolder", None)
  1613. if subfolder is None:
  1614. subfolder = ""
  1615. variant = kwargs.pop("variant", None)
  1616. use_safetensors = kwargs.pop(
  1617. "use_safetensors", None if is_dep_available("safetensors") else False
  1618. )
  1619. low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
  1620. convert_from_torch = kwargs.pop("convert_from_torch", None)
  1621. load_state_as_np = kwargs.pop("load_state_as_np", None)
  1622. if load_state_as_np is not None:
  1623. logging.warning("`load_state_as_np` is deprecated, please delete it!")
  1624. model_kwargs = kwargs
  1625. if convert_from_torch is None and os.environ.get("from_modelscope", False):
  1626. logging.warning(
  1627. "If you are attempting to load weights from ModelScope Hub and want to disable the default behavior of considering torch weights,"
  1628. " you can set ·convert_from_torch=False·. By default, `convert_from_torch` is set to `True`. "
  1629. )
  1630. convert_from_torch = True
  1631. # from_hf_hub default enable convert_from_torch
  1632. if from_hf_hub and convert_from_torch is None:
  1633. logging.warning(
  1634. "If you are attempting to load weights from Hugging Face Hub and want to disable the default behavior of considering torch weights,"
  1635. " you can set ·convert_from_torch=False·. By default, `convert_from_torch` is set to `True`. "
  1636. )
  1637. convert_from_torch = True
  1638. # convert_from_torch default is False
  1639. if convert_from_torch is None:
  1640. convert_from_torch = False
  1641. # 1. get the PretrainedConfig to init model
  1642. if not isinstance(config, PretrainedConfig):
  1643. config_path = (
  1644. config if config is not None else pretrained_model_name_or_path
  1645. )
  1646. config, model_kwargs = (
  1647. cls.config_class.from_pretrained( # NOTE cls.config_class : Qwen2VLForConditionalGeneration
  1648. config_path,
  1649. cache_dir=cache_dir,
  1650. from_hf_hub=from_hf_hub,
  1651. from_aistudio=from_aistudio,
  1652. subfolder=subfolder,
  1653. return_unused_kwargs=True,
  1654. **kwargs,
  1655. )
  1656. )
  1657. if "from_aistudio" in model_kwargs:
  1658. model_kwargs.pop("from_aistudio")
  1659. if dtype is None:
  1660. dtype = config.dtype
  1661. config.dtype = dtype
  1662. init_contexts = []
  1663. if dtype:
  1664. init_contexts.append(dtype_guard(dtype))
  1665. # Keep in fp32 modules
  1666. keep_in_fp32_modules = None
  1667. use_keep_in_fp32_modules = False
  1668. # resolve model_weight file
  1669. resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = (
  1670. cls._resolve_model_file_path(
  1671. pretrained_model_name_or_path,
  1672. cache_dir=cache_dir,
  1673. subfolder=subfolder,
  1674. from_hf_hub=from_hf_hub,
  1675. from_aistudio=from_aistudio,
  1676. config=config,
  1677. convert_from_torch=False,
  1678. use_safetensors=use_safetensors,
  1679. variant=variant,
  1680. )
  1681. )
  1682. init_args = config["init_args"] or ()
  1683. with ContextManagers(init_contexts):
  1684. model = cls(config, *init_args, **model_kwargs)
  1685. # modified by zhch158
  1686. transpose_weight_keys = [] # ✅ 默认初始化
  1687. if convert_from_torch and state_dict is None:
  1688. if (
  1689. resolved_archive_file.endswith(PYTORCH_WEIGHTS_NAME)
  1690. or resolved_archive_file.endswith(PYTORCH_WEIGHTS_INDEX_NAME)
  1691. or resolved_archive_file.endswith(SAFE_WEIGHTS_NAME)
  1692. or resolved_archive_file.endswith(SAFE_WEIGHTS_INDEX_NAME)
  1693. ):
  1694. # try to get the name-mapping info
  1695. convert_dir = os.path.dirname(resolved_archive_file)
  1696. logging.info(
  1697. f"Starting to convert pytorch weight file<{resolved_archive_file}> to "
  1698. f"paddle weight file<{convert_dir}> ..."
  1699. )
  1700. state_dict = cls.convert(
  1701. resolved_archive_file,
  1702. config,
  1703. # cache_dir=os.path.join(cache_dir, pretrained_model_name_or_path, subfolder),
  1704. cache_dir=convert_dir,
  1705. )
  1706. elif (
  1707. resolved_archive_file.endswith(PADDLE_WEIGHTS_NAME)
  1708. or resolved_archive_file.endswith(PADDLE_WEIGHTS_INDEX_NAME)
  1709. or resolved_archive_file.endswith(".pdparams")
  1710. ):
  1711. print(f"file: {resolved_archive_file} is paddle weight.")
  1712. else:
  1713. raise ValueError(
  1714. f"Unexpected file: {resolved_archive_file} for weight conversion."
  1715. )
  1716. # load pt weights early so that we know which dtype to init the model under
  1717. if not is_sharded and state_dict is None:
  1718. # 4. loading non-sharded ckpt from the state dict
  1719. if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
  1720. "model_state.pdparams"
  1721. ):
  1722. state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
  1723. elif config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
  1724. "model.safetensors"
  1725. ):
  1726. raise NotImplementedError
  1727. else:
  1728. try:
  1729. transpose_weight_keys = model.get_transpose_weight_keys()
  1730. except NotImplementedError:
  1731. if convert_from_hf:
  1732. raise ValueError("`convert_from_hf=True` is not supported")
  1733. else:
  1734. transpose_weight_keys = None
  1735. state_dict = load_state_dict(
  1736. resolved_archive_file,
  1737. convert_from_hf=convert_from_hf,
  1738. transpose_weight_keys=transpose_weight_keys,
  1739. )
  1740. logging.info("Loaded weights file from disk, setting weights to model.")
  1741. # Check if `_keep_in_fp32_modules` is not None
  1742. use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
  1743. dtype == "float16" or dtype == "bfloat16"
  1744. )
  1745. if state_dict is not None:
  1746. loaded_state_dict_keys = [k for k in state_dict.keys()]
  1747. # will only support load paddle.Tensor to model.
  1748. for k in list(state_dict.keys()):
  1749. if not isinstance(state_dict[k], paddle.Tensor):
  1750. with device_guard():
  1751. state_dict[k] = paddle.Tensor.__call__(
  1752. state_dict.pop(k), zero_copy=True
  1753. )
  1754. else:
  1755. if is_sharded:
  1756. loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
  1757. else:
  1758. loaded_state_dict_keys = [k for k in state_dict.keys()]
  1759. if low_cpu_mem_usage: # or use_keep_in_fp32_modules:
  1760. state_dict = None
  1761. # will only support load paddle.Tensor to model.
  1762. if state_dict is not None:
  1763. for k in list(state_dict.keys()):
  1764. if not isinstance(state_dict[k], paddle.Tensor):
  1765. with device_guard():
  1766. state_dict[k] = paddle.Tensor.__call__(
  1767. state_dict.pop(k), zero_copy=True
  1768. )
  1769. if use_keep_in_fp32_modules:
  1770. # low_cpu_mem_usage = True
  1771. keep_in_fp32_modules = model._keep_in_fp32_modules
  1772. else:
  1773. keep_in_fp32_modules = []
  1774. quantization_linear_list = None
  1775. model, missing_keys, unexpected_keys, mismatched_keys = (
  1776. cls._load_pretrained_model(
  1777. model=model,
  1778. state_dict=state_dict,
  1779. loaded_keys=loaded_state_dict_keys,
  1780. resolved_archive_file=(
  1781. resolved_sharded_files if is_sharded else resolved_archive_file
  1782. ),
  1783. pretrained_model_name_or_path=pretrained_model_name_or_path,
  1784. config=config,
  1785. ignore_mismatched_sizes=ignore_mismatched_sizes,
  1786. low_cpu_mem_usage=low_cpu_mem_usage,
  1787. dtype=dtype,
  1788. keep_in_fp32_modules=keep_in_fp32_modules,
  1789. quantization_linear_list=quantization_linear_list,
  1790. sharded_metadata=sharded_metadata if is_sharded else None,
  1791. convert_from_hf=convert_from_hf,
  1792. )
  1793. )
  1794. # load generation_config.json
  1795. if model.can_generate() and pretrained_model_name_or_path is not None:
  1796. try:
  1797. model.generation_config = GenerationConfig.from_pretrained(
  1798. pretrained_model_name_or_path,
  1799. cache_dir=cache_dir,
  1800. force_download=force_download,
  1801. from_hf_hub=from_hf_hub,
  1802. from_aistudio=from_aistudio,
  1803. subfolder=subfolder,
  1804. **kwargs,
  1805. )
  1806. except:
  1807. logging.info(
  1808. "Generation config file not found, using a generation config created from the model config."
  1809. )
  1810. pass
  1811. # Note:
  1812. # 1. PipelineLayer will create parameters for each layer and
  1813. # call `_synchronize_shared_weights()` to synchronize the shared parameters.
  1814. # 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
  1815. # synchronize the shared parameters.
  1816. # However, when state dict only contains the one piece of shared parameters, the shared parameters
  1817. # will be different from the original shared parameters.
  1818. if isinstance(model, PipelineLayer):
  1819. model._synchronize_shared_weights()
  1820. if paddle.in_dynamic_mode():
  1821. return model
  1822. return model, state_dict
  1823. def merge_auto_dist_configs(self, configs):
  1824. """
  1825. Merged all auto dist configs into one config.
  1826. configs is a list of config,every config is a dict,which means a model auto_dist_config.
  1827. [
  1828. {
  1829. mp_config (dict): {
  1830. "parallelize_plan": dict, the plan to shard the layer.
  1831. }
  1832. pp_config (dict): {
  1833. "split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
  1834. "global_spec": str|list(str), make the output tensor of specific layers on global mesh.
  1835. }
  1836. },{
  1837. mp_config (dict): {
  1838. "parallelize_plan": dict, the plan to shard the layer.
  1839. }
  1840. pp_config (dict): {
  1841. "split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
  1842. "global_spec": str|list(str), make the output tensor of specific layers on global mesh.
  1843. }
  1844. },....
  1845. ]
  1846. """
  1847. assert isinstance(configs, (dict, list))
  1848. if isinstance(configs, dict):
  1849. return configs
  1850. final_config = {
  1851. "mp_config": None,
  1852. "sp_config": None,
  1853. "pp_config": None,
  1854. }
  1855. for config in configs:
  1856. if "mp_config" in config and config["mp_config"] is not None:
  1857. if final_config["mp_config"] is None:
  1858. final_config["mp_config"] = config["mp_config"]
  1859. else:
  1860. for k, v in config["mp_config"]["parallelize_plan"].items():
  1861. assert (
  1862. k
  1863. not in final_config["mp_config"]["parallelize_plan"].keys()
  1864. ), f"sublayer mp_config should be a subset of model but got sublayer config {config['mp_config']} and model config {final_config['mp_config']}."
  1865. final_config["mp_config"]["parallelize_plan"][k] = v
  1866. if "sp_config" in config and config["sp_config"] is not None:
  1867. if final_config["sp_config"] is None:
  1868. final_config["sp_config"] = config["sp_config"]
  1869. else:
  1870. for k, v in config["sp_config"]["parallelize_plan"].items():
  1871. assert (
  1872. k
  1873. not in final_config["sp_config"]["parallelize_plan"].keys()
  1874. ), f"sublayer sp_config should be a subset of model but got sublayer config {config['sp_config']} and model config {final_config['sp_config']}."
  1875. final_config["sp_config"]["parallelize_plan"][k] = v
  1876. if "pp_config" in config and config["pp_config"] is not None:
  1877. if isinstance(config["pp_config"]["split_spec"], str):
  1878. config["pp_config"]["split_spec"] = [
  1879. config["pp_config"]["split_spec"]
  1880. ]
  1881. if final_config["pp_config"] is None:
  1882. final_config["pp_config"] = config["pp_config"]
  1883. else:
  1884. final_config["pp_config"]["split_spec"] += config["pp_config"][
  1885. "split_spec"
  1886. ]
  1887. elif isinstance(config["pp_config"]["split_spec"], (tuple, list)):
  1888. if final_config["pp_config"] is None:
  1889. final_config["pp_config"] = config["pp_config"]
  1890. else:
  1891. final_config["pp_config"]["split_spec"] += config["pp_config"][
  1892. "split_spec"
  1893. ]
  1894. if (
  1895. final_config["pp_config"] is not None
  1896. and len(final_config["pp_config"]["split_spec"]) == 1
  1897. ):
  1898. final_config["pp_config"]["split_spec"] = final_config["pp_config"][
  1899. "split_spec"
  1900. ][0]
  1901. return final_config
  1902. def _generate_auto_dist_config(self, auto_dist_degree):
  1903. merged_config = {
  1904. "sp_config": None,
  1905. "mp_config": None,
  1906. "pp_config": None,
  1907. }
  1908. for name, layer in self.named_sublayers(include_self=True):
  1909. if hasattr(layer, "auto_dist_config"):
  1910. if name != "":
  1911. prefix = name + "."
  1912. else:
  1913. prefix = ""
  1914. layer_config = layer.auto_dist_config(prefix)
  1915. merged_config = self.merge_auto_dist_configs(
  1916. [merged_config, layer_config]
  1917. )
  1918. for _, deeper_layer in layer.named_sublayers():
  1919. if hasattr(deeper_layer, "auto_dist_config"):
  1920. # mask all `auto_dist_config` methods in deeper layer
  1921. deeper_layer.auto_dist_config = lambda x: {}
  1922. final_config = {
  1923. "dp_config": None,
  1924. "mp_config": None,
  1925. "pp_config": None,
  1926. }
  1927. if (
  1928. "tensor_parallel" in auto_dist_degree
  1929. and auto_dist_degree["tensor_parallel"]
  1930. ):
  1931. merged_config["mp_config"] is not None
  1932. final_config["mp_config"] = merged_config["mp_config"]
  1933. if (
  1934. "sequence_parallel" in auto_dist_degree
  1935. and auto_dist_degree["sequence_parallel"]
  1936. ):
  1937. merged_config["sp_config"] is not None
  1938. final_config["mp_config"] = merged_config["sp_config"]
  1939. if (
  1940. "pipeline_parallel" in auto_dist_degree
  1941. and auto_dist_degree["pipeline_parallel"]
  1942. ):
  1943. merged_config["pp_config"] is not None
  1944. final_config["pp_config"] = merged_config["pp_config"]
  1945. return final_config
  1946. def get_transpose_weight_keys(self):
  1947. raise NotImplementedError
  1948. def get_hf_state_dict(self, *args, **kwargs):
  1949. raise NotImplementedError
  1950. def set_hf_state_dict(self, *args, **kwargs):
  1951. raise NotImplementedError