qwen2.py 61 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from functools import partial
  16. from typing import List, Optional, Tuple, Union
  17. import paddle
  18. import paddle.distributed.fleet.meta_parallel as mpu
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle import Tensor
  22. from paddle.distributed import fleet
  23. from paddle.distributed.fleet.utils import sequence_parallel_utils
  24. from .....utils import logging
  25. from .....utils.env import get_device_type
  26. from ...common.vlm import fusion_ops
  27. from ...common.vlm.activations import ACT2FN
  28. from ...common.vlm.transformers import PretrainedConfig, PretrainedModel
  29. from ...common.vlm.transformers.model_outputs import (
  30. BaseModelOutputWithPast,
  31. CausalLMOutputWithPast,
  32. )
  33. try:
  34. from paddle.incubate.nn.functional import fused_rotary_position_embedding
  35. except ImportError:
  36. fused_rotary_position_embedding = None
  37. try:
  38. from paddle.distributed.fleet.utils.sequence_parallel_utils import (
  39. GatherOp,
  40. ScatterOp,
  41. mark_as_sequence_parallel_parameter,
  42. )
  43. except:
  44. pass
  45. try:
  46. from paddle.nn.functional.flash_attention import flash_attention
  47. except:
  48. flash_attention = None
  49. Linear = nn.Linear
  50. ColumnParallelLinear = mpu.ColumnParallelLinear
  51. RowParallelLinear = mpu.RowParallelLinear
  52. ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
  53. RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
  54. class Qwen2Config(PretrainedConfig):
  55. r"""
  56. This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
  57. Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
  58. with the defaults will yield a similar configuration to that of
  59. Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
  60. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  61. documentation from [`PretrainedConfig`] for more information.
  62. Args:
  63. vocab_size (`int`, *optional*, defaults to 151936):
  64. Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
  65. `inputs_ids` passed when calling [`Qwen2Model`]
  66. hidden_size (`int`, *optional*, defaults to 4096):
  67. Dimension of the hidden representations.
  68. intermediate_size (`int`, *optional*, defaults to 22016):
  69. Dimension of the MLP representations.
  70. num_hidden_layers (`int`, *optional*, defaults to 32):
  71. Number of hidden layers in the Transformer encoder.
  72. num_attention_heads (`int`, *optional*, defaults to 32):
  73. Number of attention heads for each attention layer in the Transformer encoder.
  74. num_key_value_heads (`int`, *optional*, defaults to 32):
  75. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  76. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  77. `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  78. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  79. by meanpooling all the original heads within that group. For more details checkout [this
  80. paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
  81. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  82. The non-linear activation function (function or string) in the decoder.
  83. max_position_embeddings (`int`, *optional*, defaults to 32768):
  84. The maximum sequence length that this model might ever be used with.
  85. initializer_range (`float`, *optional*, defaults to 0.02):
  86. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  87. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  88. The epsilon used by the rms normalization layers.
  89. use_cache (`bool`, *optional*, defaults to `True`):
  90. Whether or not the model should return the last key/values attentions (not used by all models). Only
  91. relevant if `config.is_decoder=True`.
  92. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  93. Whether the model's input and output word embeddings should be tied.
  94. rope_theta (`float`, *optional*, defaults to 10000.0):
  95. The base period of the RoPE embeddings.
  96. use_sliding_window (`bool`, *optional*, defaults to `False`):
  97. Whether to use sliding window attention.
  98. sliding_window (`int`, *optional*, defaults to 4096):
  99. Sliding window attention (SWA) window size. If not specified, will default to `4096`.
  100. max_window_layers (`int`, *optional*, defaults to 28):
  101. The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
  102. attention_dropout (`float`, *optional*, defaults to 0.0):
  103. The dropout ratio for the attention probabilities.
  104. """
  105. model_type = "qwen2"
  106. keys_to_ignore_at_inference = ["past_key_values"]
  107. def __init__(
  108. self,
  109. vocab_size=151936,
  110. hidden_size=4096,
  111. intermediate_size=22016,
  112. num_hidden_layers=32,
  113. num_attention_heads=32,
  114. num_key_value_heads=32,
  115. hidden_act="silu",
  116. max_position_embeddings=32768,
  117. seq_length=32768,
  118. initializer_range=0.02,
  119. rms_norm_eps=1e-6,
  120. use_cache=True,
  121. tie_word_embeddings=False,
  122. rope_theta=10000.0,
  123. pad_token_id=0,
  124. bos_token_id=151643,
  125. eos_token_id=151643,
  126. use_sliding_window=False,
  127. sliding_window=4096,
  128. max_window_layers=28,
  129. attention_dropout=0.0,
  130. rope_scaling_factor=1.0,
  131. rope_scaling_type=None,
  132. dpo_config=None,
  133. **kwargs,
  134. ):
  135. self.vocab_size = vocab_size
  136. self.max_position_embeddings = max_position_embeddings
  137. self.seq_length = seq_length
  138. self.hidden_size = hidden_size
  139. self.intermediate_size = intermediate_size
  140. self.num_hidden_layers = num_hidden_layers
  141. self.num_attention_heads = num_attention_heads
  142. self.use_sliding_window = use_sliding_window
  143. self.sliding_window = sliding_window
  144. self.max_window_layers = max_window_layers
  145. # for backward compatibility
  146. if num_key_value_heads is None:
  147. num_key_value_heads = num_attention_heads
  148. self.num_key_value_heads = num_key_value_heads
  149. self.hidden_act = hidden_act
  150. self.initializer_range = initializer_range
  151. self.rms_norm_eps = rms_norm_eps
  152. self.use_cache = use_cache
  153. self.rope_theta = rope_theta
  154. self.attention_dropout = attention_dropout
  155. self.use_cache = use_cache
  156. self.rope_scaling_factor = rope_scaling_factor
  157. self.rope_scaling_type = rope_scaling_type
  158. self.pad_token_id = pad_token_id
  159. self.bos_token_id = bos_token_id
  160. self.eos_token_id = eos_token_id
  161. self.dpo_config = dpo_config
  162. super().__init__(
  163. pad_token_id=pad_token_id,
  164. bos_token_id=bos_token_id,
  165. eos_token_id=eos_token_id,
  166. tie_word_embeddings=tie_word_embeddings,
  167. **kwargs,
  168. )
  169. def get_triangle_upper_mask(x, mask=None):
  170. if mask is not None:
  171. return mask
  172. # [bsz, n_head, q_len, kv_seq_len]
  173. shape = x.shape
  174. # [bsz, 1, q_len, kv_seq_len]
  175. shape[1] = 1
  176. mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype)
  177. mask = paddle.triu(mask, diagonal=1)
  178. mask.stop_gradient = True
  179. return mask
  180. def parallel_matmul(
  181. x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True
  182. ):
  183. is_fleet_init = True
  184. tensor_parallel_degree = 1
  185. try:
  186. hcg = fleet.get_hybrid_communicate_group()
  187. model_parallel_group = hcg.get_model_parallel_group()
  188. tensor_parallel_degree = hcg.get_model_parallel_world_size()
  189. except:
  190. is_fleet_init = False
  191. if paddle.in_dynamic_mode():
  192. y_is_distributed = y.is_distributed
  193. else:
  194. y_is_distributed = tensor_parallel_degree > 1
  195. if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
  196. # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
  197. input_parallel = paddle.distributed.collective._c_identity(
  198. x, group=model_parallel_group
  199. )
  200. logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
  201. if tensor_parallel_output:
  202. return logits
  203. return paddle.distributed.collective._c_concat(
  204. logits, group=model_parallel_group
  205. )
  206. else:
  207. logits = paddle.matmul(x, y, transpose_y=transpose_y)
  208. return logits
  209. def scaled_dot_product_attention(
  210. query_states,
  211. config,
  212. key_states,
  213. value_states,
  214. attention_mask,
  215. output_attentions,
  216. attn_mask_startend_row_indices=None,
  217. training=True,
  218. sequence_parallel=False,
  219. skip_recompute=False,
  220. ):
  221. bsz, q_len, num_heads, head_dim = query_states.shape
  222. _, kv_seq_len, _, _ = value_states.shape
  223. # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
  224. query_states = paddle.transpose(query_states, [0, 2, 1, 3])
  225. # merge with the next transpose
  226. key_states = paddle.transpose(key_states, [0, 2, 1, 3])
  227. value_states = paddle.transpose(value_states, [0, 2, 1, 3])
  228. # Add pre divided factor to fix nan under float16.
  229. if paddle.in_dynamic_mode() and query_states.dtype == paddle.float16:
  230. pre_divided_factor = 32
  231. else:
  232. pre_divided_factor = 1
  233. attn_weights = paddle.matmul(
  234. query_states / (math.sqrt(head_dim) * pre_divided_factor),
  235. key_states.transpose([0, 1, 3, 2]),
  236. )
  237. if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]:
  238. raise ValueError(
  239. f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is"
  240. f" {attn_weights.shape}"
  241. )
  242. if attention_mask is None:
  243. attention_mask = get_triangle_upper_mask(attn_weights)
  244. attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len])
  245. if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]:
  246. raise ValueError(
  247. f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}"
  248. )
  249. attn_weights = attn_weights + attention_mask
  250. if not paddle.in_dynamic_mode():
  251. attn_weights = F.softmax(
  252. attn_weights * pre_divided_factor, axis=-1, dtype="float32"
  253. ).astype(query_states.dtype)
  254. else:
  255. with paddle.amp.auto_cast(False):
  256. attn_weights = F.softmax(
  257. attn_weights.astype("float32") * pre_divided_factor,
  258. axis=-1,
  259. dtype="float32",
  260. ).astype(query_states.dtype)
  261. attn_weights = F.dropout(
  262. attn_weights, p=config.attention_dropout, training=training
  263. )
  264. attn_output = paddle.matmul(attn_weights, value_states)
  265. attn_output = attn_output.transpose([0, 2, 1, 3])
  266. if sequence_parallel:
  267. attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
  268. else:
  269. attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
  270. return (attn_output, attn_weights) if output_attentions else attn_output
  271. def is_casual_mask(attention_mask):
  272. """
  273. Upper triangular of attention_mask equals to attention_mask is casual
  274. """
  275. return (paddle.triu(attention_mask) == attention_mask).all().item()
  276. def _make_causal_mask(input_ids_shape, past_key_values_length):
  277. """
  278. Make causal mask used for self-attention
  279. """
  280. batch_size, target_length = input_ids_shape # target_length: seq_len
  281. mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
  282. if past_key_values_length > 0:
  283. # [tgt_len, tgt_len + past_len]
  284. mask = paddle.concat(
  285. [paddle.ones([target_length, past_key_values_length], dtype="bool"), mask],
  286. axis=-1,
  287. )
  288. # [bs, 1, tgt_len, tgt_len + past_len]
  289. return mask[None, None, :, :].expand(
  290. [batch_size, 1, target_length, target_length + past_key_values_length]
  291. )
  292. def _expand_2d_mask(mask, dtype, tgt_length):
  293. """
  294. Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
  295. """
  296. batch_size, src_length = mask.shape[0], mask.shape[-1]
  297. tgt_length = tgt_length if tgt_length is not None else src_length
  298. mask = mask[:, None, None, :].astype("bool")
  299. mask.stop_gradient = True
  300. expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
  301. return expanded_mask
  302. class Qwen2RMSNorm(nn.Layer):
  303. def __init__(self, config: Qwen2Config):
  304. """
  305. Qwen2RMSNorm is equivalent to T5LayerNorm
  306. """
  307. super().__init__()
  308. self.hidden_size = config.hidden_size
  309. self.weight = paddle.create_parameter(
  310. shape=[self.hidden_size],
  311. dtype=paddle.get_default_dtype(),
  312. default_initializer=nn.initializer.Constant(1.0),
  313. )
  314. self.variance_epsilon = config.rms_norm_eps
  315. self.config = config
  316. if config.sequence_parallel:
  317. mark_as_sequence_parallel_parameter(self.weight)
  318. def forward(self, hidden_states):
  319. if self.config.use_fused_rms_norm:
  320. return fusion_ops.fusion_rms_norm(
  321. hidden_states, self.weight, self.variance_epsilon, False
  322. )
  323. if paddle.in_dynamic_mode():
  324. with paddle.amp.auto_cast(False):
  325. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  326. hidden_states = (
  327. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  328. )
  329. else:
  330. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  331. hidden_states = (
  332. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  333. )
  334. if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
  335. hidden_states = paddle.cast(hidden_states, self.weight.dtype)
  336. return hidden_states * self.weight
  337. class Qwen2RotaryEmbedding(nn.Layer):
  338. def __init__(self, dim, max_position_embeddings=2048, base=10000):
  339. super().__init__()
  340. self.dim = dim
  341. self.max_position_embeddings = max_position_embeddings
  342. self.base = base
  343. # [dim / 2]
  344. self.inv_freq = 1.0 / (
  345. self.base
  346. ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)
  347. )
  348. self._set_cos_sin_cache(seq_len=max_position_embeddings)
  349. def _set_cos_sin_cache(self, seq_len):
  350. self.max_seq_len_cached = seq_len
  351. # [seq_len]
  352. t = paddle.arange(seq_len, dtype="float32")
  353. # [seq_len, dim/2]
  354. freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
  355. # Different from paper, but it uses a different permutation in order to obtain the same calculation
  356. # [seq_len, dim]
  357. emb = paddle.concat([freqs, freqs], axis=-1)
  358. # [1, seqlen, 1, dim]
  359. self.cos_cached = emb.cos()[None, :, None, :]
  360. self.sin_cached = emb.sin()[None, :, None, :]
  361. def forward(self, x, seq_len=None):
  362. # x: [bs, num_attention_heads, seq_len, head_size]
  363. if seq_len > self.max_seq_len_cached:
  364. self._set_cos_sin_cache(seq_len)
  365. cos = self.cos_cached[:, :seq_len, :, :]
  366. sin = self.sin_cached[:, :seq_len, :, :]
  367. return (
  368. cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
  369. sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
  370. )
  371. def rotate_half(x):
  372. """Rotates half the hidden dims of the input."""
  373. x1 = x[..., : x.shape[-1] // 2]
  374. x2 = x[..., x.shape[-1] // 2 :]
  375. return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
  376. def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  377. if position_ids is None:
  378. # Note: Only for Qwen2MoEForCausalLMPipe model pretraining
  379. cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
  380. sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
  381. else:
  382. cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
  383. sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
  384. cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
  385. sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
  386. q_embed = (q * cos) + (rotate_half(q) * sin)
  387. k_embed = (k * cos) + (rotate_half(k) * sin)
  388. return q_embed, k_embed
  389. class Qwen2MLP(nn.Layer):
  390. def __init__(self, config: Qwen2Config, is_shared=False, skip_recompute_ops=None):
  391. super().__init__()
  392. if skip_recompute_ops is None:
  393. skip_recompute_ops = {}
  394. self.skip_recompute_ops = skip_recompute_ops
  395. self.hidden_size = config.hidden_size
  396. self.intermediate_size = config.intermediate_size
  397. self.fuse_attention_ffn = config.fuse_attention_ffn
  398. self.tensor_parallel_degree = config.tensor_parallel_degree
  399. if config.sequence_parallel:
  400. ColumnParallelLinear = ColumnSequenceParallelLinear
  401. RowParallelLinear = RowSequenceParallelLinear
  402. if config.tensor_parallel_degree > 1:
  403. if self.fuse_attention_ffn:
  404. self.gate_up_fused_proj = ColumnParallelLinear(
  405. self.hidden_size,
  406. self.intermediate_size * 2,
  407. gather_output=False,
  408. has_bias=False,
  409. )
  410. else:
  411. self.gate_proj = ColumnParallelLinear(
  412. self.hidden_size,
  413. self.intermediate_size,
  414. gather_output=False,
  415. has_bias=False,
  416. )
  417. self.up_proj = ColumnParallelLinear(
  418. self.hidden_size,
  419. self.intermediate_size,
  420. gather_output=False,
  421. has_bias=False,
  422. )
  423. self.down_proj = RowParallelLinear(
  424. self.intermediate_size,
  425. self.hidden_size,
  426. input_is_parallel=True,
  427. has_bias=False,
  428. )
  429. else:
  430. if self.fuse_attention_ffn:
  431. self.gate_up_fused_proj = Linear(
  432. self.hidden_size, self.intermediate_size * 2, bias_attr=False
  433. )
  434. else:
  435. self.gate_proj = Linear(
  436. self.hidden_size, self.intermediate_size, bias_attr=False
  437. ) # w1
  438. self.up_proj = Linear(
  439. self.hidden_size, self.intermediate_size, bias_attr=False
  440. ) # w3
  441. self.down_proj = Linear(
  442. self.intermediate_size, self.hidden_size, bias_attr=False
  443. ) # w2
  444. if config.hidden_act == "silu":
  445. self.act_fn = fusion_ops.swiglu
  446. self.fuse_swiglu = True
  447. else:
  448. self.act_fn = ACT2FN[config.hidden_act]
  449. self.fuse_swiglu = False
  450. def forward(self, x):
  451. if self.fuse_attention_ffn:
  452. x = self.gate_up_fused_proj(x)
  453. if self.fuse_swiglu:
  454. y = None
  455. else:
  456. x, y = x.chunk(2, axis=-1)
  457. else:
  458. x, y = self.gate_proj(x), self.up_proj(x)
  459. if self.fuse_swiglu:
  460. x = self.act_fn(x, y)
  461. else:
  462. x = self.act_fn(x) * y
  463. return self.down_proj(x)
  464. def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
  465. """
  466. This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch,
  467. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  468. """
  469. batch, slen, num_key_value_heads, head_dim = hidden_states.shape
  470. if n_rep == 1:
  471. return hidden_states
  472. hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
  473. return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim])
  474. class Qwen2Attention(nn.Layer):
  475. """
  476. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  477. and "Generating Long Sequences with Sparse Transformers".
  478. """
  479. def __init__(
  480. self,
  481. config: Qwen2Config,
  482. layerwise_recompute: bool = True,
  483. skip_recompute_ops=None,
  484. ):
  485. super().__init__()
  486. if skip_recompute_ops is None:
  487. skip_recompute_ops = {}
  488. self.config = config
  489. self.skip_recompute_ops = skip_recompute_ops
  490. self.hidden_size = config.hidden_size
  491. self.num_heads = config.num_attention_heads
  492. self.head_dim = self.hidden_size // config.num_attention_heads
  493. self.num_key_value_heads = config.num_key_value_heads
  494. assert config.num_attention_heads // config.num_key_value_heads
  495. self.num_key_value_groups = (
  496. config.num_attention_heads // config.num_key_value_heads
  497. )
  498. self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads
  499. self.max_position_embeddings = config.max_position_embeddings
  500. self.rope_theta = config.rope_theta
  501. self.is_causal = True
  502. self.attention_dropout = config.attention_dropout
  503. self.seq_length = config.seq_length
  504. self.sequence_parallel = config.sequence_parallel
  505. self.fuse_attention_qkv = config.fuse_attention_qkv
  506. # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
  507. # Enable_recompute defaults to False and is controlled by Trainer
  508. self.enable_recompute = False
  509. self.layerwise_recompute = layerwise_recompute
  510. self.recompute_granularity = config.recompute_granularity
  511. if config.tensor_parallel_degree > 1:
  512. assert (
  513. self.num_heads % config.tensor_parallel_degree == 0
  514. ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  515. self.num_heads = self.num_heads // config.tensor_parallel_degree
  516. assert (
  517. self.num_key_value_heads % config.tensor_parallel_degree == 0
  518. ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  519. self.num_key_value_heads = (
  520. self.num_key_value_heads // config.tensor_parallel_degree
  521. )
  522. self.use_fused_rope = config.use_fused_rope
  523. if self.use_fused_rope:
  524. if (
  525. get_device_type() not in ["gpu", "xpu"]
  526. or fused_rotary_position_embedding is None
  527. ):
  528. logging.warning(
  529. "Enable fuse rope in the config, but fuse rope is not available. "
  530. "Will disable fuse rope. Try using latest gpu version of Paddle."
  531. )
  532. self.use_fused_rope = False
  533. if config.sequence_parallel:
  534. ColumnParallelLinear = ColumnSequenceParallelLinear
  535. RowParallelLinear = RowSequenceParallelLinear
  536. if config.tensor_parallel_degree > 1:
  537. if self.fuse_attention_qkv:
  538. self.qkv_proj = ColumnParallelLinear(
  539. self.hidden_size,
  540. self.hidden_size
  541. + 2 * self.config.num_key_value_heads * self.head_dim,
  542. has_bias=True,
  543. gather_output=False,
  544. )
  545. else:
  546. self.q_proj = ColumnParallelLinear(
  547. self.hidden_size,
  548. self.hidden_size,
  549. has_bias=True,
  550. gather_output=False,
  551. )
  552. self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  553. self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  554. self.o_proj = RowParallelLinear(
  555. self.hidden_size,
  556. self.hidden_size,
  557. has_bias=False,
  558. input_is_parallel=True,
  559. )
  560. else:
  561. if self.fuse_attention_qkv:
  562. self.qkv_proj = Linear(
  563. self.hidden_size,
  564. self.hidden_size
  565. + 2 * self.config.num_key_value_heads * self.head_dim,
  566. )
  567. else:
  568. self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True)
  569. self.k_proj = Linear(
  570. self.hidden_size,
  571. self.config.num_key_value_heads * self.head_dim,
  572. bias_attr=True,
  573. )
  574. self.v_proj = Linear(
  575. self.hidden_size,
  576. self.config.num_key_value_heads * self.head_dim,
  577. bias_attr=True,
  578. )
  579. self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False)
  580. self.rotary_emb = Qwen2RotaryEmbedding(
  581. self.head_dim,
  582. max_position_embeddings=self.max_position_embeddings,
  583. base=self.rope_theta,
  584. )
  585. self.attn_func = scaled_dot_product_attention
  586. def forward(
  587. self,
  588. hidden_states,
  589. position_ids: Optional[Tuple[paddle.Tensor]] = None,
  590. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  591. attention_mask: Optional[paddle.Tensor] = None,
  592. output_attentions: bool = False,
  593. use_cache: bool = False,
  594. attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
  595. **kwargs,
  596. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
  597. """Input shape: Batch x Time x Channel"""
  598. # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
  599. if self.fuse_attention_qkv:
  600. mix_layer = self.qkv_proj(hidden_states)
  601. if self.sequence_parallel:
  602. target_shape = [
  603. -1,
  604. self.seq_length,
  605. self.num_key_value_heads,
  606. (self.num_key_value_groups + 2) * self.head_dim,
  607. ]
  608. else:
  609. target_shape = [
  610. 0,
  611. 0,
  612. self.num_key_value_heads,
  613. (self.num_key_value_groups + 2) * self.head_dim,
  614. ]
  615. mix_layer = paddle.reshape_(mix_layer, target_shape)
  616. query_states, key_states, value_states = paddle.split(
  617. mix_layer,
  618. num_or_sections=[
  619. self.num_key_value_groups * self.head_dim,
  620. self.head_dim,
  621. self.head_dim,
  622. ],
  623. axis=-1,
  624. )
  625. if self.gqa_or_mqa:
  626. query_states = paddle.reshape_(
  627. query_states, [0, 0, self.num_heads, self.head_dim]
  628. )
  629. else:
  630. query_states = self.q_proj(hidden_states)
  631. key_states = self.k_proj(hidden_states)
  632. value_states = self.v_proj(hidden_states)
  633. if self.sequence_parallel:
  634. target_query_shape = [
  635. -1,
  636. self.seq_length,
  637. self.num_heads,
  638. self.head_dim,
  639. ]
  640. target_key_value_shape = [
  641. -1,
  642. self.seq_length,
  643. self.num_key_value_heads,
  644. self.head_dim,
  645. ]
  646. else:
  647. target_query_shape = [0, 0, self.num_heads, self.head_dim]
  648. target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
  649. query_states = query_states.reshape(shape=target_query_shape)
  650. key_states = key_states.reshape(shape=target_key_value_shape)
  651. value_states = value_states.reshape(shape=target_key_value_shape)
  652. kv_seq_len = key_states.shape[-3]
  653. if past_key_value is not None:
  654. kv_seq_len += past_key_value[0].shape[-3]
  655. if self.use_fused_rope:
  656. assert past_key_value is None, "fuse rotary not support cache kv for now"
  657. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  658. query_states, key_states, _ = fused_rotary_position_embedding(
  659. query_states,
  660. key_states,
  661. v=None,
  662. sin=sin,
  663. cos=cos,
  664. position_ids=position_ids,
  665. use_neox_rotary_style=False,
  666. )
  667. else:
  668. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  669. query_states, key_states = apply_rotary_pos_emb(
  670. query_states, key_states, cos, sin, position_ids
  671. )
  672. # [bs, seq_len, num_head, head_dim]
  673. if past_key_value is not None:
  674. key_states = paddle.concat([past_key_value[0], key_states], axis=1)
  675. value_states = paddle.concat([past_key_value[1], value_states], axis=1)
  676. past_key_value = (key_states, value_states) if use_cache else None
  677. # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
  678. # repeat k/v heads if n_kv_heads < n_heads
  679. paddle_version = float(paddle.__version__[:3])
  680. if not self.config.use_flash_attention or (
  681. (paddle_version != 0.0) and (paddle_version <= 2.6)
  682. ):
  683. key_states = repeat_kv(key_states, self.num_key_value_groups)
  684. value_states = repeat_kv(value_states, self.num_key_value_groups)
  685. outputs = self.attn_func(
  686. query_states,
  687. self.config,
  688. key_states,
  689. value_states,
  690. attention_mask,
  691. output_attentions,
  692. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  693. training=self.training,
  694. sequence_parallel=self.sequence_parallel,
  695. )
  696. if output_attentions:
  697. attn_output, attn_weights = outputs
  698. else:
  699. attn_output = outputs
  700. # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
  701. # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
  702. attn_output = self.o_proj(attn_output)
  703. if not output_attentions:
  704. attn_weights = None
  705. outputs = (attn_output,)
  706. if output_attentions:
  707. outputs += (attn_weights,)
  708. if use_cache:
  709. outputs += (past_key_value,)
  710. if type(outputs) is tuple and len(outputs) == 1:
  711. outputs = outputs[0]
  712. return outputs
  713. class Qwen2DecoderLayer(nn.Layer):
  714. def __init__(
  715. self,
  716. config: Qwen2Config,
  717. layerwise_recompute: bool = False,
  718. skip_recompute_ops=None,
  719. ):
  720. super().__init__()
  721. if skip_recompute_ops is None:
  722. skip_recompute_ops = {}
  723. self.config = config
  724. self.skip_recompute_ops = skip_recompute_ops
  725. self.hidden_size = config.hidden_size
  726. self.self_attn = Qwen2Attention(
  727. config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops
  728. )
  729. self.mlp = Qwen2MLP(config, skip_recompute_ops=skip_recompute_ops)
  730. self.input_layernorm = Qwen2RMSNorm(config)
  731. self.post_attention_layernorm = Qwen2RMSNorm(config)
  732. # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
  733. # Enable_recompute defaults to False and is controlled by Trainer
  734. self.enable_recompute = False
  735. self.layerwise_recompute = layerwise_recompute
  736. self.recompute_granularity = config.recompute_granularity
  737. def forward(
  738. self,
  739. hidden_states: paddle.Tensor,
  740. position_ids: Optional[paddle.Tensor] = None,
  741. attention_mask: Optional[paddle.Tensor] = None,
  742. output_attentions: Optional[bool] = False,
  743. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  744. use_cache: Optional[bool] = False,
  745. attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
  746. **kwargs,
  747. ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
  748. """
  749. Args:
  750. hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  751. attention_mask (`paddle.Tensor`, *optional*): attention mask of size
  752. `(batch, sequence_length)` where padding elements are indicated by 0.
  753. output_attentions (`bool`, *optional*):
  754. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  755. returned tensors for more detail.
  756. use_cache (`bool`, *optional*):
  757. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  758. (see `past_key_values`).
  759. past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
  760. """
  761. # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel)
  762. residual = hidden_states
  763. hidden_states = self.input_layernorm(hidden_states)
  764. # Self Attention
  765. outputs = self.self_attn(
  766. hidden_states,
  767. position_ids,
  768. past_key_value,
  769. attention_mask,
  770. output_attentions,
  771. use_cache,
  772. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  773. )
  774. if type(outputs) is tuple:
  775. hidden_states = outputs[0]
  776. else:
  777. hidden_states = outputs
  778. if output_attentions:
  779. self_attn_weights = outputs[1]
  780. if use_cache:
  781. present_key_value = outputs[2 if output_attentions else 1]
  782. hidden_states = residual + hidden_states
  783. # Fully Connected
  784. residual = hidden_states
  785. hidden_states = self.post_attention_layernorm(hidden_states)
  786. hidden_states = self.mlp(hidden_states)
  787. hidden_states = residual + hidden_states
  788. outputs = (hidden_states,)
  789. if output_attentions:
  790. outputs += (self_attn_weights,)
  791. if use_cache:
  792. outputs += (present_key_value,)
  793. if type(outputs) is tuple and len(outputs) == 1:
  794. outputs = outputs[0]
  795. return outputs
  796. class Qwen2PretrainedModel(PretrainedModel):
  797. config_class = Qwen2Config
  798. base_model_prefix = "qwen2"
  799. _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"]
  800. @classmethod
  801. def _get_fuse_or_split_param_mappings(cls, config: Qwen2Config, is_fuse=False):
  802. # return parameter fuse utils
  803. from ...common.vlm.conversion_utils import split_or_fuse_func
  804. fn = split_or_fuse_func(is_fuse=is_fuse)
  805. # last key is fused key, other keys are to be fused.
  806. fuse_qkv_keys = [
  807. (
  808. "layers.0.self_attn.q_proj.weight",
  809. "layers.0.self_attn.k_proj.weight",
  810. "layers.0.self_attn.v_proj.weight",
  811. "layers.0.self_attn.qkv_proj.weight",
  812. ),
  813. (
  814. "layers.0.self_attn.q_proj.bias",
  815. "layers.0.self_attn.k_proj.bias",
  816. "layers.0.self_attn.v_proj.bias",
  817. "layers.0.self_attn.qkv_proj.bias",
  818. ),
  819. ]
  820. fuse_gate_up_keys = (
  821. "layers.0.mlp.gate_proj.weight",
  822. "layers.0.mlp.up_proj.weight",
  823. "layers.0.mlp.gate_up_fused_proj.weight",
  824. )
  825. num_heads = config.num_attention_heads
  826. num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
  827. fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
  828. fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)
  829. final_actions = {}
  830. if is_fuse:
  831. if fuse_attention_qkv:
  832. for i in range(config.num_hidden_layers):
  833. for fuse_keys in fuse_qkv_keys:
  834. keys = tuple(
  835. [
  836. key.replace("layers.0.", f"layers.{i}.")
  837. for key in fuse_keys
  838. ]
  839. )
  840. final_actions[keys] = partial(
  841. fn,
  842. is_qkv=True,
  843. num_heads=num_heads,
  844. num_key_value_heads=num_key_value_heads,
  845. )
  846. if fuse_attention_ffn:
  847. for i in range(config.num_hidden_layers):
  848. keys = tuple(
  849. [
  850. key.replace("layers.0.", f"layers.{i}.")
  851. for key in fuse_gate_up_keys
  852. ]
  853. )
  854. final_actions[keys] = fn
  855. else:
  856. if not fuse_attention_qkv:
  857. for i in range(config.num_hidden_layers):
  858. for fuse_keys in fuse_qkv_keys:
  859. keys = tuple(
  860. [
  861. key.replace("layers.0.", f"layers.{i}.")
  862. for key in fuse_keys
  863. ]
  864. )
  865. final_actions[keys] = partial(
  866. fn,
  867. split_nums=3,
  868. is_qkv=True,
  869. num_heads=num_heads,
  870. num_key_value_heads=num_key_value_heads,
  871. )
  872. if not fuse_attention_ffn:
  873. for i in range(config.num_hidden_layers):
  874. keys = tuple(
  875. [
  876. key.replace("layers.0.", f"layers.{i}.")
  877. for key in fuse_gate_up_keys
  878. ]
  879. )
  880. final_actions[keys] = partial(fn, split_nums=2)
  881. return final_actions
  882. class Qwen2Model(Qwen2PretrainedModel):
  883. """
  884. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
  885. Args:
  886. config: Qwen2Config
  887. """
  888. def __init__(self, config: Qwen2Config):
  889. super().__init__(config)
  890. self.padding_idx = config.pad_token_id
  891. self.vocab_size = config.vocab_size
  892. self.hidden_size = config.hidden_size
  893. self.sequence_parallel = config.sequence_parallel
  894. self.recompute_granularity = config.recompute_granularity
  895. self.no_recompute_layers = (
  896. config.no_recompute_layers if config.no_recompute_layers is not None else []
  897. )
  898. # Recompute defaults to False and is controlled by Trainer
  899. self.enable_recompute = False
  900. if (
  901. config.tensor_parallel_degree > 1
  902. and config.vocab_size % config.tensor_parallel_degree == 0
  903. ):
  904. self.embed_tokens = mpu.VocabParallelEmbedding(
  905. self.vocab_size,
  906. self.hidden_size,
  907. weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
  908. )
  909. else:
  910. self.embed_tokens = nn.Embedding(
  911. self.vocab_size,
  912. self.hidden_size,
  913. )
  914. self.layers = nn.LayerList(
  915. [
  916. Qwen2DecoderLayer(
  917. config=config,
  918. layerwise_recompute=layer_idx not in self.no_recompute_layers,
  919. )
  920. for layer_idx in range(config.num_hidden_layers)
  921. ]
  922. )
  923. self.norm = Qwen2RMSNorm(config)
  924. def get_input_embeddings(self):
  925. return self.embed_tokens
  926. def set_input_embeddings(self, value):
  927. self.embed_tokens = value
  928. @staticmethod
  929. def _prepare_decoder_attention_mask(
  930. attention_mask, input_shape, past_key_values_length, dtype
  931. ):
  932. if attention_mask is not None:
  933. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  934. if len(attention_mask.shape) == 2:
  935. expanded_attn_mask = _expand_2d_mask(
  936. attention_mask, dtype, tgt_length=input_shape[-1]
  937. )
  938. # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
  939. if input_shape[-1] > 1:
  940. combined_attention_mask = _make_causal_mask(
  941. input_shape,
  942. past_key_values_length=past_key_values_length,
  943. )
  944. expanded_attn_mask = expanded_attn_mask & combined_attention_mask
  945. # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
  946. elif len(attention_mask.shape) == 3:
  947. expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
  948. # if attention_mask is already 4-D, do nothing
  949. else:
  950. expanded_attn_mask = attention_mask
  951. else:
  952. expanded_attn_mask = _make_causal_mask(
  953. input_shape,
  954. past_key_values_length=past_key_values_length,
  955. )
  956. # Convert bool attention_mask to float attention mask, which will be added to attention_scores later
  957. if get_device_type() == "xpu":
  958. x = paddle.to_tensor(0.0, dtype="float32")
  959. y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
  960. expanded_attn_mask = paddle.where(expanded_attn_mask, x, y)
  961. else:
  962. expanded_attn_mask = paddle.where(
  963. expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min
  964. ).astype(dtype)
  965. return expanded_attn_mask
  966. def forward(
  967. self,
  968. input_ids: paddle.Tensor = None,
  969. position_ids: Optional[paddle.Tensor] = None,
  970. attention_mask: Optional[paddle.Tensor] = None,
  971. inputs_embeds: Optional[paddle.Tensor] = None,
  972. use_cache: Optional[bool] = None,
  973. past_key_values: Optional[List[paddle.Tensor]] = None,
  974. output_attentions: Optional[bool] = None,
  975. output_hidden_states: Optional[bool] = None,
  976. return_dict: Optional[bool] = None,
  977. attn_mask_startend_row_indices=None,
  978. ) -> Union[Tuple, BaseModelOutputWithPast]:
  979. output_attentions = (
  980. output_attentions
  981. if output_attentions is not None
  982. else self.config.output_attentions
  983. )
  984. output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip
  985. use_cache = use_cache if use_cache is not None else self.config.use_cache
  986. return_dict = (
  987. return_dict if return_dict is not None else self.config.use_return_dict
  988. )
  989. # retrieve input_ids and inputs_embeds
  990. if input_ids is not None and inputs_embeds is not None:
  991. raise ValueError(
  992. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  993. )
  994. elif input_ids is not None:
  995. batch_size, seq_length = input_ids.shape
  996. elif inputs_embeds is not None:
  997. batch_size, seq_length, _ = inputs_embeds.shape
  998. else:
  999. raise ValueError(
  1000. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  1001. )
  1002. if past_key_values is None:
  1003. past_key_values = tuple([None] * len(self.layers))
  1004. # NOTE: to make cache can be clear in-time
  1005. past_key_values = list(past_key_values)
  1006. seq_length_with_past = seq_length
  1007. cache_length = 0
  1008. if past_key_values[0] is not None:
  1009. cache_length = past_key_values[0][0].shape[1]
  1010. seq_length_with_past += cache_length
  1011. if inputs_embeds is None:
  1012. # [bs, seq_len, dim]
  1013. inputs_embeds = self.embed_tokens(input_ids)
  1014. if self.sequence_parallel:
  1015. # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
  1016. bs, seq_len, hidden_size = inputs_embeds.shape
  1017. inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size])
  1018. # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
  1019. inputs_embeds = ScatterOp.apply(inputs_embeds)
  1020. # [bs, seq_len]
  1021. attention_mask = (
  1022. paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
  1023. if attention_mask is None
  1024. else attention_mask
  1025. )
  1026. attention_mask = self._prepare_decoder_attention_mask(
  1027. attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
  1028. ) # [bs, 1, seq_len, seq_len]
  1029. if self.config.use_flash_attention:
  1030. attention_mask = None if is_casual_mask(attention_mask) else attention_mask
  1031. if position_ids is None:
  1032. position_ids = paddle.arange(seq_length, dtype="int64").expand(
  1033. (batch_size, seq_length)
  1034. )
  1035. hidden_states = inputs_embeds
  1036. # decoder layers
  1037. all_hidden_states = () if output_hidden_states else None
  1038. all_self_attns = () if output_attentions else None
  1039. next_decoder_cache = () if use_cache else None
  1040. for idx, (decoder_layer) in enumerate(self.layers):
  1041. if output_hidden_states:
  1042. all_hidden_states += (hidden_states,)
  1043. past_key_value = (
  1044. past_key_values[idx] if past_key_values is not None else None
  1045. )
  1046. has_gradient = not hidden_states.stop_gradient
  1047. if (
  1048. self.enable_recompute
  1049. and idx not in self.no_recompute_layers
  1050. and has_gradient
  1051. and self.recompute_granularity == "full"
  1052. ):
  1053. layer_outputs = self.recompute_training_full(
  1054. decoder_layer,
  1055. hidden_states,
  1056. position_ids,
  1057. attention_mask,
  1058. output_attentions,
  1059. past_key_value,
  1060. use_cache,
  1061. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  1062. )
  1063. else:
  1064. layer_outputs = decoder_layer(
  1065. hidden_states,
  1066. position_ids,
  1067. attention_mask,
  1068. output_attentions,
  1069. past_key_value,
  1070. use_cache,
  1071. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  1072. )
  1073. # NOTE: clear outdate cache after it has been used for memory saving
  1074. past_key_value = past_key_values[idx] = None
  1075. if type(layer_outputs) is tuple:
  1076. hidden_states = layer_outputs[0]
  1077. else:
  1078. hidden_states = layer_outputs
  1079. if output_attentions:
  1080. all_self_attns += (layer_outputs[1],)
  1081. if use_cache:
  1082. next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
  1083. hidden_states = self.norm(hidden_states)
  1084. # add hidden states from the last decoder layer
  1085. if output_hidden_states:
  1086. all_hidden_states += (hidden_states,)
  1087. next_cache = next_decoder_cache if use_cache else None
  1088. if not return_dict:
  1089. return tuple(
  1090. v
  1091. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
  1092. if v is not None
  1093. )
  1094. return BaseModelOutputWithPast(
  1095. last_hidden_state=hidden_states,
  1096. past_key_values=next_cache,
  1097. hidden_states=all_hidden_states,
  1098. attentions=all_self_attns,
  1099. )
  1100. class Qwen2PretrainingCriterion(nn.Layer):
  1101. """
  1102. Criterion for Mixtral.
  1103. It calculates the final loss.
  1104. """
  1105. def __init__(self, config: Qwen2Config):
  1106. super(Qwen2PretrainingCriterion, self).__init__()
  1107. self.ignore_index = getattr(config, "ignore_index", -100)
  1108. self.config = config
  1109. self.enable_parallel_cross_entropy = (
  1110. config.tensor_parallel_degree > 1 and config.tensor_parallel_output
  1111. )
  1112. if (
  1113. self.enable_parallel_cross_entropy
  1114. ): # and False: # and lm_head is distributed
  1115. self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index)
  1116. else:
  1117. self.loss_func = paddle.nn.CrossEntropyLoss(
  1118. reduction="none", ignore_index=self.ignore_index
  1119. )
  1120. def forward(self, prediction_scores, masked_lm_labels):
  1121. if self.enable_parallel_cross_entropy:
  1122. if prediction_scores.shape[-1] == self.config.vocab_size:
  1123. logging.warning(
  1124. f"enable_parallel_cross_entropy, the vocab_size should be splitted: {prediction_scores.shape[-1]}, {self.config.vocab_size}"
  1125. )
  1126. self.loss_func = paddle.nn.CrossEntropyLoss(
  1127. reduction="none", ignore_index=self.ignore_index
  1128. )
  1129. with paddle.amp.auto_cast(False):
  1130. masked_lm_loss = self.loss_func(
  1131. prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)
  1132. )
  1133. # skip ignore_index which loss == 0
  1134. # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
  1135. # loss = paddle.mean(masked_lm_loss)
  1136. binary_sequence = paddle.where(
  1137. masked_lm_loss > 0,
  1138. paddle.ones_like(masked_lm_loss),
  1139. paddle.zeros_like(masked_lm_loss),
  1140. )
  1141. count = paddle.sum(binary_sequence)
  1142. if count == 0:
  1143. loss = paddle.sum(masked_lm_loss * binary_sequence)
  1144. else:
  1145. loss = paddle.sum(masked_lm_loss * binary_sequence) / count
  1146. return loss
  1147. class Qwen2LMHead(nn.Layer):
  1148. def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=False):
  1149. super(Qwen2LMHead, self).__init__()
  1150. self.config = config
  1151. if (
  1152. config.tensor_parallel_degree > 1
  1153. and config.vocab_size % config.tensor_parallel_degree == 0
  1154. ):
  1155. vocab_size = config.vocab_size // config.tensor_parallel_degree
  1156. else:
  1157. vocab_size = config.vocab_size
  1158. self.transpose_y = transpose_y
  1159. if transpose_y:
  1160. if embedding_weights is not None:
  1161. self.weight = embedding_weights
  1162. else:
  1163. self.weight = self.create_parameter(
  1164. shape=[vocab_size, config.hidden_size],
  1165. dtype=paddle.get_default_dtype(),
  1166. )
  1167. else:
  1168. if vocab_size != config.vocab_size:
  1169. self.weight = self.create_parameter(
  1170. shape=[config.hidden_size, vocab_size],
  1171. dtype=paddle.get_default_dtype(),
  1172. )
  1173. else:
  1174. self.weight = self.create_parameter(
  1175. shape=[config.hidden_size, vocab_size],
  1176. dtype=paddle.get_default_dtype(),
  1177. )
  1178. # Must set distributed attr for Tensor Parallel !
  1179. self.weight.is_distributed = (
  1180. True if (vocab_size != config.vocab_size) else False
  1181. )
  1182. if self.weight.is_distributed:
  1183. # for tie_word_embeddings
  1184. self.weight.split_axis = 0 if self.transpose_y else 1
  1185. def forward(self, hidden_states, tensor_parallel_output=None):
  1186. if self.config.sequence_parallel:
  1187. hidden_states = GatherOp.apply(hidden_states)
  1188. seq_length = self.config.seq_length
  1189. hidden_states = paddle.reshape_(
  1190. hidden_states, [-1, seq_length, self.config.hidden_size]
  1191. )
  1192. if tensor_parallel_output is None:
  1193. tensor_parallel_output = self.config.tensor_parallel_output
  1194. logits = parallel_matmul(
  1195. hidden_states,
  1196. self.weight,
  1197. transpose_y=self.transpose_y,
  1198. tensor_parallel_output=tensor_parallel_output,
  1199. )
  1200. return logits
  1201. class Qwen2ForCausalLM(Qwen2PretrainedModel):
  1202. enable_to_static_method = True
  1203. _tied_weights_keys = ["lm_head.weight"]
  1204. def __init__(self, config: Qwen2Config):
  1205. super().__init__(config)
  1206. self.qwen2 = Qwen2Model(config)
  1207. if config.tie_word_embeddings:
  1208. self.lm_head = Qwen2LMHead(
  1209. config,
  1210. embedding_weights=self.qwen2.embed_tokens.weight,
  1211. transpose_y=True,
  1212. )
  1213. self.tie_weights()
  1214. else:
  1215. self.lm_head = Qwen2LMHead(config)
  1216. self.criterion = Qwen2PretrainingCriterion(config)
  1217. self.vocab_size = config.vocab_size
  1218. def get_input_embeddings(self):
  1219. return self.qwen2.embed_tokens
  1220. def set_input_embeddings(self, value):
  1221. self.qwen2.embed_tokens = value
  1222. def get_output_embeddings(self):
  1223. return self.lm_head
  1224. def set_output_embeddings(self, new_embeddings):
  1225. self.lm_head = new_embeddings
  1226. def set_decoder(self, decoder):
  1227. self.qwen2 = decoder
  1228. def get_decoder(self):
  1229. return self.qwen2
  1230. def prepare_inputs_for_generation(
  1231. self,
  1232. input_ids,
  1233. use_cache=False,
  1234. past_key_values=None,
  1235. attention_mask=None,
  1236. inputs_embeds=None,
  1237. **kwargs,
  1238. ):
  1239. batch_size, seq_length = input_ids.shape
  1240. position_ids = kwargs.get(
  1241. "position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))
  1242. )
  1243. if past_key_values:
  1244. input_ids = input_ids[:, -1].unsqueeze(axis=-1)
  1245. position_ids = position_ids[:, -1].unsqueeze(-1)
  1246. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1247. if inputs_embeds is not None and past_key_values is None:
  1248. model_inputs = {"inputs_embeds": inputs_embeds}
  1249. else:
  1250. model_inputs = {"input_ids": input_ids}
  1251. model_inputs.update(
  1252. {
  1253. "position_ids": position_ids,
  1254. "past_key_values": past_key_values,
  1255. "use_cache": use_cache,
  1256. "attention_mask": attention_mask,
  1257. }
  1258. )
  1259. return model_inputs
  1260. def _get_model_inputs_spec(self, dtype: str):
  1261. return {
  1262. "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
  1263. "attention_mask": paddle.static.InputSpec(
  1264. shape=[None, None], dtype="int64"
  1265. ),
  1266. "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
  1267. }
  1268. @staticmethod
  1269. def update_model_kwargs_for_generation(
  1270. outputs, model_kwargs, is_encoder_decoder=False
  1271. ):
  1272. # update cache
  1273. if (
  1274. isinstance(outputs, tuple)
  1275. and len(outputs) > 1
  1276. and not isinstance(outputs[1], paddle.Tensor)
  1277. ):
  1278. model_kwargs["past_key_values"] = outputs[1]
  1279. if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs:
  1280. model_kwargs["past_key_values"] = outputs.past_key_values
  1281. # update position_ids
  1282. if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
  1283. position_ids = model_kwargs["position_ids"]
  1284. model_kwargs["position_ids"] = paddle.concat(
  1285. [position_ids, position_ids[..., -1:] + 1], axis=-1
  1286. )
  1287. if not is_encoder_decoder and "attention_mask" in model_kwargs:
  1288. # TODO: support attention mask for other models
  1289. attention_mask = model_kwargs["attention_mask"]
  1290. if len(attention_mask.shape) == 2:
  1291. model_kwargs["attention_mask"] = paddle.concat(
  1292. [
  1293. attention_mask,
  1294. paddle.ones(
  1295. [attention_mask.shape[0], 1], dtype=attention_mask.dtype
  1296. ),
  1297. ],
  1298. axis=-1,
  1299. )
  1300. elif len(attention_mask.shape) == 4:
  1301. model_kwargs["attention_mask"] = paddle.concat(
  1302. [
  1303. attention_mask,
  1304. paddle.ones(
  1305. [*attention_mask.shape[:3], 1], dtype=attention_mask.dtype
  1306. ),
  1307. ],
  1308. axis=-1,
  1309. )[:, :, -1:, :]
  1310. return model_kwargs
  1311. def forward(
  1312. self,
  1313. input_ids: paddle.Tensor = None,
  1314. position_ids: Optional[paddle.Tensor] = None,
  1315. attention_mask: Optional[paddle.Tensor] = None,
  1316. inputs_embeds: Optional[paddle.Tensor] = None,
  1317. labels: Optional[paddle.Tensor] = None,
  1318. use_cache: Optional[bool] = None,
  1319. past_key_values: Optional[List[paddle.Tensor]] = None,
  1320. output_attentions: Optional[bool] = None,
  1321. output_hidden_states: Optional[bool] = None,
  1322. return_dict: Optional[bool] = None,
  1323. attn_mask_startend_row_indices=None,
  1324. ) -> Union[Tuple, CausalLMOutputWithPast]:
  1325. r"""
  1326. Args:
  1327. labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1328. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1329. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1330. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1331. Returns:
  1332. Example:
  1333. ```python
  1334. >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
  1335. >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
  1336. >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
  1337. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  1338. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1339. >>> # Generate
  1340. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1341. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1342. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  1343. ```"""
  1344. output_attentions = (
  1345. output_attentions
  1346. if output_attentions is not None
  1347. else self.config.output_attentions
  1348. )
  1349. output_hidden_states = (
  1350. output_hidden_states
  1351. if output_hidden_states is not None
  1352. else self.config.output_hidden_states
  1353. )
  1354. return_dict = (
  1355. return_dict if return_dict is not None else self.config.use_return_dict
  1356. )
  1357. if attn_mask_startend_row_indices is not None and attention_mask is not None:
  1358. logging.warning(
  1359. "You have provided both attn_mask_startend_row_indices and attention_mask. "
  1360. "The attn_mask_startend_row_indices will be used."
  1361. )
  1362. attention_mask = None
  1363. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1364. outputs = self.qwen2(
  1365. input_ids=input_ids,
  1366. position_ids=position_ids,
  1367. attention_mask=attention_mask,
  1368. inputs_embeds=inputs_embeds,
  1369. use_cache=use_cache,
  1370. past_key_values=past_key_values,
  1371. output_attentions=output_attentions,
  1372. output_hidden_states=output_hidden_states,
  1373. return_dict=return_dict,
  1374. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  1375. )
  1376. hidden_states = outputs[0]
  1377. # if labels is None,means we need full output, instead of tensor_parallel_output
  1378. # tensor_parallel_output is together with ParallelCrossEntropy
  1379. tensor_parallel_output = (
  1380. self.config.tensor_parallel_output
  1381. and self.config.tensor_parallel_degree > 1
  1382. )
  1383. logits = self.lm_head(
  1384. hidden_states, tensor_parallel_output=tensor_parallel_output
  1385. )
  1386. loss = None
  1387. if not return_dict:
  1388. output = (logits,) + outputs[1:]
  1389. return (loss,) + output if loss is not None else output
  1390. return CausalLMOutputWithPast(
  1391. loss=loss,
  1392. logits=logits,
  1393. past_key_values=outputs.past_key_values,
  1394. hidden_states=outputs.hidden_states,
  1395. attentions=outputs.attentions,
  1396. )