model_utils.py 86 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import gc
  16. import os
  17. import re
  18. import warnings
  19. from contextlib import contextmanager
  20. from functools import partial
  21. from pathlib import Path
  22. from typing import Any, Dict, List, Optional, Tuple, Type, Union
  23. import numpy as np
  24. import paddle
  25. import paddle.nn as nn
  26. from paddle import Tensor
  27. from paddle.distributed.fleet.meta_parallel.parallel_layers import PipelineLayer
  28. try:
  29. from paddle.distributed.fleet.meta_parallel import LocalSharedLayerDesc
  30. except:
  31. LocalSharedLayerDesc = None
  32. from paddle.nn import Layer
  33. from ......utils import logging
  34. from ......utils.deps import is_dep_available, require_deps
  35. from ...tokenizer.tokenizer_utils import InitTrackerMeta, adapt_stale_fwd_patch
  36. from ..generation import GenerationConfig, GenerationMixin
  37. from ..utils import (
  38. ASYMMETRY_QUANT_SCALE_MAX,
  39. ASYMMETRY_QUANT_SCALE_MIN,
  40. CONFIG_NAME,
  41. LEGACY_CONFIG_NAME,
  42. PADDLE_WEIGHTS_INDEX_NAME,
  43. PADDLE_WEIGHTS_NAME,
  44. PYTORCH_WEIGHTS_INDEX_NAME,
  45. PYTORCH_WEIGHTS_NAME,
  46. SAFE_WEIGHTS_INDEX_NAME,
  47. SAFE_WEIGHTS_NAME,
  48. SYMMETRY_QUANT_SCALE,
  49. device_guard,
  50. resolve_file_path,
  51. )
  52. from .configuration_utils import PretrainedConfig
  53. from .conversion_utils import ConversionMixin
  54. from .utils import (
  55. ContextManagers,
  56. fn_args_to_dict,
  57. get_checkpoint_shard_files,
  58. paddlenlp_load,
  59. weight_name_suffix,
  60. )
  61. __all__ = [
  62. "PretrainedModel",
  63. ]
  64. def _add_variant(weights_name: str, variant=None) -> str:
  65. if variant is not None and len(variant) > 0:
  66. splits = weights_name.split(".")
  67. splits = splits[:-1] + [variant] + splits[-1:]
  68. weights_name = ".".join(splits)
  69. return weights_name
  70. @contextmanager
  71. def dtype_guard(dtype="float32"):
  72. origin_dtype = paddle.get_default_dtype()
  73. paddle.set_default_dtype(dtype)
  74. try:
  75. yield
  76. finally:
  77. paddle.set_default_dtype(origin_dtype)
  78. _init_weights = True
  79. @contextmanager
  80. def no_init_weights(_enable=True):
  81. """
  82. Context manager to globally disable weight initialization to speed up loading large models.
  83. TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
  84. """
  85. global _init_weights
  86. old_init_weights = _init_weights
  87. if _enable:
  88. _init_weights = False
  89. try:
  90. yield
  91. finally:
  92. _init_weights = old_init_weights
  93. def _split_keys_evenly(keys: list, n: int) -> list:
  94. """Split a list into n lists with an equal number of elements.
  95. Args:
  96. keys (list): the list to be split
  97. n (int): number of splits
  98. Returns:
  99. result: list of lists
  100. """
  101. total_len = len(keys)
  102. base_size = total_len // n
  103. extra = total_len % n
  104. result = []
  105. index = 0
  106. for _ in range(n):
  107. part_size = base_size + 1 if extra > 0 else base_size
  108. extra -= 1
  109. result.append(keys[index : index + part_size])
  110. index += part_size
  111. return result
  112. def _load_part_state_dict_from_safetensors(
  113. keys,
  114. checkpoint_file: Union[str, os.PathLike],
  115. tensor_parallel_split_mapping,
  116. fliter_dict_keys,
  117. device,
  118. quantization_linear_list=None,
  119. quantization_config=None,
  120. dtype=None,
  121. return_numpy=False,
  122. convert_from_hf=False,
  123. transpose_weight_keys=None,
  124. ):
  125. import paddle
  126. from safetensors import safe_open
  127. if transpose_weight_keys:
  128. transpose_weight_keys = set(transpose_weight_keys)
  129. def _is_need_transpose(key):
  130. if "lora" not in key and convert_from_hf and transpose_weight_keys:
  131. return key in transpose_weight_keys
  132. def _transpose_hf_weight(key, weight):
  133. if _is_need_transpose(key):
  134. return weight.transpose([-1, -2])
  135. return weight
  136. part_state_dict = {}
  137. scale_dict = {}
  138. with safe_open(checkpoint_file, framework="paddle") as f:
  139. for key in keys:
  140. # 1. non-merge ckpt loading dont have filter key.
  141. # 2. merge ckpt will skip quant scale by `fliter_dict_keys`
  142. if (
  143. key.endswith(SYMMETRY_QUANT_SCALE)
  144. or key.endswith(ASYMMETRY_QUANT_SCALE_MIN)
  145. or key.endswith(ASYMMETRY_QUANT_SCALE_MAX)
  146. ):
  147. continue
  148. if fliter_dict_keys is not None and key not in fliter_dict_keys:
  149. continue
  150. py_safe_slice_ = f.get_slice(key)
  151. if (
  152. quantization_linear_list is not None
  153. and key.split(".weight")[0] in quantization_linear_list
  154. and not key.endswith("_scale")
  155. ):
  156. raise NotImplementedError
  157. else:
  158. if key in tensor_parallel_split_mapping:
  159. tp_fn = tensor_parallel_split_mapping[key]
  160. if _is_need_transpose(key):
  161. assert isinstance(tp_fn, partial)
  162. is_column = True
  163. if "is_column" in tp_fn.keywords:
  164. is_column = tp_fn.keywords["is_column"]
  165. is_column = not is_column
  166. tp_fn = partial(
  167. tp_fn.func,
  168. *tp_fn.args,
  169. **{**tp_fn.keywords, "is_column": is_column},
  170. )
  171. if len(py_safe_slice_.shape) == 0:
  172. weight = tp_fn(py_safe_slice_[:])
  173. else:
  174. weight = tp_fn(py_safe_slice_)
  175. else:
  176. weight = py_safe_slice_[:]
  177. if not return_numpy and device == "expected":
  178. weight = weight._copy_to(
  179. paddle.framework._current_expected_place(), False
  180. )
  181. weight = _transpose_hf_weight(key, weight)
  182. if return_numpy:
  183. weight = weight.numpy()
  184. part_state_dict[key] = weight
  185. for key in keys:
  186. if (
  187. key.endswith(SYMMETRY_QUANT_SCALE)
  188. or key.endswith(ASYMMETRY_QUANT_SCALE_MIN)
  189. or key.endswith(ASYMMETRY_QUANT_SCALE_MAX)
  190. ):
  191. scale = f.get_tensor(key)
  192. if not return_numpy and device == "expected":
  193. scale = scale._copy_to(
  194. paddle.framework._current_expected_place(), False
  195. )
  196. if return_numpy:
  197. scale = scale.numpy()
  198. scale_dict[key] = scale
  199. return part_state_dict, scale_dict
  200. def load_state_dict(
  201. checkpoint_file: Union[str, os.PathLike],
  202. tensor_parallel_split_mapping=None,
  203. fliter_dict_keys=None,
  204. device="cpu",
  205. ckpt_quant_stage="O0",
  206. convert_from_hf=False,
  207. transpose_weight_keys=None,
  208. ):
  209. """
  210. Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
  211. """
  212. if tensor_parallel_split_mapping is None:
  213. tensor_parallel_split_mapping = {}
  214. if Path(checkpoint_file).suffix == ".safetensors":
  215. require_deps("safetensors")
  216. from safetensors import safe_open
  217. with safe_open(checkpoint_file, framework="paddle") as f:
  218. state_dict, scale_dict = _load_part_state_dict_from_safetensors(
  219. list(f.keys()),
  220. checkpoint_file,
  221. tensor_parallel_split_mapping,
  222. fliter_dict_keys,
  223. "expected",
  224. dtype=None,
  225. return_numpy=False,
  226. convert_from_hf=convert_from_hf,
  227. transpose_weight_keys=transpose_weight_keys,
  228. )
  229. else:
  230. state_dict = paddlenlp_load(checkpoint_file, map_location="cpu")
  231. return state_dict
  232. _re_layer_prefix = re.compile(r"\.(\d+)\.")
  233. def _load_state_dict_into_model(
  234. model_to_load, state_dict, start_prefix, convert_from_hf
  235. ):
  236. # torch will cast dtype in load_state_dict, but paddle strictly check dtype
  237. _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf)
  238. error_msgs = []
  239. if len(start_prefix) > 0:
  240. for key in list(state_dict.keys()):
  241. if key.startswith(start_prefix):
  242. state_dict[key.replace(start_prefix, "")] = state_dict.pop(key)
  243. # TODO: add return status to state_dict
  244. with warnings.catch_warnings(record=True) as w:
  245. warnings.resetwarnings()
  246. # paddlenlp hold missing_keys , just ignore not found warnings.
  247. warnings.filterwarnings(
  248. "ignore", message=r".*is not found in the provided dict.*"
  249. )
  250. warnings.filterwarnings("ignore", message=r".*paddle.to_tensor.*")
  251. if convert_from_hf:
  252. try:
  253. model_to_load.set_hf_state_dict(state_dict)
  254. except NotImplementedError:
  255. pass
  256. model_to_load.set_state_dict(state_dict)
  257. error_msgs.extend([str(x.message) for x in w])
  258. del state_dict
  259. return error_msgs
  260. def _convert_state_dict_dtype_and_shape(state_dict, model_to_load, convert_from_hf):
  261. # convert the dtype of state dict
  262. def is_0d_or_1d(tensor):
  263. return len(tensor.shape) == 0 or list(tensor.shape) == [1]
  264. if convert_from_hf:
  265. model_state_dict = model_to_load.get_hf_state_dict()
  266. else:
  267. model_state_dict = model_to_load.state_dict()
  268. for key, value in model_state_dict.items():
  269. if key in list(state_dict.keys()):
  270. if isinstance(state_dict[key], np.ndarray):
  271. raise ValueError(
  272. "convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, please convert numpy.ndarray to paddle.Tensor"
  273. )
  274. # confirm parameter cast is executed on the same device as model
  275. # TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
  276. if (
  277. state_dict[key].is_floating_point()
  278. and state_dict[key].dtype != value.dtype
  279. ):
  280. state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
  281. # unified 0d and 1d tensor
  282. if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
  283. if list(value.shape) != list(state_dict[key].shape):
  284. state_dict[key] = paddle.reshape(state_dict.pop(key), value.shape)
  285. def _load_state_dict_into_meta_model(
  286. model,
  287. state_dict,
  288. loaded_state_dict_keys, # left for now but could be removed, see below
  289. start_prefix,
  290. expected_keys,
  291. dtype=None,
  292. is_safetensors=False,
  293. keep_in_fp32_modules=None,
  294. ):
  295. """
  296. This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
  297. params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
  298. params back to the normal device, but only for `loaded_state_dict_keys`.
  299. `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
  300. `bert.pooler.dense.weight`
  301. """
  302. from paddle.common_ops_import import convert_np_dtype_to_dtype_
  303. dtype = convert_np_dtype_to_dtype_(dtype)
  304. error_msgs = []
  305. model_state_dict = model.state_dict()
  306. for param_name, param in state_dict.items():
  307. # First part of the test is always true as loaded_state_dict_keys always contains state_dict keys.
  308. if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
  309. continue
  310. if param_name.startswith(start_prefix):
  311. param_name = param_name[len(start_prefix) :]
  312. if param.place != paddle.framework._current_expected_place():
  313. param = param._copy_to(paddle.framework._current_expected_place(), False)
  314. # # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
  315. # # in int/uint/bool and not cast them.
  316. if dtype is not None and paddle.is_floating_point(param):
  317. if (
  318. keep_in_fp32_modules is not None
  319. and any(
  320. module_to_keep_in_fp32 in param_name
  321. for module_to_keep_in_fp32 in keep_in_fp32_modules
  322. )
  323. and (dtype == paddle.float16 or dtype == paddle.bfloat16)
  324. ):
  325. param = param.astype(dtype=paddle.float32)
  326. else:
  327. param = param.astype(dtype=dtype)
  328. if dtype is None:
  329. old_param = model
  330. splits = param_name.split(".")
  331. for split in splits:
  332. old_param = getattr(old_param, split)
  333. if old_param is None:
  334. break
  335. if old_param is not None:
  336. param = param.astype(dtype=old_param.dtype)
  337. with paddle.no_grad():
  338. model_state_dict[param_name].get_tensor()._share_data_with(
  339. param.value().get_tensor()
  340. )
  341. param.value().get_tensor()._clear()
  342. return error_msgs
  343. class PretrainedModel(
  344. Layer, GenerationMixin, ConversionMixin, metaclass=InitTrackerMeta
  345. ):
  346. """
  347. The base class for all pretrained models. It mainly provides common methods
  348. for loading (construction and loading) and saving pretrained models. Loading
  349. and saving also rely on the following class attributes which should be overridden
  350. by derived classes accordingly:
  351. - **model_config_file** (str): Represents the file name of model configuration
  352. for configuration saving and loading in local file system. The value is
  353. `model_config.json`.
  354. - **resource_files_names** (dict): Name of local file where the model configuration
  355. can be saved and loaded locally. Currently, resources only include the model state,
  356. thus the dict only includes `'model_state'` as key with corresponding
  357. value `'model_state.pdparams'` for model weights saving and loading.
  358. - **pretrained_init_configuration** (dict): Provides the model configurations
  359. of built-in pretrained models (contrasts to models in local file system).
  360. It has pretrained model names as keys (such as `bert-base-uncased`), and
  361. the values are dict preserving corresponding configuration for model initialization.
  362. - **pretrained_resource_files_map** (dict): Provides resource URLs of built-in
  363. pretrained models (contrasts to models in local file system).
  364. It has the same key as resource_files_names (that is "model_state"),
  365. and the corresponding value is a dict with specific model name to model weights URL mapping
  366. (such as "bert-base-uncased" ->
  367. "https://bj.bcebos.com/paddlenlp/models/transformers/bert-base-uncased.pdparams").
  368. - **base_model_prefix** (str): Represents the attribute associated to the
  369. base model in derived classes of the same architecture adding layers on
  370. top of the base model. Note: A base model class is pretrained model class
  371. decorated by `register_base_model`, such as `BertModel`; A derived model
  372. class is a pretrained model class adding layers on top of the base model,
  373. and it has a base model as attribute, such as `BertForSequenceClassification`.
  374. Methods common to models for text generation are defined in `GenerationMixin`
  375. and also inherited here.
  376. Besides, metaclass `InitTrackerMeta` is used to create `PretrainedModel`,
  377. by which subclasses can track arguments for initialization automatically.
  378. """
  379. # Deprecated(wj-Mcat): after 2.6.* version
  380. # save the old-school `LEGACY_CONFIG_NAME`, and will be changed to `CONFIG_NAME` after 2.6.* version
  381. model_config_file = LEGACY_CONFIG_NAME
  382. pretrained_init_configuration = {}
  383. # TODO: more flexible resource handle, namedtuple with fields as:
  384. # resource_name, saved_file, handle_name_for_load(None for used as __init__
  385. # arguments), handle_name_for_save
  386. resource_files_names = {"model_state": PADDLE_WEIGHTS_NAME}
  387. pretrained_resource_files_map = {}
  388. base_model_prefix = ""
  389. main_input_name = "input_ids"
  390. config_class = None
  391. _keep_in_fp32_modules = None
  392. # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
  393. # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
  394. _keys_to_ignore_on_load_missing = None
  395. # a list of `re` patterns of `state_dict` keys that should be removed from the list of
  396. # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
  397. # warnings.
  398. _keys_to_ignore_on_load_unexpected = None
  399. # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
  400. # trained, but which are either deterministic or tied variables)
  401. _keys_to_ignore_on_save = None
  402. _tied_weights_keys = None
  403. def __init__(self, *args, **kwargs):
  404. super(PretrainedModel, self).__init__()
  405. if not self.constructed_from_pretrained_config():
  406. return
  407. # extract config from args
  408. config = None
  409. for arg in args:
  410. if isinstance(arg, PretrainedConfig):
  411. config = arg
  412. break
  413. if config is not None:
  414. self.config: PretrainedConfig = config
  415. self.model_config_file = CONFIG_NAME
  416. self.generation_config = (
  417. GenerationConfig.from_model_config(self.config)
  418. if self.can_generate()
  419. else None
  420. )
  421. return
  422. # extract config from kwargs
  423. if "config" not in kwargs:
  424. raise ValueError(
  425. "PretrainedConfig instance not found in the arguments, you can set it as args or kwargs with config field"
  426. )
  427. config = kwargs["config"]
  428. if not isinstance(config, PretrainedConfig):
  429. raise TypeError(
  430. "config parameter should be the instance of PretrainedConfig"
  431. )
  432. self.config: PretrainedConfig = kwargs["config"]
  433. self.generation_config = (
  434. GenerationConfig.from_model_config(self.config)
  435. if self.can_generate()
  436. else None
  437. )
  438. self.model_config_file = CONFIG_NAME
  439. self.warnings_issued = {}
  440. def _post_init(self, original_init, *args, **kwargs):
  441. """
  442. It would be hooked after `__init__` to add a dict including arguments of
  443. `__init__` as a attribute named `config` of the pretrained model instance.
  444. """
  445. if not self.constructed_from_pretrained_config():
  446. init_dict = fn_args_to_dict(original_init, *((self,) + args), **kwargs)
  447. self.config = init_dict
  448. # only execute when it's the base method
  449. if (
  450. original_init.__module__ != "paddlenlp.transformers.model_utils"
  451. and self.__class__.init_weights is PretrainedModel.init_weights
  452. ):
  453. self.init_weights()
  454. # Note:
  455. # 1. PipelineLayer will create parameters for each layer and
  456. # call `_synchronize_shared_weights()` to synchronize the shared parameters.
  457. # 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
  458. # synchronize the shared parameters.
  459. # However, `self._init_weights` will re-initialize the parameters without
  460. # synchronizing the shared parameters. If the following step does not load a checkpoint,
  461. # the shared parameters will be different.
  462. if isinstance(self, PipelineLayer):
  463. self._synchronize_shared_weights()
  464. def _init_weights(self, layer):
  465. """
  466. Initialize the weights. This method should be overridden by derived class.
  467. """
  468. pass
  469. def _initialize_weights(self, layer):
  470. """
  471. Initialize the weights if they are not already initialized.
  472. """
  473. if getattr(layer, "_is_initialized", False):
  474. return
  475. self._init_weights(layer)
  476. layer._is_initialized = True
  477. def init_weights(self):
  478. """
  479. If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
  480. initialization logic in `_init_weights`.
  481. """
  482. # call pure
  483. if _init_weights:
  484. # Initialize weights
  485. self.apply(self._initialize_weights)
  486. # Tie weights should be skipped when not initializing all weights
  487. # since from_pretrained(...) calls tie weights anyways
  488. # TODO(wj-Mcat): enable all tie-weights later
  489. # self.tie_weights()
  490. @classmethod
  491. def _from_config(cls, config, **kwargs):
  492. """
  493. All context managers that the model should be initialized under go here.
  494. Args:
  495. dtype (`paddle.dtype`, *optional*):
  496. Override the default `paddle.dtype` and load the model under this dtype.
  497. """
  498. dtype = kwargs.pop("dtype", None)
  499. if dtype is None:
  500. if config.dtype is not None:
  501. dtype = config.dtype
  502. else:
  503. dtype = paddle.get_default_dtype()
  504. with dtype_guard(dtype):
  505. model = cls(config, **kwargs)
  506. return model
  507. @classmethod
  508. def from_config(cls, config, **kwargs):
  509. """
  510. All context managers that the model should be initialized under go here.
  511. Args:
  512. dtype (`paddle.dtype`, *optional*):
  513. Override the default `paddle.dtype` and load the model under this dtype.
  514. """
  515. return cls._from_config(config, **kwargs)
  516. @classmethod
  517. def set_inference_config(cls, config, predictor_args, **kwargs):
  518. """
  519. All inference config can set here.
  520. Args:
  521. config : PretrainedConfig
  522. The config of the model.
  523. predictor_args : PredictorArgument
  524. The args of the predictor.
  525. """
  526. tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", 1)
  527. tensor_parallel_rank = kwargs.pop("tensor_parallel_rank", 0)
  528. if predictor_args.mode == "dynamic" or predictor_args.speculate_method in [
  529. "eagle",
  530. "mtp",
  531. ]:
  532. config.tensor_parallel_degree = tensor_parallel_degree
  533. config.tensor_parallel_rank = tensor_parallel_rank
  534. config.model_name_or_path = predictor_args.model_name_or_path
  535. config.quant_type = predictor_args.quant_type
  536. config.cachekv_int8_type = predictor_args.cachekv_int8_type
  537. config.use_fake_parameter = predictor_args.use_fake_parameter
  538. config.single_card_ptq = not predictor_args.use_fake_parameter
  539. config.append_attn = predictor_args.append_attn
  540. config.decode_strategy = predictor_args.decode_strategy
  541. config.mla_use_matrix_absorption = predictor_args.mla_use_matrix_absorption
  542. config.weightonly_group_size = predictor_args.weightonly_group_size
  543. config.weight_block_size = predictor_args.weight_block_size
  544. config.moe_quant_type = predictor_args.moe_quant_type
  545. if predictor_args.block_attn:
  546. config.block_size = predictor_args.block_size
  547. config.max_seq_len = predictor_args.total_max_length
  548. if predictor_args.speculate_method is not None:
  549. config.speculate_method = predictor_args.speculate_method
  550. config.speculate_max_draft_token_num = (
  551. predictor_args.speculate_max_draft_token_num
  552. )
  553. config.speculate_verify_window = predictor_args.speculate_verify_window
  554. config.speculate_max_candidate_len = (
  555. predictor_args.speculate_max_candidate_len
  556. )
  557. if predictor_args.speculate_method == "inference_with_reference":
  558. config.speculate_max_ngram_size = (
  559. predictor_args.speculate_max_ngram_size
  560. )
  561. if predictor_args.speculate_method is not None:
  562. if not config.get("speculate_model_type", "None") in ["eagle", "mtp"]:
  563. config.decode_strategy = "speculate_decoding"
  564. config.return_full_hidden_states = predictor_args.return_full_hidden_states
  565. @classmethod
  566. def confirm_inference_model(cls, predictor_args, **kwargs):
  567. """
  568. Confirm the inference model whether it need to change the AVX inference Model
  569. Args:
  570. model : PretrainedModel
  571. The model for inference.
  572. predictor_args : PredictorArgument
  573. The args of the predictor.
  574. """
  575. return cls
  576. @property
  577. def base_model(self):
  578. """
  579. PretrainedModel: The body of the same model architecture. It is the base
  580. model itself for base model or the base model attribute for derived
  581. model.
  582. """
  583. return getattr(self, self.base_model_prefix, self)
  584. @property
  585. def model_name_list(self):
  586. """
  587. list: Contains all supported built-in pretrained model names of the
  588. current PretrainedModel class.
  589. """
  590. # Todo: return all model name
  591. return list(self.pretrained_init_configuration.keys())
  592. def can_generate(self) -> bool:
  593. """
  594. Returns whether this model can generate sequences with `.generate()`.
  595. Returns:
  596. `bool`: Whether this model can generate sequences with `.generate()`.
  597. """
  598. # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
  599. if "GenerationMixin" in str(self.prepare_inputs_for_generation):
  600. return False
  601. return True
  602. def recompute_enable(self):
  603. r"""
  604. Enable Recompute.
  605. All layers with the `enable_recompute` attribute will be set to `True`
  606. """
  607. def fn(layer):
  608. if hasattr(layer, "enable_recompute") and (
  609. layer.enable_recompute is False or layer.enable_recompute == 0
  610. ):
  611. layer.enable_recompute = True
  612. self.apply(fn)
  613. def recompute_disable(self):
  614. r"""
  615. Disable Recompute.
  616. All layers with the `enable_recompute` attribute will be set to `False`
  617. """
  618. def fn(layer):
  619. if hasattr(layer, "enable_recompute") and (
  620. layer.enable_recompute is False or layer.enable_recompute == 0
  621. ):
  622. layer.enable_recompute = True
  623. self.apply(fn)
  624. def tie_weights(self):
  625. """
  626. Tie the weights between the input embeddings and the output embeddings.
  627. """
  628. if self.config.tie_word_embeddings:
  629. output_embeddings = self.get_output_embeddings()
  630. input_embeddings = self.get_input_embeddings()
  631. if output_embeddings is not None and input_embeddings is not None:
  632. if input_embeddings.weight.shape != output_embeddings.weight.shape:
  633. logging.warning(
  634. f"The shape of input embeddings is {input_embeddings.weight.shape} and the shape of output embeddings is {output_embeddings.weight.shape}. "
  635. "This is only expected if you are calling the `resize_token_embeddings` method"
  636. )
  637. output_embeddings.weight = input_embeddings.weight
  638. if getattr(output_embeddings, "bias", None) is not None:
  639. # need to pad
  640. if (
  641. output_embeddings.weight.shape[0]
  642. > output_embeddings.bias.shape[0]
  643. ):
  644. old_bias = output_embeddings.bias
  645. pad_length = (
  646. output_embeddings.weight.shape[0] - old_bias.shape[0]
  647. )
  648. output_embeddings.bias = output_embeddings.create_parameter(
  649. shape=[output_embeddings.weight.shape[0]],
  650. attr=output_embeddings._bias_attr,
  651. dtype=output_embeddings._dtype,
  652. is_bias=True,
  653. )
  654. new_bias = paddle.concat(
  655. [
  656. old_bias,
  657. paddle.zeros(
  658. [pad_length], dtype=output_embeddings.bias.dtype
  659. ),
  660. ]
  661. )
  662. output_embeddings.bias.set_value(new_bias)
  663. # need to trim
  664. elif (
  665. output_embeddings.weight.shape[0]
  666. < output_embeddings.bias.shape[0]
  667. ):
  668. new_bias = output_embeddings.bias[
  669. : output_embeddings.weight.shape[0]
  670. ]
  671. output_embeddings.bias = output_embeddings.create_parameter(
  672. shape=[output_embeddings.weight.shape[0]],
  673. attr=output_embeddings._bias_attr,
  674. dtype=output_embeddings._dtype,
  675. is_bias=True,
  676. )
  677. output_embeddings.bias.set_value(new_bias)
  678. def resize_position_embeddings(self, new_num_position_embeddings: int):
  679. """resize position embedding, this method should be overrited overwrited by downstream models
  680. Args:
  681. new_num_position_embeddings (int): the new position size
  682. Raises:
  683. NotImplementedError: when called and not be implemented
  684. """
  685. raise NotImplementedError(
  686. f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
  687. f"overwrite this method in the class {self.__class__} in `{self.__class__.__module__}.py`"
  688. )
  689. @classmethod
  690. def constructed_from_pretrained_config(cls, init_func=None) -> bool:
  691. """check if the model is constructed from `PretrainedConfig`
  692. Returns:
  693. bool: if the model is constructed from `PretrainedConfig`
  694. """
  695. return cls.config_class is not None and issubclass(
  696. cls.config_class, PretrainedConfig
  697. )
  698. def resize_token_embeddings(
  699. self, new_num_tokens: Optional[int] = None
  700. ) -> nn.Embedding:
  701. """
  702. Resizes input token embeddings matrix of the model according to new_num_tokens.
  703. Args:
  704. new_num_tokens (Optional[int]):
  705. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  706. vectors at the end. Reducing the size will remove vectors from the end. If not provided or None, just
  707. returns a pointer to the input tokens embedding module of the model without doing anything.
  708. Returns:
  709. paddle.nn.Embedding: The input tokens Embeddings Module of the model.
  710. """
  711. old_embeddings: nn.Embedding = self.get_input_embeddings()
  712. if not new_num_tokens or new_num_tokens == old_embeddings.weight.shape[0]:
  713. return old_embeddings
  714. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
  715. self.set_input_embeddings(new_embeddings)
  716. # 2. Update vocab_size
  717. self.base_model.config["vocab_size"] = new_num_tokens
  718. self.vocab_size = new_num_tokens
  719. # update init_config
  720. self._update_init_config(self.init_config, "vocab_size", new_num_tokens)
  721. # Tie the weights between the input embeddings and the output embeddings if needed.
  722. self.tie_weights()
  723. return new_embeddings
  724. def _update_init_config(self, init_config: dict, key: str, value: Any):
  725. """update init_config by <key, value> pair
  726. Args:
  727. init_config (dict): the init_config instance
  728. key (str): the key field
  729. value (Any): the new value of instance
  730. """
  731. if key in init_config:
  732. init_config[key] = value
  733. return
  734. for arg in init_config.get("init_args", []):
  735. if not isinstance(arg, PretrainedModel):
  736. continue
  737. self._update_init_config(arg.init_config, key, value)
  738. def _get_resized_embeddings(
  739. self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None
  740. ) -> nn.Embedding:
  741. """
  742. Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
  743. initialized vectors at the end. Reducing the size will remove vectors from the end
  744. Args:
  745. old_embeddings (nn.Embedding):
  746. Old embeddings to be resized.
  747. new_num_tokens (Optional[int]):
  748. New number of tokens in the embedding matrix.
  749. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  750. vectors from the end.
  751. Returns:
  752. paddle.nn.Embedding: The resized Embedding Module or the old Embedding Module if new_num_tokens is None.
  753. """
  754. if new_num_tokens is None:
  755. return old_embeddings
  756. old_num_tokens, old_embedding_dim = old_embeddings.weight.shape
  757. if old_num_tokens == new_num_tokens:
  758. return old_embeddings
  759. if not isinstance(old_embeddings, nn.Embedding):
  760. raise TypeError(
  761. f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
  762. " should either use a different resize function or make sure that old_embeddings are an instance of"
  763. f" {nn.Embedding}."
  764. )
  765. # Build new embeddings
  766. new_embeddings = nn.Embedding(
  767. new_num_tokens,
  768. old_embedding_dim,
  769. padding_idx=old_embeddings._padding_idx,
  770. sparse=old_embeddings._sparse,
  771. )
  772. # make sure that new_embeddings's dtype is same as the old embeddings' dtype
  773. if new_embeddings.weight.dtype != old_embeddings.weight.dtype:
  774. new_embeddings.to(dtype=old_embeddings.weight.dtype)
  775. # numbers of tokens to copy
  776. n = min(old_num_tokens, new_num_tokens)
  777. with paddle.no_grad():
  778. new_embeddings.weight[:n, :] = old_embeddings.weight[:n, :]
  779. return new_embeddings
  780. def __setattr__(self, name, value):
  781. value = adapt_stale_fwd_patch(self, name, value)
  782. return super(PretrainedModel, self).__setattr__(name, value)
  783. @classmethod
  784. def _resolve_model_file_path(
  785. cls: Type[PretrainedModel],
  786. pretrained_model_name_or_path: str,
  787. from_hf_hub: bool = False,
  788. from_aistudio: bool = False,
  789. cache_dir: str | None = None,
  790. subfolder: Optional[str] = "",
  791. config: PretrainedConfig = None,
  792. convert_from_torch: bool = False,
  793. use_safetensors: bool | None = None,
  794. variant=None,
  795. ) -> str:
  796. """resolve model target file path from `` and `cache_dir`
  797. 1. when it is file path:
  798. return the weight file
  799. 2. when it is model-name:
  800. 2.1 check default `MODEL_HOME` + `model-mame` + model_state.pdparams
  801. 2.2 get the url from `pretrained_resource_files_map`, and set it to `pretrained_model_name_or_path`
  802. 3. when it is local dir:
  803. check whether the file<local_dir + weight_file> exist
  804. Args:
  805. cls (Type[PretrainedModel]): the inherited PretrainedModel class
  806. pretrained_model_name_or_path (str): the model-name/url/local_dir/local_dir
  807. cache_dir (Optional[str], optional): cache_dir is used when name_or_path is model-name/url. Defaults to None.
  808. convert_from_torch (bool, optional): whether support convert pytorch model to paddle model
  809. Returns:
  810. str: the model weight file path
  811. """
  812. is_sharded = False
  813. sharded_metadata = None
  814. if pretrained_model_name_or_path is not None:
  815. # the following code use a lot of os.path.join, hence setting subfolder to empty str if None
  816. if subfolder is None:
  817. subfolder = ""
  818. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  819. is_local = os.path.isdir(pretrained_model_name_or_path)
  820. def get_file_path(
  821. pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, variant
  822. ):
  823. return os.path.join(
  824. pretrained_model_name_or_path,
  825. subfolder,
  826. _add_variant(SAFE_WEIGHTS_NAME, variant),
  827. )
  828. # pretrained_model_name_or_path is file
  829. if os.path.isfile(pretrained_model_name_or_path):
  830. archive_file = pretrained_model_name_or_path
  831. is_local = True
  832. # pretrained_model_name_or_path is dir
  833. elif is_local:
  834. if use_safetensors is not False and os.path.isfile(
  835. get_file_path(
  836. pretrained_model_name_or_path,
  837. subfolder,
  838. SAFE_WEIGHTS_INDEX_NAME,
  839. variant,
  840. )
  841. ):
  842. # Load from a sharded safetensors checkpoint
  843. archive_file = get_file_path(
  844. pretrained_model_name_or_path,
  845. subfolder,
  846. SAFE_WEIGHTS_INDEX_NAME,
  847. variant,
  848. )
  849. is_sharded = True
  850. elif use_safetensors is not False and os.path.isfile(
  851. get_file_path(
  852. pretrained_model_name_or_path,
  853. subfolder,
  854. SAFE_WEIGHTS_INDEX_NAME,
  855. weight_name_suffix(),
  856. )
  857. ):
  858. # Load from a sharded safetensors checkpoint
  859. archive_file = get_file_path(
  860. pretrained_model_name_or_path,
  861. subfolder,
  862. SAFE_WEIGHTS_INDEX_NAME,
  863. weight_name_suffix(),
  864. )
  865. is_sharded = True
  866. elif use_safetensors is not False and os.path.isfile(
  867. get_file_path(
  868. pretrained_model_name_or_path,
  869. subfolder,
  870. SAFE_WEIGHTS_NAME,
  871. variant,
  872. )
  873. ):
  874. # Load from a safetensors checkpoint
  875. archive_file = get_file_path(
  876. pretrained_model_name_or_path,
  877. subfolder,
  878. SAFE_WEIGHTS_NAME,
  879. variant,
  880. )
  881. elif use_safetensors is not False and os.path.isfile(
  882. get_file_path(
  883. pretrained_model_name_or_path,
  884. subfolder,
  885. SAFE_WEIGHTS_NAME,
  886. weight_name_suffix(),
  887. )
  888. ):
  889. # Load from a safetensors checkpoint
  890. archive_file = get_file_path(
  891. pretrained_model_name_or_path,
  892. subfolder,
  893. SAFE_WEIGHTS_NAME,
  894. weight_name_suffix(),
  895. )
  896. elif os.path.isfile(
  897. get_file_path(
  898. pretrained_model_name_or_path,
  899. subfolder,
  900. PADDLE_WEIGHTS_INDEX_NAME,
  901. variant,
  902. )
  903. ):
  904. # Load from a sharded PaddlePaddle checkpoint
  905. archive_file = get_file_path(
  906. pretrained_model_name_or_path,
  907. subfolder,
  908. PADDLE_WEIGHTS_INDEX_NAME,
  909. variant,
  910. )
  911. is_sharded = True
  912. elif os.path.isfile(
  913. get_file_path(
  914. pretrained_model_name_or_path,
  915. subfolder,
  916. PADDLE_WEIGHTS_INDEX_NAME,
  917. weight_name_suffix(),
  918. )
  919. ):
  920. # Load from a sharded PaddlePaddle checkpoint for hybrid parallel model
  921. archive_file = get_file_path(
  922. pretrained_model_name_or_path,
  923. subfolder,
  924. PADDLE_WEIGHTS_INDEX_NAME,
  925. weight_name_suffix(),
  926. )
  927. is_sharded = True
  928. elif os.path.isfile(
  929. get_file_path(
  930. pretrained_model_name_or_path,
  931. subfolder,
  932. PADDLE_WEIGHTS_NAME,
  933. variant,
  934. )
  935. ):
  936. # Load from a PaddlePaddle checkpoint
  937. archive_file = get_file_path(
  938. pretrained_model_name_or_path,
  939. subfolder,
  940. PADDLE_WEIGHTS_NAME,
  941. variant,
  942. )
  943. elif os.path.isfile(
  944. get_file_path(
  945. pretrained_model_name_or_path,
  946. subfolder,
  947. PADDLE_WEIGHTS_NAME,
  948. weight_name_suffix(),
  949. )
  950. ):
  951. # Load from a PaddlePaddle checkpoint for hybrid parallel model
  952. archive_file = get_file_path(
  953. pretrained_model_name_or_path,
  954. subfolder,
  955. PADDLE_WEIGHTS_NAME,
  956. weight_name_suffix(),
  957. )
  958. elif os.path.isfile(
  959. os.path.join(
  960. pretrained_model_name_or_path,
  961. subfolder,
  962. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  963. )
  964. ):
  965. if from_hf_hub or convert_from_torch:
  966. archive_file = os.path.join(
  967. pretrained_model_name_or_path,
  968. subfolder,
  969. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  970. )
  971. else:
  972. raise ValueError(
  973. f"Found {_add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant)} in directory"
  974. f" {pretrained_model_name_or_path}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
  975. )
  976. elif os.path.isfile(
  977. os.path.join(
  978. pretrained_model_name_or_path,
  979. subfolder,
  980. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  981. )
  982. ):
  983. if from_hf_hub or convert_from_torch:
  984. archive_file = os.path.join(
  985. pretrained_model_name_or_path,
  986. subfolder,
  987. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  988. )
  989. else:
  990. raise ValueError(
  991. f"Found {_add_variant(PYTORCH_WEIGHTS_NAME, variant)} in directory"
  992. f" {pretrained_model_name_or_path}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
  993. )
  994. else:
  995. raise EnvironmentError(
  996. f"Error no file named {_add_variant(PADDLE_WEIGHTS_NAME, variant)}, found in directory"
  997. f" {pretrained_model_name_or_path}."
  998. )
  999. elif pretrained_model_name_or_path in cls.pretrained_init_configuration:
  1000. # fetch the weight url from the `pretrained_resource_files_map`
  1001. resource_file_url = cls.pretrained_resource_files_map["model_state"][
  1002. pretrained_model_name_or_path
  1003. ]
  1004. resolved_archive_file = resolve_file_path(
  1005. pretrained_model_name_or_path,
  1006. [resource_file_url],
  1007. subfolder,
  1008. cache_dir=cache_dir,
  1009. from_aistudio=from_aistudio,
  1010. from_hf_hub=from_hf_hub,
  1011. )
  1012. else:
  1013. if use_safetensors is True:
  1014. filenames = [
  1015. _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
  1016. _add_variant(SAFE_WEIGHTS_NAME, variant),
  1017. ]
  1018. elif use_safetensors is None:
  1019. filenames = [
  1020. _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
  1021. _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
  1022. _add_variant(SAFE_WEIGHTS_NAME, variant),
  1023. _add_variant(PADDLE_WEIGHTS_NAME, variant),
  1024. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  1025. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  1026. ]
  1027. else:
  1028. filenames = [
  1029. _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
  1030. _add_variant(PADDLE_WEIGHTS_NAME, variant),
  1031. _add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
  1032. _add_variant(PYTORCH_WEIGHTS_NAME, variant),
  1033. ]
  1034. resolved_archive_file = resolve_file_path(
  1035. pretrained_model_name_or_path,
  1036. filenames,
  1037. subfolder,
  1038. cache_dir=cache_dir,
  1039. from_aistudio=from_aistudio,
  1040. from_hf_hub=from_hf_hub,
  1041. )
  1042. if resolved_archive_file is None:
  1043. raise EnvironmentError(
  1044. f"Error no files {filenames} found in repo {pretrained_model_name_or_path}."
  1045. )
  1046. elif "pytorch_model.bin" in str(resolved_archive_file):
  1047. if not from_hf_hub and not convert_from_torch:
  1048. raise ValueError(
  1049. f"Download pytorch weight in "
  1050. f" {resolved_archive_file}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
  1051. )
  1052. if is_local:
  1053. logging.info(f"Loading weights file {archive_file}")
  1054. resolved_archive_file = archive_file
  1055. else:
  1056. logging.info(
  1057. f"Loading weights file from cache at {resolved_archive_file}"
  1058. )
  1059. else:
  1060. resolved_archive_file = None
  1061. # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
  1062. resolved_sharded_files = None
  1063. if str(resolved_archive_file).endswith(".json"):
  1064. is_sharded = True
  1065. if is_sharded:
  1066. # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
  1067. resolved_sharded_files, sharded_metadata = get_checkpoint_shard_files(
  1068. pretrained_model_name_or_path,
  1069. resolved_archive_file,
  1070. from_aistudio=from_aistudio,
  1071. from_hf_hub=from_hf_hub,
  1072. cache_dir=cache_dir,
  1073. subfolder=subfolder,
  1074. )
  1075. return (
  1076. resolved_archive_file,
  1077. resolved_sharded_files,
  1078. sharded_metadata,
  1079. is_sharded,
  1080. )
  1081. @classmethod
  1082. def _load_pretrained_model(
  1083. cls,
  1084. model: PretrainedModel,
  1085. state_dict: Dict[str, Tensor],
  1086. loaded_keys: List[str],
  1087. resolved_archive_file: Union[str, List] = [],
  1088. pretrained_model_name_or_path=None,
  1089. config=None,
  1090. ignore_mismatched_sizes=False,
  1091. low_cpu_mem_usage=False,
  1092. dtype=None,
  1093. keep_in_fp32_modules=None,
  1094. quantization_linear_list=None,
  1095. sharded_metadata=None,
  1096. convert_from_hf=False,
  1097. ) -> Tuple[List[str]]:
  1098. """load the state_dict into model, and do the following things:
  1099. * check the
  1100. Args:
  1101. model (PretrainedModel): the pretrained model instance
  1102. state_dict (Dict[str, Tensor]): the model state dict data
  1103. loaded_keys (List[str]):
  1104. ignore_mismatched_sizes (bool, optional): whether ignore error when tensor size mismatched. Defaults to False.
  1105. dtype (_type_, optional): the dtype of model state dict. Defaults to None.
  1106. Returns:
  1107. Tuple[List[str]]: _description_
  1108. """
  1109. is_safetensors = False
  1110. if convert_from_hf:
  1111. try:
  1112. model_state_dict = model.get_hf_state_dict()
  1113. except NotImplementedError:
  1114. model_state_dict = model.state_dict()
  1115. else:
  1116. model_state_dict = model.state_dict()
  1117. expected_keys = list(model_state_dict.keys())
  1118. prefix = model.base_model_prefix
  1119. if len(prefix) > 0:
  1120. has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
  1121. expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
  1122. else:
  1123. has_prefix_module = False
  1124. expects_prefix_module = False
  1125. # key re-naming operations are never done on the keys
  1126. # that are loaded, but always on the keys of the newly initialized model
  1127. remove_prefix_from_model = not has_prefix_module and expects_prefix_module
  1128. add_prefix_to_model = has_prefix_module and not expects_prefix_module
  1129. if remove_prefix_from_model:
  1130. _prefix = f"{prefix}."
  1131. expected_keys_not_prefixed = [
  1132. s for s in expected_keys if not s.startswith(_prefix)
  1133. ]
  1134. expected_keys = [
  1135. s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys
  1136. ]
  1137. if quantization_linear_list is not None:
  1138. quantization_linear_list = [
  1139. s[len(_prefix) :] if s.startswith(_prefix) else s
  1140. for s in quantization_linear_list
  1141. ]
  1142. elif add_prefix_to_model:
  1143. expected_keys = [".".join([prefix, s]) for s in expected_keys]
  1144. if quantization_linear_list is not None:
  1145. quantization_linear_list = [
  1146. ".".join([prefix, s]) for s in quantization_linear_list
  1147. ]
  1148. missing_keys = list(set(expected_keys) - set(loaded_keys))
  1149. unexpected_keys = list(set(loaded_keys) - set(expected_keys))
  1150. # Optimize for skip unused shard files for supper large model
  1151. if sharded_metadata is not None:
  1152. assert isinstance(resolved_archive_file, list)
  1153. new_archive_file = []
  1154. skip_archive_file = []
  1155. expected_keys_set = set(expected_keys)
  1156. for file in resolved_archive_file:
  1157. filename = os.path.split(file)[-1]
  1158. if not expected_keys_set.isdisjoint(
  1159. set(sharded_metadata["file_map"][filename])
  1160. ):
  1161. new_archive_file.append(file)
  1162. else:
  1163. skip_archive_file.append(filename)
  1164. resolved_archive_file = new_archive_file
  1165. if len(skip_archive_file) > 0:
  1166. logging.info(
  1167. f"Skip load files for not contrains expected key, {skip_archive_file}"
  1168. )
  1169. # Some models may have keys that are not in the state by design, removing them before needlessly warning
  1170. # the user.
  1171. if cls._keys_to_ignore_on_load_missing is not None:
  1172. for pat in cls._keys_to_ignore_on_load_missing:
  1173. missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
  1174. if cls._keys_to_ignore_on_load_unexpected is not None:
  1175. for pat in cls._keys_to_ignore_on_load_unexpected:
  1176. unexpected_keys = [
  1177. k for k in unexpected_keys if re.search(pat, k) is None
  1178. ]
  1179. # Set some modules to fp32 if any
  1180. if keep_in_fp32_modules is not None:
  1181. for name, param in model.named_parameters():
  1182. if any(
  1183. module_to_keep_in_fp32 in name
  1184. for module_to_keep_in_fp32 in keep_in_fp32_modules
  1185. ):
  1186. if param.dtype != paddle.float32:
  1187. param_fp32 = param.cast(dtype=paddle.float32)
  1188. param_fp32_tensor = param_fp32.value().get_tensor()
  1189. param_tensor = param.value().get_tensor()
  1190. param_tensor._share_data_with(param_fp32_tensor)
  1191. # Make sure we are able to load base models as well as derived models (with heads)
  1192. start_prefix = ""
  1193. model_to_load = model
  1194. if (
  1195. len(cls.base_model_prefix) > 0
  1196. and not hasattr(model, cls.base_model_prefix)
  1197. and has_prefix_module
  1198. ):
  1199. start_prefix = cls.base_model_prefix + "."
  1200. if (
  1201. len(cls.base_model_prefix) > 0
  1202. and hasattr(model, cls.base_model_prefix)
  1203. and not has_prefix_module
  1204. ):
  1205. model_to_load = getattr(model, cls.base_model_prefix)
  1206. base_model_expected_keys = list(model_to_load.state_dict().keys())
  1207. if any(
  1208. key in expected_keys_not_prefixed
  1209. and key not in base_model_expected_keys
  1210. for key in loaded_keys
  1211. ):
  1212. raise ValueError(
  1213. "The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
  1214. "properly saved?"
  1215. )
  1216. def _find_mismatched_keys(
  1217. state_dict,
  1218. model_state_dict,
  1219. loaded_keys,
  1220. add_prefix_to_model,
  1221. remove_prefix_from_model,
  1222. ignore_mismatched_sizes,
  1223. ):
  1224. mismatched_keys = []
  1225. if ignore_mismatched_sizes:
  1226. for checkpoint_key in loaded_keys:
  1227. # If the checkpoint is sharded, we may not have the key here.
  1228. if checkpoint_key not in state_dict:
  1229. continue
  1230. model_key = checkpoint_key
  1231. if remove_prefix_from_model:
  1232. # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
  1233. model_key = f"{prefix}.{checkpoint_key}"
  1234. elif add_prefix_to_model:
  1235. # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
  1236. model_key = ".".join(checkpoint_key.split(".")[1:])
  1237. if (
  1238. model_key in model_state_dict
  1239. and state_dict[checkpoint_key].shape
  1240. != model_state_dict[model_key].shape
  1241. ):
  1242. mismatched_keys.append(
  1243. (
  1244. checkpoint_key,
  1245. state_dict[checkpoint_key].shape,
  1246. model_state_dict[model_key].shape,
  1247. )
  1248. )
  1249. del state_dict[checkpoint_key]
  1250. return mismatched_keys
  1251. def _fuse_or_split_keys(
  1252. state_dict,
  1253. config,
  1254. loaded_keys,
  1255. pre_tensor_parallel_split=False,
  1256. resume_state_dict=None,
  1257. ):
  1258. if resume_state_dict is not None:
  1259. state_dict.update(resume_state_dict)
  1260. before_fuse_keys = list(state_dict.keys())
  1261. if pre_tensor_parallel_split:
  1262. tp_actions = cls.get_tensor_parallel_convert_actions(
  1263. config, loaded_keys, ignore_error=True
  1264. )
  1265. else:
  1266. tp_actions = None
  1267. state_dict, resume_state_dict = cls.convert_fuse_and_split(
  1268. config, state_dict, tp_actions
  1269. )
  1270. after_fuse_keys = list(state_dict.keys())
  1271. fused_keys = list(set(before_fuse_keys) - set(after_fuse_keys))
  1272. new_keys = list(set(after_fuse_keys) - set(before_fuse_keys))
  1273. return state_dict, resume_state_dict, fused_keys, new_keys
  1274. if state_dict is not None:
  1275. # have loaded all state_dict, no resume state_dict
  1276. state_dict, _, fused_keys, new_keys = _fuse_or_split_keys(
  1277. state_dict,
  1278. config,
  1279. loaded_keys,
  1280. pre_tensor_parallel_split=(
  1281. True
  1282. if config is not None and config.tensor_parallel_degree > 1
  1283. else False
  1284. ),
  1285. )
  1286. missing_keys = list(set(missing_keys) - set(new_keys))
  1287. unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
  1288. mismatched_keys = _find_mismatched_keys(
  1289. state_dict,
  1290. model_state_dict,
  1291. loaded_keys,
  1292. add_prefix_to_model,
  1293. remove_prefix_from_model,
  1294. ignore_mismatched_sizes,
  1295. )
  1296. error_msgs = _load_state_dict_into_model(
  1297. model_to_load,
  1298. state_dict,
  1299. start_prefix,
  1300. convert_from_hf=convert_from_hf,
  1301. )
  1302. else:
  1303. # Sharded checkpoint or whole but low_cpu_mem_usage==True
  1304. # This should always be a list but, just to be sure.
  1305. if not isinstance(resolved_archive_file, list):
  1306. resolved_archive_file = [resolved_archive_file]
  1307. error_msgs = []
  1308. mismatched_keys = []
  1309. resume_state_dict = {}
  1310. for shard_file in resolved_archive_file:
  1311. pre_tensor_parallel_split = False
  1312. if (
  1313. shard_file.endswith(".safetensors")
  1314. and config.tensor_parallel_degree > 1
  1315. and "tp" not in os.path.split(shard_file)[-1]
  1316. ):
  1317. pre_tensor_parallel_split = True
  1318. assert loaded_keys is not None, "loaded_keys is not None."
  1319. tp_actions = cls.get_tensor_parallel_convert_actions(
  1320. config, loaded_keys, ignore_error=True
  1321. )
  1322. # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
  1323. filter_dict_keys = set(expected_keys)
  1324. fuse_actions, _ = cls.get_fuse_or_split_param_convert_actions(
  1325. config, loaded_keys, is_fuse=True
  1326. )
  1327. split_actions, _ = cls.get_fuse_or_split_param_convert_actions(
  1328. config, loaded_keys, is_fuse=False
  1329. )
  1330. for k in list(fuse_actions.keys()):
  1331. need_add_except_key = k[-1] in expected_keys
  1332. if need_add_except_key:
  1333. filter_dict_keys |= set(k[:-1])
  1334. # remove pre_tensor_parallel_split function from tp_actions
  1335. if pre_tensor_parallel_split:
  1336. for item in k[:-1]:
  1337. if item in tp_actions:
  1338. tp_actions.pop(item, None)
  1339. for k in list(split_actions.keys()):
  1340. need_add_except_key = False
  1341. for item in k[:-1]:
  1342. if item in expected_keys:
  1343. need_add_except_key = True
  1344. break
  1345. if need_add_except_key:
  1346. filter_dict_keys.add(k[-1])
  1347. # remove pre_tensor_parallel_split function from tp_actions
  1348. if pre_tensor_parallel_split:
  1349. if k[-1] in tp_actions:
  1350. fuse_actions.pop(k[-1], None)
  1351. try:
  1352. transpose_weight_keys = model.get_transpose_weight_keys()
  1353. except NotImplementedError:
  1354. if convert_from_hf:
  1355. raise ValueError("`convert_from_hf=True` is not supported")
  1356. else:
  1357. transpose_weight_keys = None
  1358. state_dict = load_state_dict(
  1359. shard_file,
  1360. tp_actions if pre_tensor_parallel_split else None,
  1361. filter_dict_keys,
  1362. convert_from_hf=convert_from_hf,
  1363. transpose_weight_keys=transpose_weight_keys,
  1364. )
  1365. # convert for fusing or splitting weights
  1366. state_dict, resume_state_dict, fused_keys, new_keys = (
  1367. _fuse_or_split_keys(
  1368. state_dict,
  1369. config,
  1370. loaded_keys,
  1371. pre_tensor_parallel_split=pre_tensor_parallel_split,
  1372. resume_state_dict=resume_state_dict,
  1373. )
  1374. )
  1375. missing_keys = list(set(missing_keys) - set(new_keys))
  1376. unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
  1377. # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
  1378. # matching the weights in the model.
  1379. mismatched_keys += _find_mismatched_keys(
  1380. state_dict,
  1381. model_state_dict,
  1382. loaded_keys,
  1383. add_prefix_to_model,
  1384. remove_prefix_from_model,
  1385. ignore_mismatched_sizes,
  1386. )
  1387. if (
  1388. config.tensor_parallel_degree > 1
  1389. and ".tp" not in shard_file
  1390. and not pre_tensor_parallel_split
  1391. ):
  1392. logging.info("Converting state_dict to Tensor Parallel Format")
  1393. # ignore error for multi shard, since only parts of data
  1394. state_dict = cls.convert_tensor_parallel(
  1395. None,
  1396. config,
  1397. state_dict=state_dict,
  1398. ignore_error=len(resolved_archive_file) > 1,
  1399. )
  1400. logging.info("Converted state_dict to Tensor Parallel Format")
  1401. if low_cpu_mem_usage:
  1402. new_error_msgs = _load_state_dict_into_meta_model(
  1403. model_to_load,
  1404. state_dict,
  1405. loaded_keys,
  1406. start_prefix,
  1407. expected_keys,
  1408. dtype=dtype,
  1409. is_safetensors=is_safetensors,
  1410. keep_in_fp32_modules=keep_in_fp32_modules,
  1411. )
  1412. error_msgs += new_error_msgs
  1413. else:
  1414. error_msgs += _load_state_dict_into_model(
  1415. model_to_load,
  1416. state_dict,
  1417. start_prefix,
  1418. convert_from_hf=convert_from_hf,
  1419. )
  1420. # force memory release
  1421. del state_dict
  1422. gc.collect()
  1423. if len(error_msgs) > 0:
  1424. error_msg = "\n\t".join(error_msgs)
  1425. if " but the expected shape is" in error_msg:
  1426. error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
  1427. raise RuntimeError(
  1428. f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
  1429. )
  1430. if len(unexpected_keys) > 0:
  1431. logging.warning(
  1432. f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
  1433. f" initializing {model.__class__.__name__}: {sorted(unexpected_keys)}\n- This IS expected if you are"
  1434. f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
  1435. " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
  1436. " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
  1437. f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
  1438. " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
  1439. )
  1440. else:
  1441. logging.info(
  1442. f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
  1443. )
  1444. if len(missing_keys) > 0:
  1445. logging.warning(
  1446. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  1447. f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
  1448. " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
  1449. )
  1450. elif len(mismatched_keys) == 0:
  1451. logging.info(
  1452. f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
  1453. f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
  1454. f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
  1455. " training."
  1456. )
  1457. if len(mismatched_keys) > 0:
  1458. mismatched_warning = "\n".join(
  1459. [
  1460. f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
  1461. for key, shape1, shape2 in mismatched_keys
  1462. ]
  1463. )
  1464. logging.warning(
  1465. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  1466. f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
  1467. f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
  1468. " to use it for predictions and inference."
  1469. )
  1470. return model, missing_keys, unexpected_keys, mismatched_keys
  1471. @classmethod
  1472. def from_pretrained(
  1473. cls, pretrained_model_name_or_path, *args, convert_from_hf=False, **kwargs
  1474. ):
  1475. """
  1476. Creates an instance of `PretrainedModel`. Model weights are loaded
  1477. by specifying name of a built-in pretrained model, a pretrained model from HF Hub, a community contributed model,
  1478. or a local file directory path.
  1479. Args:
  1480. pretrained_model_name_or_path (str): Name of pretrained model or dir path
  1481. to load from. The string can be:
  1482. - Name of a built-in pretrained model
  1483. - Name of a pretrained model from HF Hub
  1484. - Name of a community-contributed pretrained model.
  1485. - Local directory path which contains model weights file("model_state.pdparams")
  1486. and model config file ("model_config.json").
  1487. from_hf_hub (bool): load model from huggingface hub. Default to `False`.
  1488. subfolder (str, optional) An optional value corresponding to a folder inside the repo.
  1489. Only works when loading from Huggingface Hub.
  1490. *args (tuple): Position arguments for model `__init__`. If provided,
  1491. use these as position argument values for model initialization.
  1492. **kwargs (dict): Keyword arguments for model `__init__`. If provided,
  1493. use these to update pre-defined keyword argument values for model
  1494. initialization. If the keyword is in `__init__` argument names of
  1495. base model, update argument values of the base model; else update
  1496. argument values of derived model.
  1497. load_state_as_np (bool, optional): The weights read in can be choosed
  1498. to place on CPU or GPU though the model is on the default device.
  1499. If `True`, load the model weights as `numpy.ndarray` on CPU.
  1500. Otherwise, weights would be loaded as tensors on the default
  1501. device. Note that if on GPU, the latter would creates extra
  1502. temporary tensors in addition to the model weights, which
  1503. doubles the memory usage . Thus it is suggested to use `True`
  1504. for big models on GPU. Default to `False`.
  1505. Returns:
  1506. PretrainedModel: An instance of `PretrainedModel`.
  1507. Example:
  1508. .. code-block::
  1509. from paddlenlp.transformers import BertForSequenceClassification
  1510. # Name of built-in pretrained model
  1511. model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
  1512. # Name of pretrained model from PaddleHub
  1513. model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
  1514. # Name of community-contributed pretrained model
  1515. model = BertForSequenceClassification.from_pretrained('yingyibiao/bert-base-uncased-sst-2-finetuned', num_labels=3)
  1516. # Load from local directory path
  1517. model = BertForSequenceClassification.from_pretrained('./my_bert/')
  1518. """
  1519. config = kwargs.pop("config", None)
  1520. state_dict = kwargs.pop("state_dict", None)
  1521. cache_dir = kwargs.pop("cache_dir", None)
  1522. force_download = kwargs.get("force_download", False)
  1523. ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
  1524. dtype = kwargs.pop("dtype", None)
  1525. from_hf_hub = kwargs.pop("from_hf_hub", False)
  1526. from_aistudio = kwargs.pop("from_aistudio", False)
  1527. subfolder = kwargs.pop("subfolder", None)
  1528. if subfolder is None:
  1529. subfolder = ""
  1530. variant = kwargs.pop("variant", None)
  1531. use_safetensors = kwargs.pop(
  1532. "use_safetensors", None if is_dep_available("safetensors") else False
  1533. )
  1534. low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
  1535. convert_from_torch = kwargs.pop("convert_from_torch", None)
  1536. load_state_as_np = kwargs.pop("load_state_as_np", None)
  1537. if load_state_as_np is not None:
  1538. logging.warning("`load_state_as_np` is deprecated, please delete it!")
  1539. model_kwargs = kwargs
  1540. if convert_from_torch is None and os.environ.get("from_modelscope", False):
  1541. logging.warning(
  1542. "If you are attempting to load weights from ModelScope Hub and want to disable the default behavior of considering torch weights,"
  1543. " you can set ·convert_from_torch=False·. By default, `convert_from_torch` is set to `True`. "
  1544. )
  1545. convert_from_torch = True
  1546. # from_hf_hub default enable convert_from_torch
  1547. if from_hf_hub and convert_from_torch is None:
  1548. logging.warning(
  1549. "If you are attempting to load weights from Hugging Face Hub and want to disable the default behavior of considering torch weights,"
  1550. " you can set ·convert_from_torch=False·. By default, `convert_from_torch` is set to `True`. "
  1551. )
  1552. convert_from_torch = True
  1553. # convert_from_torch default is False
  1554. if convert_from_torch is None:
  1555. convert_from_torch = False
  1556. # 1. get the PretrainedConfig to init model
  1557. if not isinstance(config, PretrainedConfig):
  1558. config_path = (
  1559. config if config is not None else pretrained_model_name_or_path
  1560. )
  1561. config, model_kwargs = (
  1562. cls.config_class.from_pretrained( # NOTE cls.config_class : Qwen2VLForConditionalGeneration
  1563. config_path,
  1564. cache_dir=cache_dir,
  1565. from_hf_hub=from_hf_hub,
  1566. from_aistudio=from_aistudio,
  1567. subfolder=subfolder,
  1568. return_unused_kwargs=True,
  1569. **kwargs,
  1570. )
  1571. )
  1572. if "from_aistudio" in model_kwargs:
  1573. model_kwargs.pop("from_aistudio")
  1574. if dtype is None:
  1575. dtype = config.dtype
  1576. config.dtype = dtype
  1577. init_contexts = []
  1578. if dtype:
  1579. init_contexts.append(dtype_guard(dtype))
  1580. # Keep in fp32 modules
  1581. keep_in_fp32_modules = None
  1582. use_keep_in_fp32_modules = False
  1583. # resolve model_weight file
  1584. resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = (
  1585. cls._resolve_model_file_path(
  1586. pretrained_model_name_or_path,
  1587. cache_dir=cache_dir,
  1588. subfolder=subfolder,
  1589. from_hf_hub=from_hf_hub,
  1590. from_aistudio=from_aistudio,
  1591. config=config,
  1592. convert_from_torch=False,
  1593. use_safetensors=use_safetensors,
  1594. variant=variant,
  1595. )
  1596. )
  1597. init_args = config["init_args"] or ()
  1598. with ContextManagers(init_contexts):
  1599. model = cls(config, *init_args, **model_kwargs)
  1600. # modified by zhch158
  1601. transpose_weight_keys = [] # ✅ 默认初始化
  1602. if convert_from_torch and state_dict is None:
  1603. if (
  1604. resolved_archive_file.endswith(PYTORCH_WEIGHTS_NAME)
  1605. or resolved_archive_file.endswith(PYTORCH_WEIGHTS_INDEX_NAME)
  1606. or resolved_archive_file.endswith(SAFE_WEIGHTS_NAME)
  1607. or resolved_archive_file.endswith(SAFE_WEIGHTS_INDEX_NAME)
  1608. ):
  1609. # try to get the name-mapping info
  1610. convert_dir = os.path.dirname(resolved_archive_file)
  1611. logging.info(
  1612. f"Starting to convert pytorch weight file<{resolved_archive_file}> to "
  1613. f"paddle weight file<{convert_dir}> ..."
  1614. )
  1615. state_dict = cls.convert(
  1616. resolved_archive_file,
  1617. config,
  1618. # cache_dir=os.path.join(cache_dir, pretrained_model_name_or_path, subfolder),
  1619. cache_dir=convert_dir,
  1620. )
  1621. elif (
  1622. resolved_archive_file.endswith(PADDLE_WEIGHTS_NAME)
  1623. or resolved_archive_file.endswith(PADDLE_WEIGHTS_INDEX_NAME)
  1624. or resolved_archive_file.endswith(".pdparams")
  1625. ):
  1626. print(f"file: {resolved_archive_file} is paddle weight.")
  1627. else:
  1628. raise ValueError(
  1629. f"Unexpected file: {resolved_archive_file} for weight conversion."
  1630. )
  1631. # load pt weights early so that we know which dtype to init the model under
  1632. if not is_sharded and state_dict is None:
  1633. # 4. loading non-sharded ckpt from the state dict
  1634. if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
  1635. "model_state.pdparams"
  1636. ):
  1637. state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
  1638. elif config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
  1639. "model.safetensors"
  1640. ):
  1641. raise NotImplementedError
  1642. else:
  1643. try:
  1644. transpose_weight_keys = model.get_transpose_weight_keys()
  1645. except NotImplementedError:
  1646. if convert_from_hf:
  1647. raise ValueError("`convert_from_hf=True` is not supported")
  1648. else:
  1649. transpose_weight_keys = None
  1650. state_dict = load_state_dict(
  1651. resolved_archive_file,
  1652. convert_from_hf=convert_from_hf,
  1653. transpose_weight_keys=transpose_weight_keys,
  1654. )
  1655. logging.info("Loaded weights file from disk, setting weights to model.")
  1656. # Check if `_keep_in_fp32_modules` is not None
  1657. use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
  1658. dtype == "float16" or dtype == "bfloat16"
  1659. )
  1660. if state_dict is not None:
  1661. loaded_state_dict_keys = [k for k in state_dict.keys()]
  1662. # will only support load paddle.Tensor to model.
  1663. for k in list(state_dict.keys()):
  1664. if not isinstance(state_dict[k], paddle.Tensor):
  1665. with device_guard():
  1666. state_dict[k] = paddle.Tensor.__call__(
  1667. state_dict.pop(k), zero_copy=True
  1668. )
  1669. else:
  1670. if is_sharded:
  1671. loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
  1672. else:
  1673. loaded_state_dict_keys = [k for k in state_dict.keys()]
  1674. if low_cpu_mem_usage: # or use_keep_in_fp32_modules:
  1675. state_dict = None
  1676. # will only support load paddle.Tensor to model.
  1677. if state_dict is not None:
  1678. for k in list(state_dict.keys()):
  1679. if not isinstance(state_dict[k], paddle.Tensor):
  1680. with device_guard():
  1681. state_dict[k] = paddle.Tensor.__call__(
  1682. state_dict.pop(k), zero_copy=True
  1683. )
  1684. if use_keep_in_fp32_modules:
  1685. # low_cpu_mem_usage = True
  1686. keep_in_fp32_modules = model._keep_in_fp32_modules
  1687. else:
  1688. keep_in_fp32_modules = []
  1689. quantization_linear_list = None
  1690. model, missing_keys, unexpected_keys, mismatched_keys = (
  1691. cls._load_pretrained_model(
  1692. model=model,
  1693. state_dict=state_dict,
  1694. loaded_keys=loaded_state_dict_keys,
  1695. resolved_archive_file=(
  1696. resolved_sharded_files if is_sharded else resolved_archive_file
  1697. ),
  1698. pretrained_model_name_or_path=pretrained_model_name_or_path,
  1699. config=config,
  1700. ignore_mismatched_sizes=ignore_mismatched_sizes,
  1701. low_cpu_mem_usage=low_cpu_mem_usage,
  1702. dtype=dtype,
  1703. keep_in_fp32_modules=keep_in_fp32_modules,
  1704. quantization_linear_list=quantization_linear_list,
  1705. sharded_metadata=sharded_metadata if is_sharded else None,
  1706. convert_from_hf=convert_from_hf,
  1707. )
  1708. )
  1709. # load generation_config.json
  1710. if model.can_generate() and pretrained_model_name_or_path is not None:
  1711. try:
  1712. model.generation_config = GenerationConfig.from_pretrained(
  1713. pretrained_model_name_or_path,
  1714. cache_dir=cache_dir,
  1715. force_download=force_download,
  1716. from_hf_hub=from_hf_hub,
  1717. from_aistudio=from_aistudio,
  1718. subfolder=subfolder,
  1719. **kwargs,
  1720. )
  1721. except:
  1722. logging.info(
  1723. "Generation config file not found, using a generation config created from the model config."
  1724. )
  1725. pass
  1726. # Note:
  1727. # 1. PipelineLayer will create parameters for each layer and
  1728. # call `_synchronize_shared_weights()` to synchronize the shared parameters.
  1729. # 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
  1730. # synchronize the shared parameters.
  1731. # However, when state dict only contains the one piece of shared parameters, the shared parameters
  1732. # will be different from the original shared parameters.
  1733. if isinstance(model, PipelineLayer):
  1734. model._synchronize_shared_weights()
  1735. if paddle.in_dynamic_mode():
  1736. return model
  1737. return model, state_dict
  1738. def merge_auto_dist_configs(self, configs):
  1739. """
  1740. Merged all auto dist configs into one config.
  1741. configs is a list of config,every config is a dict,which means a model auto_dist_config.
  1742. [
  1743. {
  1744. mp_config (dict): {
  1745. "parallelize_plan": dict, the plan to shard the layer.
  1746. }
  1747. pp_config (dict): {
  1748. "split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
  1749. "global_spec": str|list(str), make the output tensor of specific layers on global mesh.
  1750. }
  1751. },{
  1752. mp_config (dict): {
  1753. "parallelize_plan": dict, the plan to shard the layer.
  1754. }
  1755. pp_config (dict): {
  1756. "split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
  1757. "global_spec": str|list(str), make the output tensor of specific layers on global mesh.
  1758. }
  1759. },....
  1760. ]
  1761. """
  1762. assert isinstance(configs, (dict, list))
  1763. if isinstance(configs, dict):
  1764. return configs
  1765. final_config = {
  1766. "mp_config": None,
  1767. "sp_config": None,
  1768. "pp_config": None,
  1769. }
  1770. for config in configs:
  1771. if "mp_config" in config and config["mp_config"] is not None:
  1772. if final_config["mp_config"] is None:
  1773. final_config["mp_config"] = config["mp_config"]
  1774. else:
  1775. for k, v in config["mp_config"]["parallelize_plan"].items():
  1776. assert (
  1777. k
  1778. not in final_config["mp_config"]["parallelize_plan"].keys()
  1779. ), f"sublayer mp_config should be a subset of model but got sublayer config {config['mp_config']} and model config {final_config['mp_config']}."
  1780. final_config["mp_config"]["parallelize_plan"][k] = v
  1781. if "sp_config" in config and config["sp_config"] is not None:
  1782. if final_config["sp_config"] is None:
  1783. final_config["sp_config"] = config["sp_config"]
  1784. else:
  1785. for k, v in config["sp_config"]["parallelize_plan"].items():
  1786. assert (
  1787. k
  1788. not in final_config["sp_config"]["parallelize_plan"].keys()
  1789. ), f"sublayer sp_config should be a subset of model but got sublayer config {config['sp_config']} and model config {final_config['sp_config']}."
  1790. final_config["sp_config"]["parallelize_plan"][k] = v
  1791. if "pp_config" in config and config["pp_config"] is not None:
  1792. if isinstance(config["pp_config"]["split_spec"], str):
  1793. config["pp_config"]["split_spec"] = [
  1794. config["pp_config"]["split_spec"]
  1795. ]
  1796. if final_config["pp_config"] is None:
  1797. final_config["pp_config"] = config["pp_config"]
  1798. else:
  1799. final_config["pp_config"]["split_spec"] += config["pp_config"][
  1800. "split_spec"
  1801. ]
  1802. elif isinstance(config["pp_config"]["split_spec"], (tuple, list)):
  1803. if final_config["pp_config"] is None:
  1804. final_config["pp_config"] = config["pp_config"]
  1805. else:
  1806. final_config["pp_config"]["split_spec"] += config["pp_config"][
  1807. "split_spec"
  1808. ]
  1809. if (
  1810. final_config["pp_config"] is not None
  1811. and len(final_config["pp_config"]["split_spec"]) == 1
  1812. ):
  1813. final_config["pp_config"]["split_spec"] = final_config["pp_config"][
  1814. "split_spec"
  1815. ][0]
  1816. return final_config
  1817. def _generate_auto_dist_config(self, auto_dist_degree):
  1818. merged_config = {
  1819. "sp_config": None,
  1820. "mp_config": None,
  1821. "pp_config": None,
  1822. }
  1823. for name, layer in self.named_sublayers(include_self=True):
  1824. if hasattr(layer, "auto_dist_config"):
  1825. if name != "":
  1826. prefix = name + "."
  1827. else:
  1828. prefix = ""
  1829. layer_config = layer.auto_dist_config(prefix)
  1830. merged_config = self.merge_auto_dist_configs(
  1831. [merged_config, layer_config]
  1832. )
  1833. for _, deeper_layer in layer.named_sublayers():
  1834. if hasattr(deeper_layer, "auto_dist_config"):
  1835. # mask all `auto_dist_config` methods in deeper layer
  1836. deeper_layer.auto_dist_config = lambda x: {}
  1837. final_config = {
  1838. "dp_config": None,
  1839. "mp_config": None,
  1840. "pp_config": None,
  1841. }
  1842. if (
  1843. "tensor_parallel" in auto_dist_degree
  1844. and auto_dist_degree["tensor_parallel"]
  1845. ):
  1846. merged_config["mp_config"] is not None
  1847. final_config["mp_config"] = merged_config["mp_config"]
  1848. if (
  1849. "sequence_parallel" in auto_dist_degree
  1850. and auto_dist_degree["sequence_parallel"]
  1851. ):
  1852. merged_config["sp_config"] is not None
  1853. final_config["mp_config"] = merged_config["sp_config"]
  1854. if (
  1855. "pipeline_parallel" in auto_dist_degree
  1856. and auto_dist_degree["pipeline_parallel"]
  1857. ):
  1858. merged_config["pp_config"] is not None
  1859. final_config["pp_config"] = merged_config["pp_config"]
  1860. return final_config
  1861. def get_transpose_weight_keys(self):
  1862. raise NotImplementedError
  1863. def get_hf_state_dict(self, *args, **kwargs):
  1864. raise NotImplementedError
  1865. def set_hf_state_dict(self, *args, **kwargs):
  1866. raise NotImplementedError