conversion_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. from dataclasses import dataclass
  16. from typing import TYPE_CHECKING, List, Optional, TypeVar
  17. import numpy as np
  18. import paddle
  19. from numpy import ndarray, transpose
  20. from ......utils import logging
  21. from ..distributed import distributed_allgather, distributed_gather
  22. from ..utils import device_guard, get_env_device
  23. if TYPE_CHECKING:
  24. from .configuration_utils import PretrainedConfig
  25. # the type hinting for pytorch model & layer & tensor
  26. Module = TypeVar("Module")
  27. PytorchTensor = TypeVar("PytorchTensor")
  28. @dataclass
  29. class StateDictNameMapping:
  30. """NameMapping of StateDict between two models"""
  31. source_name: str
  32. target_name: str = None
  33. action: Optional[str] = None # the value can be: transpose, merge_last_two_dim
  34. index: Optional[int] = None
  35. slots: list[str] = None
  36. def __post_init__(self):
  37. self.target_name = self.target_name or self.source_name
  38. def should_transpose(self) -> bool:
  39. return self.action == "transpose"
  40. def should_merge_last_two_dim(self) -> bool:
  41. """check that whether merge last two dim"""
  42. return self.action == "merge_last_two_dim"
  43. def run(self, state_dict: dict[str, ndarray], name: str) -> ndarray:
  44. """run some custom operation on ndarray, eg: transpose, merge_last_two_dim
  45. Args:
  46. tensor (ndarray): the source of the tensor data
  47. Returns:
  48. ndarray: the final tensor
  49. """
  50. tensor = state_dict.pop(name)
  51. if callable(self.action):
  52. return self.action(tensor)
  53. if self.action == "transpose":
  54. return transpose(tensor, [1, 0])
  55. if self.action == "merge_last_two_dim":
  56. shape = tensor.shape
  57. assert len(shape) == 3
  58. return np.reshape(tensor, [shape[0], -1])
  59. if self.action == "split":
  60. assert (
  61. self.index is not None
  62. ), "when action is `split`, index field is required."
  63. # FIXME if the order of split starts from index=2, no tensor left.
  64. if self.index < 2:
  65. state_dict[name] = tensor
  66. # qkv is stored in same tensor, so it should be split into 3 arr
  67. tensors = np.split(tensor, 3, axis=-1)
  68. return tensors[self.index]
  69. return tensor
  70. def matched(self, text: str) -> bool:
  71. """check whether the layer_name match the current pattern
  72. Args:
  73. text (str): the name of layer
  74. Returns:
  75. bool: whether the
  76. """
  77. if text == self.source_name:
  78. return True
  79. if not self.slots:
  80. return False
  81. class ConversionMixin:
  82. @classmethod
  83. def support_conversion(cls, config: PretrainedConfig) -> bool:
  84. """check whether the model support conversion"""
  85. try:
  86. # try to get the name-mapping info
  87. _ = cls._get_name_mappings(config)
  88. except NotImplementedError:
  89. return False
  90. finally:
  91. return True
  92. @classmethod
  93. def _get_name_mappings(cls, config: PretrainedConfig) -> List[StateDictNameMapping]:
  94. """get name mapping of PretrainedModel
  95. Args:
  96. config (PretrainedConfig): the configuration of name-mapping
  97. Raises:
  98. NotImplementedError:
  99. Returns:
  100. List[StateDictNameMapping]: the name-mappings of pretrained model
  101. """
  102. raise NotImplementedError
  103. @classmethod
  104. def get_tensor_parallel_convert_actions(
  105. cls,
  106. config: PretrainedConfig,
  107. loaded_state_dict_keys,
  108. is_split=True,
  109. ignore_error=False,
  110. ):
  111. name_action_mappings = cls._get_tensor_parallel_mappings(
  112. config, is_split=is_split
  113. )
  114. state_keys_map = cls._resolve_prefix_keys(
  115. name_action_mappings.keys(), loaded_state_dict_keys, ignore_error
  116. )
  117. for k, v in state_keys_map.items():
  118. name_action_mappings[v] = name_action_mappings.pop(k)
  119. return name_action_mappings
  120. @classmethod
  121. def convert_tensor_parallel(
  122. cls,
  123. weight_file: str,
  124. config: PretrainedConfig,
  125. state_dict=None,
  126. ignore_error=False,
  127. ) -> None:
  128. """the entry of converting config and converting model file
  129. Args:
  130. weight_file (str | None): the weight file path of `model_state.pdparams` file
  131. config (PretrainedConfig): the PretrainedConfig instance of model
  132. """
  133. name_action_mappings = cls._get_tensor_parallel_mappings(config)
  134. if state_dict is None:
  135. with device_guard("cpu"):
  136. state_dict = paddle.load(weight_file, return_numpy=False)
  137. logging.info(
  138. "Starting to convert original state_dict to tensor parallel state_dict."
  139. )
  140. state_keys_map = cls._resolve_prefix_keys(
  141. name_action_mappings.keys(), state_dict.keys(), ignore_error
  142. )
  143. for k, v in state_keys_map.items():
  144. name_action_mappings[v] = name_action_mappings.pop(k)
  145. for name, action in name_action_mappings.items():
  146. if name not in state_dict:
  147. if not ignore_error:
  148. logging.warning(f"Key <{name}> not in the model state weight file.")
  149. continue
  150. tensor = state_dict.pop(name)
  151. new_tensor = action(tensor)
  152. with device_guard("cpu"):
  153. state_dict[name] = paddle.Tensor(new_tensor, zero_copy=True)
  154. return state_dict
  155. @classmethod
  156. def merge_tensor_parallel(cls, state_dict, config) -> None:
  157. """the entry of converting config and converting model file
  158. Args:
  159. input_dir (str | None): the input dir which contains `pytorch_model.bin` and `config.json` file
  160. config (PretrainedConfig): the PretrainedConfig instance of model
  161. """
  162. name_action_mappings = cls._get_tensor_parallel_mappings(config, is_split=False)
  163. state_keys_map = cls._resolve_prefix_keys(
  164. name_action_mappings.keys(), state_dict.keys()
  165. )
  166. for k, v in state_keys_map.items():
  167. name_action_mappings[v] = name_action_mappings.pop(k)
  168. state_dict_to_save = {}
  169. hcg = paddle.distributed.fleet.get_hybrid_communicate_group()
  170. mp_group = hcg.get_model_parallel_group()
  171. is_dst = paddle.distributed.get_rank(mp_group) == 0
  172. for key in state_dict.keys():
  173. tensor = state_dict[key]
  174. if key in name_action_mappings:
  175. if get_env_device() == "xpu":
  176. ret = distributed_allgather(tensor, group=mp_group, offload=True)
  177. else:
  178. ret = distributed_gather(tensor, group=mp_group, offload=True)
  179. action = name_action_mappings.pop(key)
  180. tensor = action(ret) if is_dst else None
  181. else:
  182. tensor = tensor.cpu().numpy() if is_dst else None
  183. # keep state dict use paddle.tensor
  184. if isinstance(tensor, np.ndarray):
  185. with device_guard("cpu"):
  186. tensor = paddle.Tensor(tensor, zero_copy=True)
  187. state_dict_to_save[key] = tensor
  188. if len(name_action_mappings) > 0:
  189. for x in name_action_mappings.keys():
  190. logging.debug(
  191. f"key <{x}> need to merge tensor parallel but we can't find in model state."
  192. )
  193. return state_dict_to_save
  194. @classmethod
  195. def _get_tensor_parallel_mappings(
  196. cls, config: PretrainedConfig, is_split=True
  197. ) -> List[StateDictNameMapping]:
  198. """get name mapping of PretrainedModel
  199. Args:
  200. config (PretrainedConfig): the configuration of name-mapping
  201. Raises:
  202. NotImplementedError:
  203. Returns:
  204. List[StateDictNameMapping]: the name-mappings for tensor_parallel
  205. """
  206. raise NotImplementedError
  207. @staticmethod
  208. def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):
  209. # state_keys_map base to real
  210. state_keys_map = {}
  211. # sorted by length,match from long to short for A.key B.key ...
  212. state_keys_base = sorted(state_keys_base, key=lambda x: len(x), reverse=True)
  213. state_keys_real = set(state_keys_real)
  214. for key in state_keys_base:
  215. for x in state_keys_real:
  216. if x.endswith(key):
  217. state_keys_map[key] = x
  218. break
  219. if key not in state_keys_map:
  220. if not ignore_error:
  221. logging.debug(
  222. f"tensor parallel conversion: could not find name {key} in loaded state dict!"
  223. )
  224. else:
  225. state_keys_real.remove(state_keys_map[key])
  226. return state_keys_map
  227. @classmethod
  228. def convert_fuse_and_split(
  229. cls, config: PretrainedConfig, state_dict, tp_actions=None
  230. ):
  231. loaded_keys = state_dict.keys()
  232. # collect and convert fuse/split action
  233. fused_and_split_keys = []
  234. convert_with_same_keys = []
  235. fuse_actions, resume_keys = cls.get_fuse_or_split_param_convert_actions(
  236. config, loaded_keys, is_fuse=True
  237. )
  238. for keys, action in fuse_actions.items():
  239. if keys[-1] in keys[:-1]:
  240. assert len(keys) == 2, "only 2 keys can be converted with the same name"
  241. convert_with_same_keys.append(keys[-1])
  242. origin_states = [state_dict.pop(key) for key in keys[:-1]]
  243. state_dict[keys[-1]] = action(origin_states)
  244. fused_and_split_keys.append(keys[-1])
  245. logging.debug(f"Fusing parameter: {keys[:-1]} into {keys[-1]}")
  246. split_actions, _ = cls.get_fuse_or_split_param_convert_actions(
  247. config, loaded_keys, is_fuse=False
  248. )
  249. for keys, action in split_actions.items():
  250. if keys[-1] in keys[:-1]:
  251. assert len(keys) == 2, "only 2 keys can be converted with the same name"
  252. convert_with_same_keys.append(keys[-1])
  253. origin_state = state_dict.pop(keys[-1])
  254. split_states = action(origin_state)
  255. for key_idx, key in enumerate(keys[:-1]):
  256. state_dict[key] = split_states[key_idx]
  257. fused_and_split_keys.append(key)
  258. logging.debug(f"Splitting parameter: {keys[-1]} into {keys[:-1]}")
  259. if tp_actions is not None:
  260. for key in fused_and_split_keys:
  261. if key in convert_with_same_keys:
  262. continue
  263. for name in tp_actions.keys():
  264. if key.endswith(name):
  265. with device_guard():
  266. state_dict[key] = paddle.Tensor(
  267. tp_actions[name](state_dict.pop(key)), zero_copy=True
  268. )
  269. break
  270. resume_state_dict = {k: state_dict[k] for k in resume_keys if k in state_dict}
  271. return state_dict, resume_state_dict
  272. @classmethod
  273. def get_fuse_or_split_param_convert_actions(
  274. cls,
  275. config: PretrainedConfig,
  276. loaded_state_dict_keys,
  277. is_fuse=True,
  278. ignore_error=False,
  279. ):
  280. name_action_mappings = cls._get_fuse_or_split_param_mappings(config, is_fuse)
  281. state_keys_map = cls._resolve_prefix_keys_for_fuse_and_split(
  282. name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, is_fuse
  283. )
  284. for k, v in state_keys_map.items():
  285. name_action_mappings[v] = name_action_mappings.pop(k)
  286. filter_name_action = {}
  287. resume_keys = []
  288. if is_fuse:
  289. for k, v in name_action_mappings.items():
  290. cond = True
  291. if not all(item in loaded_state_dict_keys for item in k[:-1]):
  292. # resume keys for next fuse
  293. resume_keys += k[:-1]
  294. cond = False
  295. if cond:
  296. filter_name_action[k] = v
  297. else:
  298. for k, v in name_action_mappings.items():
  299. if k[-1] in loaded_state_dict_keys:
  300. filter_name_action[k] = v
  301. return filter_name_action, resume_keys
  302. @classmethod
  303. def _get_fuse_or_split_param_mappings(
  304. cls, config: PretrainedConfig, is_fuse=True
  305. ) -> List[StateDictNameMapping]:
  306. """get fused parameter mapping of PretrainedModel
  307. Args:
  308. config (PretrainedConfig): the configuration of name-mapping
  309. Raises:
  310. NotImplementedError:
  311. Returns:
  312. List[StateDictNameMapping]: the name-mappings for tensor_parallel
  313. """
  314. return {}
  315. @staticmethod
  316. def _resolve_prefix_keys_for_fuse_and_split(
  317. state_keys_base, state_keys_real, ignore_error=False, is_fuse=True
  318. ):
  319. state_keys_map = {}
  320. for keys in state_keys_base:
  321. prefix = ""
  322. if is_fuse:
  323. for x in state_keys_real:
  324. for base_key in keys[:-1]:
  325. if x.endswith(base_key):
  326. prefix = x.replace(base_key, "")
  327. break
  328. if prefix != "":
  329. break
  330. else:
  331. base_key = keys[-1]
  332. for x in state_keys_real:
  333. if x.endswith(base_key):
  334. prefix = x.replace(base_key, "")
  335. break
  336. new_keys = tuple([prefix + key for key in keys])
  337. state_keys_map[keys] = new_keys
  338. return state_keys_map