qwen2.py 65 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from functools import partial
  16. from typing import List, Optional, Tuple, Union
  17. import paddle
  18. import paddle.distributed.fleet.meta_parallel as mpu
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle import Tensor
  22. from paddle.distributed import fleet
  23. from paddle.distributed.fleet.utils import sequence_parallel_utils
  24. from .....utils import logging
  25. from .....utils.env import get_device_type
  26. from ...common.vlm import fusion_ops
  27. from ...common.vlm.activations import ACT2FN
  28. from ...common.vlm.transformers import PretrainedConfig, PretrainedModel
  29. from ...common.vlm.transformers.model_outputs import (
  30. BaseModelOutputWithPast,
  31. CausalLMOutputWithPast,
  32. )
  33. try:
  34. from paddle.incubate.nn.functional import fused_rotary_position_embedding
  35. except ImportError:
  36. fused_rotary_position_embedding = None
  37. try:
  38. from paddle.distributed.fleet.utils.sequence_parallel_utils import (
  39. GatherOp,
  40. ScatterOp,
  41. mark_as_sequence_parallel_parameter,
  42. )
  43. except:
  44. pass
  45. try:
  46. from paddle.nn.functional.flash_attention import flash_attention
  47. except:
  48. flash_attention = None
  49. Linear = nn.Linear
  50. ColumnParallelLinear = mpu.ColumnParallelLinear
  51. RowParallelLinear = mpu.RowParallelLinear
  52. ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
  53. RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
  54. class Qwen2Config(PretrainedConfig):
  55. r"""
  56. This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
  57. Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
  58. with the defaults will yield a similar configuration to that of
  59. Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
  60. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  61. documentation from [`PretrainedConfig`] for more information.
  62. Args:
  63. vocab_size (`int`, *optional*, defaults to 151936):
  64. Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
  65. `inputs_ids` passed when calling [`Qwen2Model`]
  66. hidden_size (`int`, *optional*, defaults to 4096):
  67. Dimension of the hidden representations.
  68. intermediate_size (`int`, *optional*, defaults to 22016):
  69. Dimension of the MLP representations.
  70. num_hidden_layers (`int`, *optional*, defaults to 32):
  71. Number of hidden layers in the Transformer encoder.
  72. num_attention_heads (`int`, *optional*, defaults to 32):
  73. Number of attention heads for each attention layer in the Transformer encoder.
  74. num_key_value_heads (`int`, *optional*, defaults to 32):
  75. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  76. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  77. `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  78. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  79. by meanpooling all the original heads within that group. For more details checkout [this
  80. paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
  81. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  82. The non-linear activation function (function or string) in the decoder.
  83. max_position_embeddings (`int`, *optional*, defaults to 32768):
  84. The maximum sequence length that this model might ever be used with.
  85. initializer_range (`float`, *optional*, defaults to 0.02):
  86. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  87. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  88. The epsilon used by the rms normalization layers.
  89. use_cache (`bool`, *optional*, defaults to `True`):
  90. Whether or not the model should return the last key/values attentions (not used by all models). Only
  91. relevant if `config.is_decoder=True`.
  92. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  93. Whether the model's input and output word embeddings should be tied.
  94. rope_theta (`float`, *optional*, defaults to 10000.0):
  95. The base period of the RoPE embeddings.
  96. use_sliding_window (`bool`, *optional*, defaults to `False`):
  97. Whether to use sliding window attention.
  98. sliding_window (`int`, *optional*, defaults to 4096):
  99. Sliding window attention (SWA) window size. If not specified, will default to `4096`.
  100. max_window_layers (`int`, *optional*, defaults to 28):
  101. The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
  102. attention_dropout (`float`, *optional*, defaults to 0.0):
  103. The dropout ratio for the attention probabilities.
  104. """
  105. model_type = "qwen2"
  106. keys_to_ignore_at_inference = ["past_key_values"]
  107. def __init__(
  108. self,
  109. vocab_size=151936,
  110. hidden_size=4096,
  111. intermediate_size=22016,
  112. num_hidden_layers=32,
  113. num_attention_heads=32,
  114. num_key_value_heads=32,
  115. hidden_act="silu",
  116. max_position_embeddings=32768,
  117. seq_length=32768,
  118. initializer_range=0.02,
  119. rms_norm_eps=1e-6,
  120. use_cache=True,
  121. tie_word_embeddings=False,
  122. rope_theta=10000.0,
  123. pad_token_id=0,
  124. bos_token_id=151643,
  125. eos_token_id=151643,
  126. use_sliding_window=False,
  127. sliding_window=4096,
  128. max_window_layers=28,
  129. attention_dropout=0.0,
  130. rope_scaling_factor=1.0,
  131. rope_scaling_type=None,
  132. dpo_config=None,
  133. **kwargs,
  134. ):
  135. self.vocab_size = vocab_size
  136. self.max_position_embeddings = max_position_embeddings
  137. self.seq_length = seq_length
  138. self.hidden_size = hidden_size
  139. self.intermediate_size = intermediate_size
  140. self.num_hidden_layers = num_hidden_layers
  141. self.num_attention_heads = num_attention_heads
  142. self.use_sliding_window = use_sliding_window
  143. self.sliding_window = sliding_window
  144. self.max_window_layers = max_window_layers
  145. # for backward compatibility
  146. if num_key_value_heads is None:
  147. num_key_value_heads = num_attention_heads
  148. self.num_key_value_heads = num_key_value_heads
  149. self.hidden_act = hidden_act
  150. self.initializer_range = initializer_range
  151. self.rms_norm_eps = rms_norm_eps
  152. self.use_cache = use_cache
  153. self.rope_theta = rope_theta
  154. self.attention_dropout = attention_dropout
  155. self.use_cache = use_cache
  156. self.rope_scaling_factor = rope_scaling_factor
  157. self.rope_scaling_type = rope_scaling_type
  158. self.pad_token_id = pad_token_id
  159. self.bos_token_id = bos_token_id
  160. self.eos_token_id = eos_token_id
  161. self.dpo_config = dpo_config
  162. super().__init__(
  163. pad_token_id=pad_token_id,
  164. bos_token_id=bos_token_id,
  165. eos_token_id=eos_token_id,
  166. tie_word_embeddings=tie_word_embeddings,
  167. **kwargs,
  168. )
  169. def get_triangle_upper_mask(x, mask=None):
  170. if mask is not None:
  171. return mask
  172. # [bsz, n_head, q_len, kv_seq_len]
  173. shape = x.shape
  174. # [bsz, 1, q_len, kv_seq_len]
  175. shape[1] = 1
  176. mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype)
  177. mask = paddle.triu(mask, diagonal=1)
  178. mask.stop_gradient = True
  179. return mask
  180. def parallel_matmul(
  181. x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True
  182. ):
  183. is_fleet_init = True
  184. tensor_parallel_degree = 1
  185. try:
  186. hcg = fleet.get_hybrid_communicate_group()
  187. model_parallel_group = hcg.get_model_parallel_group()
  188. tensor_parallel_degree = hcg.get_model_parallel_world_size()
  189. except:
  190. is_fleet_init = False
  191. if paddle.in_dynamic_mode():
  192. y_is_distributed = y.is_distributed
  193. else:
  194. y_is_distributed = tensor_parallel_degree > 1
  195. if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
  196. # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
  197. input_parallel = paddle.distributed.collective._c_identity(
  198. x, group=model_parallel_group
  199. )
  200. logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
  201. if tensor_parallel_output:
  202. return logits
  203. return paddle.distributed.collective._c_concat(
  204. logits, group=model_parallel_group
  205. )
  206. else:
  207. logits = paddle.matmul(x, y, transpose_y=transpose_y)
  208. return logits
  209. def scaled_dot_product_attention(
  210. query_states,
  211. config,
  212. key_states,
  213. value_states,
  214. attention_mask,
  215. output_attentions,
  216. attn_mask_startend_row_indices=None,
  217. training=True,
  218. sequence_parallel=False,
  219. skip_recompute=False,
  220. ):
  221. bsz, q_len, num_heads, head_dim = query_states.shape
  222. _, kv_seq_len, _, _ = value_states.shape
  223. if config.use_flash_attention and flash_attention:
  224. # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
  225. # Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
  226. return fusion_ops.fusion_flash_attention(
  227. query_states,
  228. config,
  229. key_states,
  230. value_states,
  231. attention_mask,
  232. output_attentions,
  233. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  234. sequence_parallel=sequence_parallel,
  235. skip_recompute=skip_recompute,
  236. )
  237. else:
  238. # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
  239. query_states = paddle.transpose(query_states, [0, 2, 1, 3])
  240. # merge with the next transpose
  241. key_states = paddle.transpose(key_states, [0, 2, 1, 3])
  242. value_states = paddle.transpose(value_states, [0, 2, 1, 3])
  243. # Add pre divided factor to fix nan under float16.
  244. if paddle.in_dynamic_mode() and query_states.dtype == paddle.float16:
  245. pre_divided_factor = 32
  246. else:
  247. pre_divided_factor = 1
  248. attn_weights = paddle.matmul(
  249. query_states / (math.sqrt(head_dim) * pre_divided_factor),
  250. key_states.transpose([0, 1, 3, 2]),
  251. )
  252. if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]:
  253. raise ValueError(
  254. f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is"
  255. f" {attn_weights.shape}"
  256. )
  257. if attention_mask is None:
  258. attention_mask = get_triangle_upper_mask(attn_weights)
  259. attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len])
  260. if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]:
  261. raise ValueError(
  262. f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}"
  263. )
  264. attn_weights = attn_weights + attention_mask
  265. if not paddle.in_dynamic_mode():
  266. attn_weights = F.softmax(
  267. attn_weights * pre_divided_factor, axis=-1, dtype="float32"
  268. ).astype(query_states.dtype)
  269. else:
  270. with paddle.amp.auto_cast(False):
  271. attn_weights = F.softmax(
  272. attn_weights.astype("float32") * pre_divided_factor,
  273. axis=-1,
  274. dtype="float32",
  275. ).astype(query_states.dtype)
  276. attn_weights = F.dropout(
  277. attn_weights, p=config.attention_dropout, training=training
  278. )
  279. attn_output = paddle.matmul(attn_weights, value_states)
  280. attn_output = attn_output.transpose([0, 2, 1, 3])
  281. if sequence_parallel:
  282. attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
  283. else:
  284. attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
  285. return (attn_output, attn_weights) if output_attentions else attn_output
  286. def is_casual_mask(attention_mask):
  287. """
  288. Upper triangular of attention_mask equals to attention_mask is casual
  289. """
  290. return (paddle.triu(attention_mask) == attention_mask).all().item()
  291. def _make_causal_mask(input_ids_shape, past_key_values_length):
  292. """
  293. Make causal mask used for self-attention
  294. """
  295. batch_size, target_length = input_ids_shape # target_length: seq_len
  296. mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
  297. if past_key_values_length > 0:
  298. # [tgt_len, tgt_len + past_len]
  299. mask = paddle.concat(
  300. [paddle.ones([target_length, past_key_values_length], dtype="bool"), mask],
  301. axis=-1,
  302. )
  303. # [bs, 1, tgt_len, tgt_len + past_len]
  304. return mask[None, None, :, :].expand(
  305. [batch_size, 1, target_length, target_length + past_key_values_length]
  306. )
  307. def _expand_2d_mask(mask, dtype, tgt_length):
  308. """
  309. Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
  310. """
  311. batch_size, src_length = mask.shape[0], mask.shape[-1]
  312. tgt_length = tgt_length if tgt_length is not None else src_length
  313. mask = mask[:, None, None, :].astype("bool")
  314. mask.stop_gradient = True
  315. expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
  316. return expanded_mask
  317. class Qwen2RMSNorm(nn.Layer):
  318. def __init__(self, config: Qwen2Config):
  319. """
  320. Qwen2RMSNorm is equivalent to T5LayerNorm
  321. """
  322. super().__init__()
  323. self.hidden_size = config.hidden_size
  324. self.weight = paddle.create_parameter(
  325. shape=[self.hidden_size],
  326. dtype=paddle.get_default_dtype(),
  327. default_initializer=nn.initializer.Constant(1.0),
  328. )
  329. self.variance_epsilon = config.rms_norm_eps
  330. self.config = config
  331. if config.sequence_parallel:
  332. mark_as_sequence_parallel_parameter(self.weight)
  333. def forward(self, hidden_states):
  334. if self.config.use_fused_rms_norm:
  335. return fusion_ops.fusion_rms_norm(
  336. hidden_states, self.weight, self.variance_epsilon, False
  337. )
  338. if paddle.in_dynamic_mode():
  339. with paddle.amp.auto_cast(False):
  340. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  341. hidden_states = (
  342. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  343. )
  344. else:
  345. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  346. hidden_states = (
  347. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  348. )
  349. if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
  350. hidden_states = paddle.cast(hidden_states, self.weight.dtype)
  351. return hidden_states * self.weight
  352. class Qwen2RotaryEmbedding(nn.Layer):
  353. def __init__(self, dim, max_position_embeddings=2048, base=10000):
  354. super().__init__()
  355. self.dim = dim
  356. self.max_position_embeddings = max_position_embeddings
  357. self.base = base
  358. # [dim / 2]
  359. self.inv_freq = 1.0 / (
  360. self.base
  361. ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)
  362. )
  363. self._set_cos_sin_cache(seq_len=max_position_embeddings)
  364. def _set_cos_sin_cache(self, seq_len):
  365. self.max_seq_len_cached = seq_len
  366. # [seq_len]
  367. t = paddle.arange(seq_len, dtype="float32")
  368. # [seq_len, dim/2]
  369. freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
  370. # Different from paper, but it uses a different permutation in order to obtain the same calculation
  371. # [seq_len, dim]
  372. emb = paddle.concat([freqs, freqs], axis=-1)
  373. # [1, seqlen, 1, dim]
  374. self.cos_cached = emb.cos()[None, :, None, :]
  375. self.sin_cached = emb.sin()[None, :, None, :]
  376. def forward(self, x, seq_len=None):
  377. # x: [bs, num_attention_heads, seq_len, head_size]
  378. if seq_len > self.max_seq_len_cached:
  379. self._set_cos_sin_cache(seq_len)
  380. cos = self.cos_cached[:, :seq_len, :, :]
  381. sin = self.sin_cached[:, :seq_len, :, :]
  382. return (
  383. cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
  384. sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
  385. )
  386. def rotate_half(x):
  387. """Rotates half the hidden dims of the input."""
  388. x1 = x[..., : x.shape[-1] // 2]
  389. x2 = x[..., x.shape[-1] // 2 :]
  390. return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
  391. def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  392. if position_ids is None:
  393. # Note: Only for Qwen2MoEForCausalLMPipe model pretraining
  394. cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
  395. sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
  396. else:
  397. cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
  398. sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
  399. cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
  400. sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
  401. q_embed = (q * cos) + (rotate_half(q) * sin)
  402. k_embed = (k * cos) + (rotate_half(k) * sin)
  403. return q_embed, k_embed
  404. class Qwen2MLP(nn.Layer):
  405. def __init__(self, config: Qwen2Config, is_shared=False, skip_recompute_ops=None):
  406. super().__init__()
  407. if skip_recompute_ops is None:
  408. skip_recompute_ops = {}
  409. self.skip_recompute_ops = skip_recompute_ops
  410. self.hidden_size = config.hidden_size
  411. self.intermediate_size = config.intermediate_size
  412. self.fuse_attention_ffn = config.fuse_attention_ffn
  413. self.tensor_parallel_degree = config.tensor_parallel_degree
  414. if config.sequence_parallel:
  415. ColumnParallelLinear = ColumnSequenceParallelLinear
  416. RowParallelLinear = RowSequenceParallelLinear
  417. if config.tensor_parallel_degree > 1:
  418. if self.fuse_attention_ffn:
  419. self.gate_up_fused_proj = ColumnParallelLinear(
  420. self.hidden_size,
  421. self.intermediate_size * 2,
  422. gather_output=False,
  423. has_bias=False,
  424. )
  425. else:
  426. self.gate_proj = ColumnParallelLinear(
  427. self.hidden_size,
  428. self.intermediate_size,
  429. gather_output=False,
  430. has_bias=False,
  431. )
  432. self.up_proj = ColumnParallelLinear(
  433. self.hidden_size,
  434. self.intermediate_size,
  435. gather_output=False,
  436. has_bias=False,
  437. )
  438. self.down_proj = RowParallelLinear(
  439. self.intermediate_size,
  440. self.hidden_size,
  441. input_is_parallel=True,
  442. has_bias=False,
  443. )
  444. else:
  445. if self.fuse_attention_ffn:
  446. self.gate_up_fused_proj = Linear(
  447. self.hidden_size, self.intermediate_size * 2, bias_attr=False
  448. )
  449. else:
  450. self.gate_proj = Linear(
  451. self.hidden_size, self.intermediate_size, bias_attr=False
  452. ) # w1
  453. self.up_proj = Linear(
  454. self.hidden_size, self.intermediate_size, bias_attr=False
  455. ) # w3
  456. self.down_proj = Linear(
  457. self.intermediate_size, self.hidden_size, bias_attr=False
  458. ) # w2
  459. if config.hidden_act == "silu":
  460. self.act_fn = fusion_ops.swiglu
  461. self.fuse_swiglu = True
  462. else:
  463. self.act_fn = ACT2FN[config.hidden_act]
  464. self.fuse_swiglu = False
  465. def forward(self, x):
  466. if self.fuse_attention_ffn:
  467. x = self.gate_up_fused_proj(x)
  468. if self.fuse_swiglu:
  469. y = None
  470. else:
  471. x, y = x.chunk(2, axis=-1)
  472. else:
  473. x, y = self.gate_proj(x), self.up_proj(x)
  474. if self.fuse_swiglu:
  475. x = self.act_fn(x, y)
  476. else:
  477. x = self.act_fn(x) * y
  478. return self.down_proj(x)
  479. def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
  480. """
  481. This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch,
  482. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  483. """
  484. batch, slen, num_key_value_heads, head_dim = hidden_states.shape
  485. if n_rep == 1:
  486. return hidden_states
  487. hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
  488. return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim])
  489. class Qwen2Attention(nn.Layer):
  490. """
  491. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  492. and "Generating Long Sequences with Sparse Transformers".
  493. """
  494. def __init__(
  495. self,
  496. config: Qwen2Config,
  497. layerwise_recompute: bool = True,
  498. skip_recompute_ops=None,
  499. ):
  500. super().__init__()
  501. if skip_recompute_ops is None:
  502. skip_recompute_ops = {}
  503. self.config = config
  504. self.skip_recompute_ops = skip_recompute_ops
  505. self.hidden_size = config.hidden_size
  506. self.num_heads = config.num_attention_heads
  507. self.head_dim = self.hidden_size // config.num_attention_heads
  508. self.num_key_value_heads = config.num_key_value_heads
  509. assert config.num_attention_heads // config.num_key_value_heads
  510. self.num_key_value_groups = (
  511. config.num_attention_heads // config.num_key_value_heads
  512. )
  513. self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads
  514. self.max_position_embeddings = config.max_position_embeddings
  515. self.rope_theta = config.rope_theta
  516. self.is_causal = True
  517. self.attention_dropout = config.attention_dropout
  518. self.seq_length = config.seq_length
  519. self.sequence_parallel = config.sequence_parallel
  520. self.fuse_attention_qkv = config.fuse_attention_qkv
  521. # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
  522. # Enable_recompute defaults to False and is controlled by Trainer
  523. self.enable_recompute = False
  524. self.layerwise_recompute = layerwise_recompute
  525. self.recompute_granularity = config.recompute_granularity
  526. if config.tensor_parallel_degree > 1:
  527. assert (
  528. self.num_heads % config.tensor_parallel_degree == 0
  529. ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  530. self.num_heads = self.num_heads // config.tensor_parallel_degree
  531. assert (
  532. self.num_key_value_heads % config.tensor_parallel_degree == 0
  533. ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  534. self.num_key_value_heads = (
  535. self.num_key_value_heads // config.tensor_parallel_degree
  536. )
  537. self.use_fused_rope = config.use_fused_rope
  538. if self.use_fused_rope:
  539. if (
  540. get_device_type() not in ["gpu", "xpu"]
  541. or fused_rotary_position_embedding is None
  542. ):
  543. logging.warning(
  544. "Enable fuse rope in the config, but fuse rope is not available. "
  545. "Will disable fuse rope. Try using latest gpu version of Paddle."
  546. )
  547. self.use_fused_rope = False
  548. if config.sequence_parallel:
  549. ColumnParallelLinear = ColumnSequenceParallelLinear
  550. RowParallelLinear = RowSequenceParallelLinear
  551. if config.tensor_parallel_degree > 1:
  552. if self.fuse_attention_qkv:
  553. self.qkv_proj = ColumnParallelLinear(
  554. self.hidden_size,
  555. self.hidden_size
  556. + 2 * self.config.num_key_value_heads * self.head_dim,
  557. has_bias=True,
  558. gather_output=False,
  559. )
  560. else:
  561. self.q_proj = ColumnParallelLinear(
  562. self.hidden_size,
  563. self.hidden_size,
  564. has_bias=True,
  565. gather_output=False,
  566. )
  567. self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  568. self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  569. self.o_proj = RowParallelLinear(
  570. self.hidden_size,
  571. self.hidden_size,
  572. has_bias=False,
  573. input_is_parallel=True,
  574. )
  575. else:
  576. if self.fuse_attention_qkv:
  577. self.qkv_proj = Linear(
  578. self.hidden_size,
  579. self.hidden_size
  580. + 2 * self.config.num_key_value_heads * self.head_dim,
  581. )
  582. else:
  583. self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True)
  584. self.k_proj = Linear(
  585. self.hidden_size,
  586. self.config.num_key_value_heads * self.head_dim,
  587. bias_attr=True,
  588. )
  589. self.v_proj = Linear(
  590. self.hidden_size,
  591. self.config.num_key_value_heads * self.head_dim,
  592. bias_attr=True,
  593. )
  594. self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False)
  595. self.rotary_emb = Qwen2RotaryEmbedding(
  596. self.head_dim,
  597. max_position_embeddings=self.max_position_embeddings,
  598. base=self.rope_theta,
  599. )
  600. self.attn_func = scaled_dot_product_attention
  601. def forward(
  602. self,
  603. hidden_states,
  604. position_ids: Optional[Tuple[paddle.Tensor]] = None,
  605. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  606. attention_mask: Optional[paddle.Tensor] = None,
  607. output_attentions: bool = False,
  608. use_cache: bool = False,
  609. attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
  610. **kwargs,
  611. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
  612. """Input shape: Batch x Time x Channel"""
  613. # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
  614. if self.fuse_attention_qkv:
  615. mix_layer = self.qkv_proj(hidden_states)
  616. if self.sequence_parallel:
  617. target_shape = [
  618. -1,
  619. self.seq_length,
  620. self.num_key_value_heads,
  621. (self.num_key_value_groups + 2) * self.head_dim,
  622. ]
  623. else:
  624. target_shape = [
  625. 0,
  626. 0,
  627. self.num_key_value_heads,
  628. (self.num_key_value_groups + 2) * self.head_dim,
  629. ]
  630. mix_layer = paddle.reshape_(mix_layer, target_shape)
  631. query_states, key_states, value_states = paddle.split(
  632. mix_layer,
  633. num_or_sections=[
  634. self.num_key_value_groups * self.head_dim,
  635. self.head_dim,
  636. self.head_dim,
  637. ],
  638. axis=-1,
  639. )
  640. if self.gqa_or_mqa:
  641. query_states = paddle.reshape_(
  642. query_states, [0, 0, self.num_heads, self.head_dim]
  643. )
  644. else:
  645. query_states = self.q_proj(hidden_states)
  646. key_states = self.k_proj(hidden_states)
  647. value_states = self.v_proj(hidden_states)
  648. if self.sequence_parallel:
  649. target_query_shape = [
  650. -1,
  651. self.seq_length,
  652. self.num_heads,
  653. self.head_dim,
  654. ]
  655. target_key_value_shape = [
  656. -1,
  657. self.seq_length,
  658. self.num_key_value_heads,
  659. self.head_dim,
  660. ]
  661. else:
  662. target_query_shape = [0, 0, self.num_heads, self.head_dim]
  663. target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
  664. query_states = query_states.reshape(shape=target_query_shape)
  665. key_states = key_states.reshape(shape=target_key_value_shape)
  666. value_states = value_states.reshape(shape=target_key_value_shape)
  667. kv_seq_len = key_states.shape[-3]
  668. if past_key_value is not None:
  669. kv_seq_len += past_key_value[0].shape[-3]
  670. if self.use_fused_rope:
  671. assert past_key_value is None, "fuse rotary not support cache kv for now"
  672. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  673. query_states, key_states, _ = fused_rotary_position_embedding(
  674. query_states,
  675. key_states,
  676. v=None,
  677. sin=sin,
  678. cos=cos,
  679. position_ids=position_ids,
  680. use_neox_rotary_style=False,
  681. )
  682. else:
  683. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  684. query_states, key_states = apply_rotary_pos_emb(
  685. query_states, key_states, cos, sin, position_ids
  686. )
  687. # [bs, seq_len, num_head, head_dim]
  688. if past_key_value is not None:
  689. key_states = paddle.concat([past_key_value[0], key_states], axis=1)
  690. value_states = paddle.concat([past_key_value[1], value_states], axis=1)
  691. past_key_value = (key_states, value_states) if use_cache else None
  692. # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
  693. # repeat k/v heads if n_kv_heads < n_heads
  694. paddle_version = float(paddle.__version__[:3])
  695. if not self.config.use_flash_attention or (
  696. (paddle_version != 0.0) and (paddle_version <= 2.6)
  697. ):
  698. key_states = repeat_kv(key_states, self.num_key_value_groups)
  699. value_states = repeat_kv(value_states, self.num_key_value_groups)
  700. outputs = self.attn_func(
  701. query_states,
  702. self.config,
  703. key_states,
  704. value_states,
  705. attention_mask,
  706. output_attentions,
  707. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  708. training=self.training,
  709. sequence_parallel=self.sequence_parallel,
  710. )
  711. if output_attentions:
  712. attn_output, attn_weights = outputs
  713. else:
  714. attn_output = outputs
  715. # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
  716. # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
  717. attn_output = self.o_proj(attn_output)
  718. if not output_attentions:
  719. attn_weights = None
  720. outputs = (attn_output,)
  721. if output_attentions:
  722. outputs += (attn_weights,)
  723. if use_cache:
  724. outputs += (past_key_value,)
  725. if type(outputs) is tuple and len(outputs) == 1:
  726. outputs = outputs[0]
  727. return outputs
  728. class Qwen2DecoderLayer(nn.Layer):
  729. def __init__(
  730. self,
  731. config: Qwen2Config,
  732. layerwise_recompute: bool = False,
  733. skip_recompute_ops=None,
  734. ):
  735. super().__init__()
  736. if skip_recompute_ops is None:
  737. skip_recompute_ops = {}
  738. self.config = config
  739. self.skip_recompute_ops = skip_recompute_ops
  740. self.hidden_size = config.hidden_size
  741. self.self_attn = Qwen2Attention(
  742. config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops
  743. )
  744. self.mlp = Qwen2MLP(config, skip_recompute_ops=skip_recompute_ops)
  745. self.input_layernorm = Qwen2RMSNorm(config)
  746. self.post_attention_layernorm = Qwen2RMSNorm(config)
  747. # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
  748. # Enable_recompute defaults to False and is controlled by Trainer
  749. self.enable_recompute = False
  750. self.layerwise_recompute = layerwise_recompute
  751. self.recompute_granularity = config.recompute_granularity
  752. def forward(
  753. self,
  754. hidden_states: paddle.Tensor,
  755. position_ids: Optional[paddle.Tensor] = None,
  756. attention_mask: Optional[paddle.Tensor] = None,
  757. output_attentions: Optional[bool] = False,
  758. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  759. use_cache: Optional[bool] = False,
  760. attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
  761. **kwargs,
  762. ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
  763. """
  764. Args:
  765. hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  766. attention_mask (`paddle.Tensor`, *optional*): attention mask of size
  767. `(batch, sequence_length)` where padding elements are indicated by 0.
  768. output_attentions (`bool`, *optional*):
  769. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  770. returned tensors for more detail.
  771. use_cache (`bool`, *optional*):
  772. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  773. (see `past_key_values`).
  774. past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
  775. """
  776. # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel)
  777. residual = hidden_states
  778. hidden_states = self.input_layernorm(hidden_states)
  779. # Self Attention
  780. outputs = self.self_attn(
  781. hidden_states,
  782. position_ids,
  783. past_key_value,
  784. attention_mask,
  785. output_attentions,
  786. use_cache,
  787. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  788. )
  789. if type(outputs) is tuple:
  790. hidden_states = outputs[0]
  791. else:
  792. hidden_states = outputs
  793. if output_attentions:
  794. self_attn_weights = outputs[1]
  795. if use_cache:
  796. present_key_value = outputs[2 if output_attentions else 1]
  797. hidden_states = residual + hidden_states
  798. # Fully Connected
  799. residual = hidden_states
  800. hidden_states = self.post_attention_layernorm(hidden_states)
  801. hidden_states = self.mlp(hidden_states)
  802. hidden_states = residual + hidden_states
  803. outputs = (hidden_states,)
  804. if output_attentions:
  805. outputs += (self_attn_weights,)
  806. if use_cache:
  807. outputs += (present_key_value,)
  808. if type(outputs) is tuple and len(outputs) == 1:
  809. outputs = outputs[0]
  810. return outputs
  811. class Qwen2PretrainedModel(PretrainedModel):
  812. config_class = Qwen2Config
  813. base_model_prefix = "qwen2"
  814. _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"]
  815. @classmethod
  816. def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True):
  817. from paddlenlp.transformers.conversion_utils import split_or_merge_func
  818. fn = split_or_merge_func(
  819. is_split=is_split,
  820. tensor_parallel_degree=config.tensor_parallel_degree,
  821. tensor_parallel_rank=config.tensor_parallel_rank,
  822. num_attention_heads=config.num_attention_heads,
  823. )
  824. def get_tensor_parallel_split_mappings(num_layers):
  825. final_actions = {}
  826. base_actions = {
  827. # Row Linear
  828. "embed_tokens.weight": partial(fn, is_column=False),
  829. "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
  830. "layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
  831. }
  832. if config.tie_word_embeddings:
  833. base_actions["lm_head.weight"] = partial(fn, is_column=False)
  834. else:
  835. base_actions["lm_head.weight"] = partial(fn, is_column=True)
  836. if not config.vocab_size % config.tensor_parallel_degree == 0:
  837. base_actions.pop("lm_head.weight")
  838. base_actions.pop("embed_tokens.weight")
  839. # Column Linear
  840. if config.fuse_attention_qkv:
  841. base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(
  842. fn, is_column=True
  843. )
  844. base_actions["layers.0.self_attn.qkv_proj.bias"] = partial(
  845. fn, is_column=True
  846. )
  847. else:
  848. base_actions["layers.0.self_attn.q_proj.weight"] = partial(
  849. fn, is_column=True
  850. )
  851. base_actions["layers.0.self_attn.q_proj.bias"] = partial(
  852. fn, is_column=True
  853. )
  854. # if we have enough num_key_value_heads to split, then split it.
  855. if config.num_key_value_heads % config.tensor_parallel_degree == 0:
  856. base_actions["layers.0.self_attn.k_proj.weight"] = partial(
  857. fn, is_column=True
  858. )
  859. base_actions["layers.0.self_attn.v_proj.weight"] = partial(
  860. fn, is_column=True
  861. )
  862. base_actions["layers.0.self_attn.k_proj.bias"] = partial(
  863. fn, is_column=True
  864. )
  865. base_actions["layers.0.self_attn.v_proj.bias"] = partial(
  866. fn, is_column=True
  867. )
  868. if config.fuse_attention_ffn:
  869. base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
  870. fn, is_column=True, is_naive_2fuse=True
  871. )
  872. else:
  873. base_actions["layers.0.mlp.gate_proj.weight"] = partial(
  874. fn, is_column=True
  875. )
  876. base_actions["layers.0.mlp.up_proj.weight"] = partial(
  877. fn, is_column=True
  878. )
  879. for key, action in base_actions.items():
  880. if "layers.0." in key:
  881. for i in range(num_layers):
  882. final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
  883. final_actions[key] = action
  884. return final_actions
  885. mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
  886. return mappings
  887. @classmethod
  888. def _get_fuse_or_split_param_mappings(cls, config: Qwen2Config, is_fuse=False):
  889. # return parameter fuse utils
  890. from paddlenlp.transformers.conversion_utils import split_or_fuse_func
  891. fn = split_or_fuse_func(is_fuse=is_fuse)
  892. # last key is fused key, other keys are to be fused.
  893. fuse_qkv_keys = [
  894. (
  895. "layers.0.self_attn.q_proj.weight",
  896. "layers.0.self_attn.k_proj.weight",
  897. "layers.0.self_attn.v_proj.weight",
  898. "layers.0.self_attn.qkv_proj.weight",
  899. ),
  900. (
  901. "layers.0.self_attn.q_proj.bias",
  902. "layers.0.self_attn.k_proj.bias",
  903. "layers.0.self_attn.v_proj.bias",
  904. "layers.0.self_attn.qkv_proj.bias",
  905. ),
  906. ]
  907. fuse_gate_up_keys = (
  908. "layers.0.mlp.gate_proj.weight",
  909. "layers.0.mlp.up_proj.weight",
  910. "layers.0.mlp.gate_up_fused_proj.weight",
  911. )
  912. num_heads = config.num_attention_heads
  913. num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
  914. fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
  915. fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)
  916. final_actions = {}
  917. if is_fuse:
  918. if fuse_attention_qkv:
  919. for i in range(config.num_hidden_layers):
  920. for fuse_keys in fuse_qkv_keys:
  921. keys = tuple(
  922. [
  923. key.replace("layers.0.", f"layers.{i}.")
  924. for key in fuse_keys
  925. ]
  926. )
  927. final_actions[keys] = partial(
  928. fn,
  929. is_qkv=True,
  930. num_heads=num_heads,
  931. num_key_value_heads=num_key_value_heads,
  932. )
  933. if fuse_attention_ffn:
  934. for i in range(config.num_hidden_layers):
  935. keys = tuple(
  936. [
  937. key.replace("layers.0.", f"layers.{i}.")
  938. for key in fuse_gate_up_keys
  939. ]
  940. )
  941. final_actions[keys] = fn
  942. else:
  943. if not fuse_attention_qkv:
  944. for i in range(config.num_hidden_layers):
  945. for fuse_keys in fuse_qkv_keys:
  946. keys = tuple(
  947. [
  948. key.replace("layers.0.", f"layers.{i}.")
  949. for key in fuse_keys
  950. ]
  951. )
  952. final_actions[keys] = partial(
  953. fn,
  954. split_nums=3,
  955. is_qkv=True,
  956. num_heads=num_heads,
  957. num_key_value_heads=num_key_value_heads,
  958. )
  959. if not fuse_attention_ffn:
  960. for i in range(config.num_hidden_layers):
  961. keys = tuple(
  962. [
  963. key.replace("layers.0.", f"layers.{i}.")
  964. for key in fuse_gate_up_keys
  965. ]
  966. )
  967. final_actions[keys] = partial(fn, split_nums=2)
  968. return final_actions
  969. class Qwen2Model(Qwen2PretrainedModel):
  970. """
  971. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
  972. Args:
  973. config: Qwen2Config
  974. """
  975. def __init__(self, config: Qwen2Config):
  976. super().__init__(config)
  977. self.padding_idx = config.pad_token_id
  978. self.vocab_size = config.vocab_size
  979. self.hidden_size = config.hidden_size
  980. self.sequence_parallel = config.sequence_parallel
  981. self.recompute_granularity = config.recompute_granularity
  982. self.no_recompute_layers = (
  983. config.no_recompute_layers if config.no_recompute_layers is not None else []
  984. )
  985. # Recompute defaults to False and is controlled by Trainer
  986. self.enable_recompute = False
  987. if (
  988. config.tensor_parallel_degree > 1
  989. and config.vocab_size % config.tensor_parallel_degree == 0
  990. ):
  991. self.embed_tokens = mpu.VocabParallelEmbedding(
  992. self.vocab_size,
  993. self.hidden_size,
  994. weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
  995. )
  996. else:
  997. self.embed_tokens = nn.Embedding(
  998. self.vocab_size,
  999. self.hidden_size,
  1000. )
  1001. self.layers = nn.LayerList(
  1002. [
  1003. Qwen2DecoderLayer(
  1004. config=config,
  1005. layerwise_recompute=layer_idx not in self.no_recompute_layers,
  1006. )
  1007. for layer_idx in range(config.num_hidden_layers)
  1008. ]
  1009. )
  1010. self.norm = Qwen2RMSNorm(config)
  1011. def get_input_embeddings(self):
  1012. return self.embed_tokens
  1013. def set_input_embeddings(self, value):
  1014. self.embed_tokens = value
  1015. @staticmethod
  1016. def _prepare_decoder_attention_mask(
  1017. attention_mask, input_shape, past_key_values_length, dtype
  1018. ):
  1019. if attention_mask is not None:
  1020. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1021. if len(attention_mask.shape) == 2:
  1022. expanded_attn_mask = _expand_2d_mask(
  1023. attention_mask, dtype, tgt_length=input_shape[-1]
  1024. )
  1025. # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
  1026. if input_shape[-1] > 1:
  1027. combined_attention_mask = _make_causal_mask(
  1028. input_shape,
  1029. past_key_values_length=past_key_values_length,
  1030. )
  1031. expanded_attn_mask = expanded_attn_mask & combined_attention_mask
  1032. # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
  1033. elif len(attention_mask.shape) == 3:
  1034. expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
  1035. # if attention_mask is already 4-D, do nothing
  1036. else:
  1037. expanded_attn_mask = attention_mask
  1038. else:
  1039. expanded_attn_mask = _make_causal_mask(
  1040. input_shape,
  1041. past_key_values_length=past_key_values_length,
  1042. )
  1043. # Convert bool attention_mask to float attention mask, which will be added to attention_scores later
  1044. if get_device_type() == "xpu":
  1045. x = paddle.to_tensor(0.0, dtype="float32")
  1046. y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
  1047. expanded_attn_mask = paddle.where(expanded_attn_mask, x, y)
  1048. else:
  1049. expanded_attn_mask = paddle.where(
  1050. expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min
  1051. ).astype(dtype)
  1052. return expanded_attn_mask
  1053. def forward(
  1054. self,
  1055. input_ids: paddle.Tensor = None,
  1056. position_ids: Optional[paddle.Tensor] = None,
  1057. attention_mask: Optional[paddle.Tensor] = None,
  1058. inputs_embeds: Optional[paddle.Tensor] = None,
  1059. use_cache: Optional[bool] = None,
  1060. past_key_values: Optional[List[paddle.Tensor]] = None,
  1061. output_attentions: Optional[bool] = None,
  1062. output_hidden_states: Optional[bool] = None,
  1063. return_dict: Optional[bool] = None,
  1064. attn_mask_startend_row_indices=None,
  1065. ) -> Union[Tuple, BaseModelOutputWithPast]:
  1066. output_attentions = (
  1067. output_attentions
  1068. if output_attentions is not None
  1069. else self.config.output_attentions
  1070. )
  1071. output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip
  1072. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1073. return_dict = (
  1074. return_dict if return_dict is not None else self.config.use_return_dict
  1075. )
  1076. # retrieve input_ids and inputs_embeds
  1077. if input_ids is not None and inputs_embeds is not None:
  1078. raise ValueError(
  1079. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  1080. )
  1081. elif input_ids is not None:
  1082. batch_size, seq_length = input_ids.shape
  1083. elif inputs_embeds is not None:
  1084. batch_size, seq_length, _ = inputs_embeds.shape
  1085. else:
  1086. raise ValueError(
  1087. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  1088. )
  1089. if past_key_values is None:
  1090. past_key_values = tuple([None] * len(self.layers))
  1091. # NOTE: to make cache can be clear in-time
  1092. past_key_values = list(past_key_values)
  1093. seq_length_with_past = seq_length
  1094. cache_length = 0
  1095. if past_key_values[0] is not None:
  1096. cache_length = past_key_values[0][0].shape[1]
  1097. seq_length_with_past += cache_length
  1098. if inputs_embeds is None:
  1099. # [bs, seq_len, dim]
  1100. inputs_embeds = self.embed_tokens(input_ids)
  1101. if self.sequence_parallel:
  1102. # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
  1103. bs, seq_len, hidden_size = inputs_embeds.shape
  1104. inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size])
  1105. # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
  1106. inputs_embeds = ScatterOp.apply(inputs_embeds)
  1107. # [bs, seq_len]
  1108. attention_mask = (
  1109. paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
  1110. if attention_mask is None
  1111. else attention_mask
  1112. )
  1113. attention_mask = self._prepare_decoder_attention_mask(
  1114. attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
  1115. ) # [bs, 1, seq_len, seq_len]
  1116. if self.config.use_flash_attention:
  1117. attention_mask = None if is_casual_mask(attention_mask) else attention_mask
  1118. if position_ids is None:
  1119. position_ids = paddle.arange(seq_length, dtype="int64").expand(
  1120. (batch_size, seq_length)
  1121. )
  1122. hidden_states = inputs_embeds
  1123. # decoder layers
  1124. all_hidden_states = () if output_hidden_states else None
  1125. all_self_attns = () if output_attentions else None
  1126. next_decoder_cache = () if use_cache else None
  1127. for idx, (decoder_layer) in enumerate(self.layers):
  1128. if output_hidden_states:
  1129. all_hidden_states += (hidden_states,)
  1130. past_key_value = (
  1131. past_key_values[idx] if past_key_values is not None else None
  1132. )
  1133. has_gradient = not hidden_states.stop_gradient
  1134. if (
  1135. self.enable_recompute
  1136. and idx not in self.no_recompute_layers
  1137. and has_gradient
  1138. and self.recompute_granularity == "full"
  1139. ):
  1140. layer_outputs = self.recompute_training_full(
  1141. decoder_layer,
  1142. hidden_states,
  1143. position_ids,
  1144. attention_mask,
  1145. output_attentions,
  1146. past_key_value,
  1147. use_cache,
  1148. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  1149. )
  1150. else:
  1151. layer_outputs = decoder_layer(
  1152. hidden_states,
  1153. position_ids,
  1154. attention_mask,
  1155. output_attentions,
  1156. past_key_value,
  1157. use_cache,
  1158. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  1159. )
  1160. # NOTE: clear outdate cache after it has been used for memory saving
  1161. past_key_value = past_key_values[idx] = None
  1162. if type(layer_outputs) is tuple:
  1163. hidden_states = layer_outputs[0]
  1164. else:
  1165. hidden_states = layer_outputs
  1166. if output_attentions:
  1167. all_self_attns += (layer_outputs[1],)
  1168. if use_cache:
  1169. next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
  1170. hidden_states = self.norm(hidden_states)
  1171. # add hidden states from the last decoder layer
  1172. if output_hidden_states:
  1173. all_hidden_states += (hidden_states,)
  1174. next_cache = next_decoder_cache if use_cache else None
  1175. if not return_dict:
  1176. return tuple(
  1177. v
  1178. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
  1179. if v is not None
  1180. )
  1181. return BaseModelOutputWithPast(
  1182. last_hidden_state=hidden_states,
  1183. past_key_values=next_cache,
  1184. hidden_states=all_hidden_states,
  1185. attentions=all_self_attns,
  1186. )
  1187. class Qwen2PretrainingCriterion(nn.Layer):
  1188. """
  1189. Criterion for Mixtral.
  1190. It calculates the final loss.
  1191. """
  1192. def __init__(self, config: Qwen2Config):
  1193. super(Qwen2PretrainingCriterion, self).__init__()
  1194. self.ignore_index = getattr(config, "ignore_index", -100)
  1195. self.config = config
  1196. self.enable_parallel_cross_entropy = (
  1197. config.tensor_parallel_degree > 1 and config.tensor_parallel_output
  1198. )
  1199. if (
  1200. self.enable_parallel_cross_entropy
  1201. ): # and False: # and lm_head is distributed
  1202. self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index)
  1203. else:
  1204. self.loss_func = paddle.nn.CrossEntropyLoss(
  1205. reduction="none", ignore_index=self.ignore_index
  1206. )
  1207. def forward(self, prediction_scores, masked_lm_labels):
  1208. if self.enable_parallel_cross_entropy:
  1209. if prediction_scores.shape[-1] == self.config.vocab_size:
  1210. logging.warning(
  1211. f"enable_parallel_cross_entropy, the vocab_size should be splitted: {prediction_scores.shape[-1]}, {self.config.vocab_size}"
  1212. )
  1213. self.loss_func = paddle.nn.CrossEntropyLoss(
  1214. reduction="none", ignore_index=self.ignore_index
  1215. )
  1216. with paddle.amp.auto_cast(False):
  1217. masked_lm_loss = self.loss_func(
  1218. prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)
  1219. )
  1220. # skip ignore_index which loss == 0
  1221. # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
  1222. # loss = paddle.mean(masked_lm_loss)
  1223. binary_sequence = paddle.where(
  1224. masked_lm_loss > 0,
  1225. paddle.ones_like(masked_lm_loss),
  1226. paddle.zeros_like(masked_lm_loss),
  1227. )
  1228. count = paddle.sum(binary_sequence)
  1229. if count == 0:
  1230. loss = paddle.sum(masked_lm_loss * binary_sequence)
  1231. else:
  1232. loss = paddle.sum(masked_lm_loss * binary_sequence) / count
  1233. return loss
  1234. class Qwen2LMHead(nn.Layer):
  1235. def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=False):
  1236. super(Qwen2LMHead, self).__init__()
  1237. self.config = config
  1238. if (
  1239. config.tensor_parallel_degree > 1
  1240. and config.vocab_size % config.tensor_parallel_degree == 0
  1241. ):
  1242. vocab_size = config.vocab_size // config.tensor_parallel_degree
  1243. else:
  1244. vocab_size = config.vocab_size
  1245. self.transpose_y = transpose_y
  1246. if transpose_y:
  1247. if embedding_weights is not None:
  1248. self.weight = embedding_weights
  1249. else:
  1250. self.weight = self.create_parameter(
  1251. shape=[vocab_size, config.hidden_size],
  1252. dtype=paddle.get_default_dtype(),
  1253. )
  1254. else:
  1255. if vocab_size != config.vocab_size:
  1256. self.weight = self.create_parameter(
  1257. shape=[config.hidden_size, vocab_size],
  1258. dtype=paddle.get_default_dtype(),
  1259. )
  1260. else:
  1261. self.weight = self.create_parameter(
  1262. shape=[config.hidden_size, vocab_size],
  1263. dtype=paddle.get_default_dtype(),
  1264. )
  1265. # Must set distributed attr for Tensor Parallel !
  1266. self.weight.is_distributed = (
  1267. True if (vocab_size != config.vocab_size) else False
  1268. )
  1269. if self.weight.is_distributed:
  1270. # for tie_word_embeddings
  1271. self.weight.split_axis = 0 if self.transpose_y else 1
  1272. def forward(self, hidden_states, tensor_parallel_output=None):
  1273. if self.config.sequence_parallel:
  1274. hidden_states = GatherOp.apply(hidden_states)
  1275. seq_length = self.config.seq_length
  1276. hidden_states = paddle.reshape_(
  1277. hidden_states, [-1, seq_length, self.config.hidden_size]
  1278. )
  1279. if tensor_parallel_output is None:
  1280. tensor_parallel_output = self.config.tensor_parallel_output
  1281. logits = parallel_matmul(
  1282. hidden_states,
  1283. self.weight,
  1284. transpose_y=self.transpose_y,
  1285. tensor_parallel_output=tensor_parallel_output,
  1286. )
  1287. return logits
  1288. class Qwen2ForCausalLM(Qwen2PretrainedModel):
  1289. enable_to_static_method = True
  1290. _tied_weights_keys = ["lm_head.weight"]
  1291. def __init__(self, config: Qwen2Config):
  1292. super().__init__(config)
  1293. self.qwen2 = Qwen2Model(config)
  1294. if config.tie_word_embeddings:
  1295. self.lm_head = Qwen2LMHead(
  1296. config,
  1297. embedding_weights=self.qwen2.embed_tokens.weight,
  1298. transpose_y=True,
  1299. )
  1300. self.tie_weights()
  1301. else:
  1302. self.lm_head = Qwen2LMHead(config)
  1303. self.criterion = Qwen2PretrainingCriterion(config)
  1304. self.vocab_size = config.vocab_size
  1305. def get_input_embeddings(self):
  1306. return self.qwen2.embed_tokens
  1307. def set_input_embeddings(self, value):
  1308. self.qwen2.embed_tokens = value
  1309. def get_output_embeddings(self):
  1310. return self.lm_head
  1311. def set_output_embeddings(self, new_embeddings):
  1312. self.lm_head = new_embeddings
  1313. def set_decoder(self, decoder):
  1314. self.qwen2 = decoder
  1315. def get_decoder(self):
  1316. return self.qwen2
  1317. def prepare_inputs_for_generation(
  1318. self,
  1319. input_ids,
  1320. use_cache=False,
  1321. past_key_values=None,
  1322. attention_mask=None,
  1323. inputs_embeds=None,
  1324. **kwargs,
  1325. ):
  1326. batch_size, seq_length = input_ids.shape
  1327. position_ids = kwargs.get(
  1328. "position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))
  1329. )
  1330. if past_key_values:
  1331. input_ids = input_ids[:, -1].unsqueeze(axis=-1)
  1332. position_ids = position_ids[:, -1].unsqueeze(-1)
  1333. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1334. if inputs_embeds is not None and past_key_values is None:
  1335. model_inputs = {"inputs_embeds": inputs_embeds}
  1336. else:
  1337. model_inputs = {"input_ids": input_ids}
  1338. model_inputs.update(
  1339. {
  1340. "position_ids": position_ids,
  1341. "past_key_values": past_key_values,
  1342. "use_cache": use_cache,
  1343. "attention_mask": attention_mask,
  1344. }
  1345. )
  1346. return model_inputs
  1347. def _get_model_inputs_spec(self, dtype: str):
  1348. return {
  1349. "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
  1350. "attention_mask": paddle.static.InputSpec(
  1351. shape=[None, None], dtype="int64"
  1352. ),
  1353. "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
  1354. }
  1355. @staticmethod
  1356. def update_model_kwargs_for_generation(
  1357. outputs, model_kwargs, is_encoder_decoder=False
  1358. ):
  1359. # update cache
  1360. if (
  1361. isinstance(outputs, tuple)
  1362. and len(outputs) > 1
  1363. and not isinstance(outputs[1], paddle.Tensor)
  1364. ):
  1365. model_kwargs["past_key_values"] = outputs[1]
  1366. if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs:
  1367. model_kwargs["past_key_values"] = outputs.past_key_values
  1368. # update position_ids
  1369. if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
  1370. position_ids = model_kwargs["position_ids"]
  1371. model_kwargs["position_ids"] = paddle.concat(
  1372. [position_ids, position_ids[..., -1:] + 1], axis=-1
  1373. )
  1374. if not is_encoder_decoder and "attention_mask" in model_kwargs:
  1375. # TODO: support attention mask for other models
  1376. attention_mask = model_kwargs["attention_mask"]
  1377. if len(attention_mask.shape) == 2:
  1378. model_kwargs["attention_mask"] = paddle.concat(
  1379. [
  1380. attention_mask,
  1381. paddle.ones(
  1382. [attention_mask.shape[0], 1], dtype=attention_mask.dtype
  1383. ),
  1384. ],
  1385. axis=-1,
  1386. )
  1387. elif len(attention_mask.shape) == 4:
  1388. model_kwargs["attention_mask"] = paddle.concat(
  1389. [
  1390. attention_mask,
  1391. paddle.ones(
  1392. [*attention_mask.shape[:3], 1], dtype=attention_mask.dtype
  1393. ),
  1394. ],
  1395. axis=-1,
  1396. )[:, :, -1:, :]
  1397. return model_kwargs
  1398. def forward(
  1399. self,
  1400. input_ids: paddle.Tensor = None,
  1401. position_ids: Optional[paddle.Tensor] = None,
  1402. attention_mask: Optional[paddle.Tensor] = None,
  1403. inputs_embeds: Optional[paddle.Tensor] = None,
  1404. labels: Optional[paddle.Tensor] = None,
  1405. use_cache: Optional[bool] = None,
  1406. past_key_values: Optional[List[paddle.Tensor]] = None,
  1407. output_attentions: Optional[bool] = None,
  1408. output_hidden_states: Optional[bool] = None,
  1409. return_dict: Optional[bool] = None,
  1410. attn_mask_startend_row_indices=None,
  1411. ) -> Union[Tuple, CausalLMOutputWithPast]:
  1412. r"""
  1413. Args:
  1414. labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1415. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1416. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1417. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1418. Returns:
  1419. Example:
  1420. ```python
  1421. >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
  1422. >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
  1423. >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
  1424. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  1425. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1426. >>> # Generate
  1427. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1428. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1429. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  1430. ```"""
  1431. output_attentions = (
  1432. output_attentions
  1433. if output_attentions is not None
  1434. else self.config.output_attentions
  1435. )
  1436. output_hidden_states = (
  1437. output_hidden_states
  1438. if output_hidden_states is not None
  1439. else self.config.output_hidden_states
  1440. )
  1441. return_dict = (
  1442. return_dict if return_dict is not None else self.config.use_return_dict
  1443. )
  1444. if attn_mask_startend_row_indices is not None and attention_mask is not None:
  1445. logging.warning(
  1446. "You have provided both attn_mask_startend_row_indices and attention_mask. "
  1447. "The attn_mask_startend_row_indices will be used."
  1448. )
  1449. attention_mask = None
  1450. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1451. outputs = self.qwen2(
  1452. input_ids=input_ids,
  1453. position_ids=position_ids,
  1454. attention_mask=attention_mask,
  1455. inputs_embeds=inputs_embeds,
  1456. use_cache=use_cache,
  1457. past_key_values=past_key_values,
  1458. output_attentions=output_attentions,
  1459. output_hidden_states=output_hidden_states,
  1460. return_dict=return_dict,
  1461. attn_mask_startend_row_indices=attn_mask_startend_row_indices,
  1462. )
  1463. hidden_states = outputs[0]
  1464. # if labels is None,means we need full output, instead of tensor_parallel_output
  1465. # tensor_parallel_output is together with ParallelCrossEntropy
  1466. tensor_parallel_output = (
  1467. self.config.tensor_parallel_output
  1468. and self.config.tensor_parallel_degree > 1
  1469. )
  1470. logits = self.lm_head(
  1471. hidden_states, tensor_parallel_output=tensor_parallel_output
  1472. )
  1473. loss = None
  1474. if not return_dict:
  1475. output = (logits,) + outputs[1:]
  1476. return (loss,) + output if loss is not None else output
  1477. return CausalLMOutputWithPast(
  1478. loss=loss,
  1479. logits=logits,
  1480. past_key_values=outputs.past_key_values,
  1481. hidden_states=outputs.hidden_states,
  1482. attentions=outputs.attentions,
  1483. )