| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from functools import partial
- from typing import List, Optional, Tuple, Union
- import paddle
- import paddle.distributed.fleet.meta_parallel as mpu
- import paddle.nn as nn
- import paddle.nn.functional as F
- from paddle import Tensor
- from paddle.distributed import fleet
- from paddle.distributed.fleet.utils import sequence_parallel_utils
- from .....utils import logging
- from .....utils.env import get_device_type
- from ...common.vlm import fusion_ops
- from ...common.vlm.activations import ACT2FN
- from ...common.vlm.transformers import PretrainedConfig, PretrainedModel
- from ...common.vlm.transformers.model_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- )
- try:
- from paddle.incubate.nn.functional import fused_rotary_position_embedding
- except ImportError:
- fused_rotary_position_embedding = None
- try:
- from paddle.distributed.fleet.utils.sequence_parallel_utils import (
- GatherOp,
- ScatterOp,
- mark_as_sequence_parallel_parameter,
- )
- except:
- pass
- try:
- from paddle.nn.functional.flash_attention import flash_attention
- except:
- flash_attention = None
- Linear = nn.Linear
- ColumnParallelLinear = mpu.ColumnParallelLinear
- RowParallelLinear = mpu.RowParallelLinear
- ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
- RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
- class Qwen2Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
- Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of
- Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Args:
- vocab_size (`int`, *optional*, defaults to 151936):
- Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`Qwen2Model`]
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 22016):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer encoder.
- num_key_value_heads (`int`, *optional*, defaults to 32):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 32768):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether the model's input and output word embeddings should be tied.
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
- use_sliding_window (`bool`, *optional*, defaults to `False`):
- Whether to use sliding window attention.
- sliding_window (`int`, *optional*, defaults to 4096):
- Sliding window attention (SWA) window size. If not specified, will default to `4096`.
- max_window_layers (`int`, *optional*, defaults to 28):
- The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- """
- model_type = "qwen2"
- keys_to_ignore_at_inference = ["past_key_values"]
- def __init__(
- self,
- vocab_size=151936,
- hidden_size=4096,
- intermediate_size=22016,
- num_hidden_layers=32,
- num_attention_heads=32,
- num_key_value_heads=32,
- hidden_act="silu",
- max_position_embeddings=32768,
- seq_length=32768,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- tie_word_embeddings=False,
- rope_theta=10000.0,
- pad_token_id=0,
- bos_token_id=151643,
- eos_token_id=151643,
- use_sliding_window=False,
- sliding_window=4096,
- max_window_layers=28,
- attention_dropout=0.0,
- rope_scaling_factor=1.0,
- rope_scaling_type=None,
- dpo_config=None,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.seq_length = seq_length
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.use_sliding_window = use_sliding_window
- self.sliding_window = sliding_window
- self.max_window_layers = max_window_layers
- # for backward compatibility
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.attention_dropout = attention_dropout
- self.use_cache = use_cache
- self.rope_scaling_factor = rope_scaling_factor
- self.rope_scaling_type = rope_scaling_type
- self.pad_token_id = pad_token_id
- self.bos_token_id = bos_token_id
- self.eos_token_id = eos_token_id
- self.dpo_config = dpo_config
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
- def get_triangle_upper_mask(x, mask=None):
- if mask is not None:
- return mask
- # [bsz, n_head, q_len, kv_seq_len]
- shape = x.shape
- # [bsz, 1, q_len, kv_seq_len]
- shape[1] = 1
- mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype)
- mask = paddle.triu(mask, diagonal=1)
- mask.stop_gradient = True
- return mask
- def parallel_matmul(
- x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True
- ):
- is_fleet_init = True
- tensor_parallel_degree = 1
- try:
- hcg = fleet.get_hybrid_communicate_group()
- model_parallel_group = hcg.get_model_parallel_group()
- tensor_parallel_degree = hcg.get_model_parallel_world_size()
- except:
- is_fleet_init = False
- if paddle.in_dynamic_mode():
- y_is_distributed = y.is_distributed
- else:
- y_is_distributed = tensor_parallel_degree > 1
- if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
- # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
- input_parallel = paddle.distributed.collective._c_identity(
- x, group=model_parallel_group
- )
- logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
- if tensor_parallel_output:
- return logits
- return paddle.distributed.collective._c_concat(
- logits, group=model_parallel_group
- )
- else:
- logits = paddle.matmul(x, y, transpose_y=transpose_y)
- return logits
- def scaled_dot_product_attention(
- query_states,
- config,
- key_states,
- value_states,
- attention_mask,
- output_attentions,
- attn_mask_startend_row_indices=None,
- training=True,
- sequence_parallel=False,
- skip_recompute=False,
- ):
- bsz, q_len, num_heads, head_dim = query_states.shape
- _, kv_seq_len, _, _ = value_states.shape
- # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
- query_states = paddle.transpose(query_states, [0, 2, 1, 3])
- # merge with the next transpose
- key_states = paddle.transpose(key_states, [0, 2, 1, 3])
- value_states = paddle.transpose(value_states, [0, 2, 1, 3])
- # Add pre divided factor to fix nan under float16.
- if paddle.in_dynamic_mode() and query_states.dtype == paddle.float16:
- pre_divided_factor = 32
- else:
- pre_divided_factor = 1
- attn_weights = paddle.matmul(
- query_states / (math.sqrt(head_dim) * pre_divided_factor),
- key_states.transpose([0, 1, 3, 2]),
- )
- if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]:
- raise ValueError(
- f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is"
- f" {attn_weights.shape}"
- )
- if attention_mask is None:
- attention_mask = get_triangle_upper_mask(attn_weights)
- attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len])
- if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]:
- raise ValueError(
- f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}"
- )
- attn_weights = attn_weights + attention_mask
- if not paddle.in_dynamic_mode():
- attn_weights = F.softmax(
- attn_weights * pre_divided_factor, axis=-1, dtype="float32"
- ).astype(query_states.dtype)
- else:
- with paddle.amp.auto_cast(False):
- attn_weights = F.softmax(
- attn_weights.astype("float32") * pre_divided_factor,
- axis=-1,
- dtype="float32",
- ).astype(query_states.dtype)
- attn_weights = F.dropout(
- attn_weights, p=config.attention_dropout, training=training
- )
- attn_output = paddle.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose([0, 2, 1, 3])
- if sequence_parallel:
- attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
- else:
- attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
- return (attn_output, attn_weights) if output_attentions else attn_output
- def is_casual_mask(attention_mask):
- """
- Upper triangular of attention_mask equals to attention_mask is casual
- """
- return (paddle.triu(attention_mask) == attention_mask).all().item()
- def _make_causal_mask(input_ids_shape, past_key_values_length):
- """
- Make causal mask used for self-attention
- """
- batch_size, target_length = input_ids_shape # target_length: seq_len
- mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
- if past_key_values_length > 0:
- # [tgt_len, tgt_len + past_len]
- mask = paddle.concat(
- [paddle.ones([target_length, past_key_values_length], dtype="bool"), mask],
- axis=-1,
- )
- # [bs, 1, tgt_len, tgt_len + past_len]
- return mask[None, None, :, :].expand(
- [batch_size, 1, target_length, target_length + past_key_values_length]
- )
- def _expand_2d_mask(mask, dtype, tgt_length):
- """
- Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
- """
- batch_size, src_length = mask.shape[0], mask.shape[-1]
- tgt_length = tgt_length if tgt_length is not None else src_length
- mask = mask[:, None, None, :].astype("bool")
- mask.stop_gradient = True
- expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
- return expanded_mask
- class Qwen2RMSNorm(nn.Layer):
- def __init__(self, config: Qwen2Config):
- """
- Qwen2RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.hidden_size = config.hidden_size
- self.weight = paddle.create_parameter(
- shape=[self.hidden_size],
- dtype=paddle.get_default_dtype(),
- default_initializer=nn.initializer.Constant(1.0),
- )
- self.variance_epsilon = config.rms_norm_eps
- self.config = config
- if config.sequence_parallel:
- mark_as_sequence_parallel_parameter(self.weight)
- def forward(self, hidden_states):
- if self.config.use_fused_rms_norm:
- return fusion_ops.fusion_rms_norm(
- hidden_states, self.weight, self.variance_epsilon, False
- )
- if paddle.in_dynamic_mode():
- with paddle.amp.auto_cast(False):
- variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
- hidden_states = (
- paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
- )
- else:
- variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
- hidden_states = (
- paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
- )
- if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
- hidden_states = paddle.cast(hidden_states, self.weight.dtype)
- return hidden_states * self.weight
- class Qwen2RotaryEmbedding(nn.Layer):
- def __init__(self, dim, max_position_embeddings=2048, base=10000):
- super().__init__()
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- # [dim / 2]
- self.inv_freq = 1.0 / (
- self.base
- ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)
- )
- self._set_cos_sin_cache(seq_len=max_position_embeddings)
- def _set_cos_sin_cache(self, seq_len):
- self.max_seq_len_cached = seq_len
- # [seq_len]
- t = paddle.arange(seq_len, dtype="float32")
- # [seq_len, dim/2]
- freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- # [seq_len, dim]
- emb = paddle.concat([freqs, freqs], axis=-1)
- # [1, seqlen, 1, dim]
- self.cos_cached = emb.cos()[None, :, None, :]
- self.sin_cached = emb.sin()[None, :, None, :]
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len > self.max_seq_len_cached:
- self._set_cos_sin_cache(seq_len)
- cos = self.cos_cached[:, :seq_len, :, :]
- sin = self.sin_cached[:, :seq_len, :, :]
- return (
- cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
- sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
- )
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
- if position_ids is None:
- # Note: Only for Qwen2MoEForCausalLMPipe model pretraining
- cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
- sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
- else:
- cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
- sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
- sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- class Qwen2MLP(nn.Layer):
- def __init__(self, config: Qwen2Config, is_shared=False, skip_recompute_ops=None):
- super().__init__()
- if skip_recompute_ops is None:
- skip_recompute_ops = {}
- self.skip_recompute_ops = skip_recompute_ops
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.fuse_attention_ffn = config.fuse_attention_ffn
- self.tensor_parallel_degree = config.tensor_parallel_degree
- if config.sequence_parallel:
- ColumnParallelLinear = ColumnSequenceParallelLinear
- RowParallelLinear = RowSequenceParallelLinear
- if config.tensor_parallel_degree > 1:
- if self.fuse_attention_ffn:
- self.gate_up_fused_proj = ColumnParallelLinear(
- self.hidden_size,
- self.intermediate_size * 2,
- gather_output=False,
- has_bias=False,
- )
- else:
- self.gate_proj = ColumnParallelLinear(
- self.hidden_size,
- self.intermediate_size,
- gather_output=False,
- has_bias=False,
- )
- self.up_proj = ColumnParallelLinear(
- self.hidden_size,
- self.intermediate_size,
- gather_output=False,
- has_bias=False,
- )
- self.down_proj = RowParallelLinear(
- self.intermediate_size,
- self.hidden_size,
- input_is_parallel=True,
- has_bias=False,
- )
- else:
- if self.fuse_attention_ffn:
- self.gate_up_fused_proj = Linear(
- self.hidden_size, self.intermediate_size * 2, bias_attr=False
- )
- else:
- self.gate_proj = Linear(
- self.hidden_size, self.intermediate_size, bias_attr=False
- ) # w1
- self.up_proj = Linear(
- self.hidden_size, self.intermediate_size, bias_attr=False
- ) # w3
- self.down_proj = Linear(
- self.intermediate_size, self.hidden_size, bias_attr=False
- ) # w2
- if config.hidden_act == "silu":
- self.act_fn = fusion_ops.swiglu
- self.fuse_swiglu = True
- else:
- self.act_fn = ACT2FN[config.hidden_act]
- self.fuse_swiglu = False
- def forward(self, x):
- if self.fuse_attention_ffn:
- x = self.gate_up_fused_proj(x)
- if self.fuse_swiglu:
- y = None
- else:
- x, y = x.chunk(2, axis=-1)
- else:
- x, y = self.gate_proj(x), self.up_proj(x)
- if self.fuse_swiglu:
- x = self.act_fn(x, y)
- else:
- x = self.act_fn(x) * y
- return self.down_proj(x)
- def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
- """
- This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, slen, num_key_value_heads, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
- return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim])
- class Qwen2Attention(nn.Layer):
- """
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
- and "Generating Long Sequences with Sparse Transformers".
- """
- def __init__(
- self,
- config: Qwen2Config,
- layerwise_recompute: bool = True,
- skip_recompute_ops=None,
- ):
- super().__init__()
- if skip_recompute_ops is None:
- skip_recompute_ops = {}
- self.config = config
- self.skip_recompute_ops = skip_recompute_ops
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // config.num_attention_heads
- self.num_key_value_heads = config.num_key_value_heads
- assert config.num_attention_heads // config.num_key_value_heads
- self.num_key_value_groups = (
- config.num_attention_heads // config.num_key_value_heads
- )
- self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.is_causal = True
- self.attention_dropout = config.attention_dropout
- self.seq_length = config.seq_length
- self.sequence_parallel = config.sequence_parallel
- self.fuse_attention_qkv = config.fuse_attention_qkv
- # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
- # Enable_recompute defaults to False and is controlled by Trainer
- self.enable_recompute = False
- self.layerwise_recompute = layerwise_recompute
- self.recompute_granularity = config.recompute_granularity
- if config.tensor_parallel_degree > 1:
- assert (
- self.num_heads % config.tensor_parallel_degree == 0
- ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
- self.num_heads = self.num_heads // config.tensor_parallel_degree
- assert (
- self.num_key_value_heads % config.tensor_parallel_degree == 0
- ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
- self.num_key_value_heads = (
- self.num_key_value_heads // config.tensor_parallel_degree
- )
- self.use_fused_rope = config.use_fused_rope
- if self.use_fused_rope:
- if (
- get_device_type() not in ["gpu", "xpu"]
- or fused_rotary_position_embedding is None
- ):
- logging.warning(
- "Enable fuse rope in the config, but fuse rope is not available. "
- "Will disable fuse rope. Try using latest gpu version of Paddle."
- )
- self.use_fused_rope = False
- if config.sequence_parallel:
- ColumnParallelLinear = ColumnSequenceParallelLinear
- RowParallelLinear = RowSequenceParallelLinear
- if config.tensor_parallel_degree > 1:
- if self.fuse_attention_qkv:
- self.qkv_proj = ColumnParallelLinear(
- self.hidden_size,
- self.hidden_size
- + 2 * self.config.num_key_value_heads * self.head_dim,
- has_bias=True,
- gather_output=False,
- )
- else:
- self.q_proj = ColumnParallelLinear(
- self.hidden_size,
- self.hidden_size,
- has_bias=True,
- gather_output=False,
- )
- self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
- self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
- self.o_proj = RowParallelLinear(
- self.hidden_size,
- self.hidden_size,
- has_bias=False,
- input_is_parallel=True,
- )
- else:
- if self.fuse_attention_qkv:
- self.qkv_proj = Linear(
- self.hidden_size,
- self.hidden_size
- + 2 * self.config.num_key_value_heads * self.head_dim,
- )
- else:
- self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True)
- self.k_proj = Linear(
- self.hidden_size,
- self.config.num_key_value_heads * self.head_dim,
- bias_attr=True,
- )
- self.v_proj = Linear(
- self.hidden_size,
- self.config.num_key_value_heads * self.head_dim,
- bias_attr=True,
- )
- self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False)
- self.rotary_emb = Qwen2RotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.rope_theta,
- )
- self.attn_func = scaled_dot_product_attention
- def forward(
- self,
- hidden_states,
- position_ids: Optional[Tuple[paddle.Tensor]] = None,
- past_key_value: Optional[Tuple[paddle.Tensor]] = None,
- attention_mask: Optional[paddle.Tensor] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
- **kwargs,
- ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
- if self.fuse_attention_qkv:
- mix_layer = self.qkv_proj(hidden_states)
- if self.sequence_parallel:
- target_shape = [
- -1,
- self.seq_length,
- self.num_key_value_heads,
- (self.num_key_value_groups + 2) * self.head_dim,
- ]
- else:
- target_shape = [
- 0,
- 0,
- self.num_key_value_heads,
- (self.num_key_value_groups + 2) * self.head_dim,
- ]
- mix_layer = paddle.reshape_(mix_layer, target_shape)
- query_states, key_states, value_states = paddle.split(
- mix_layer,
- num_or_sections=[
- self.num_key_value_groups * self.head_dim,
- self.head_dim,
- self.head_dim,
- ],
- axis=-1,
- )
- if self.gqa_or_mqa:
- query_states = paddle.reshape_(
- query_states, [0, 0, self.num_heads, self.head_dim]
- )
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- if self.sequence_parallel:
- target_query_shape = [
- -1,
- self.seq_length,
- self.num_heads,
- self.head_dim,
- ]
- target_key_value_shape = [
- -1,
- self.seq_length,
- self.num_key_value_heads,
- self.head_dim,
- ]
- else:
- target_query_shape = [0, 0, self.num_heads, self.head_dim]
- target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
- query_states = query_states.reshape(shape=target_query_shape)
- key_states = key_states.reshape(shape=target_key_value_shape)
- value_states = value_states.reshape(shape=target_key_value_shape)
- kv_seq_len = key_states.shape[-3]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-3]
- if self.use_fused_rope:
- assert past_key_value is None, "fuse rotary not support cache kv for now"
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states, _ = fused_rotary_position_embedding(
- query_states,
- key_states,
- v=None,
- sin=sin,
- cos=cos,
- position_ids=position_ids,
- use_neox_rotary_style=False,
- )
- else:
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids
- )
- # [bs, seq_len, num_head, head_dim]
- if past_key_value is not None:
- key_states = paddle.concat([past_key_value[0], key_states], axis=1)
- value_states = paddle.concat([past_key_value[1], value_states], axis=1)
- past_key_value = (key_states, value_states) if use_cache else None
- # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
- # repeat k/v heads if n_kv_heads < n_heads
- paddle_version = float(paddle.__version__[:3])
- if not self.config.use_flash_attention or (
- (paddle_version != 0.0) and (paddle_version <= 2.6)
- ):
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- outputs = self.attn_func(
- query_states,
- self.config,
- key_states,
- value_states,
- attention_mask,
- output_attentions,
- attn_mask_startend_row_indices=attn_mask_startend_row_indices,
- training=self.training,
- sequence_parallel=self.sequence_parallel,
- )
- if output_attentions:
- attn_output, attn_weights = outputs
- else:
- attn_output = outputs
- # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
- # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
- attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
- outputs = (attn_output,)
- if output_attentions:
- outputs += (attn_weights,)
- if use_cache:
- outputs += (past_key_value,)
- if type(outputs) is tuple and len(outputs) == 1:
- outputs = outputs[0]
- return outputs
- class Qwen2DecoderLayer(nn.Layer):
- def __init__(
- self,
- config: Qwen2Config,
- layerwise_recompute: bool = False,
- skip_recompute_ops=None,
- ):
- super().__init__()
- if skip_recompute_ops is None:
- skip_recompute_ops = {}
- self.config = config
- self.skip_recompute_ops = skip_recompute_ops
- self.hidden_size = config.hidden_size
- self.self_attn = Qwen2Attention(
- config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops
- )
- self.mlp = Qwen2MLP(config, skip_recompute_ops=skip_recompute_ops)
- self.input_layernorm = Qwen2RMSNorm(config)
- self.post_attention_layernorm = Qwen2RMSNorm(config)
- # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
- # Enable_recompute defaults to False and is controlled by Trainer
- self.enable_recompute = False
- self.layerwise_recompute = layerwise_recompute
- self.recompute_granularity = config.recompute_granularity
- def forward(
- self,
- hidden_states: paddle.Tensor,
- position_ids: Optional[paddle.Tensor] = None,
- attention_mask: Optional[paddle.Tensor] = None,
- output_attentions: Optional[bool] = False,
- past_key_value: Optional[Tuple[paddle.Tensor]] = None,
- use_cache: Optional[bool] = False,
- attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
- **kwargs,
- ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
- """
- Args:
- hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`paddle.Tensor`, *optional*): attention mask of size
- `(batch, sequence_length)` where padding elements are indicated by 0.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
- """
- # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel)
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- outputs = self.self_attn(
- hidden_states,
- position_ids,
- past_key_value,
- attention_mask,
- output_attentions,
- use_cache,
- attn_mask_startend_row_indices=attn_mask_startend_row_indices,
- )
- if type(outputs) is tuple:
- hidden_states = outputs[0]
- else:
- hidden_states = outputs
- if output_attentions:
- self_attn_weights = outputs[1]
- if use_cache:
- present_key_value = outputs[2 if output_attentions else 1]
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
- if type(outputs) is tuple and len(outputs) == 1:
- outputs = outputs[0]
- return outputs
- class Qwen2PretrainedModel(PretrainedModel):
- config_class = Qwen2Config
- base_model_prefix = "qwen2"
- _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"]
- @classmethod
- def _get_fuse_or_split_param_mappings(cls, config: Qwen2Config, is_fuse=False):
- # return parameter fuse utils
- from ...common.vlm.conversion_utils import split_or_fuse_func
- fn = split_or_fuse_func(is_fuse=is_fuse)
- # last key is fused key, other keys are to be fused.
- fuse_qkv_keys = [
- (
- "layers.0.self_attn.q_proj.weight",
- "layers.0.self_attn.k_proj.weight",
- "layers.0.self_attn.v_proj.weight",
- "layers.0.self_attn.qkv_proj.weight",
- ),
- (
- "layers.0.self_attn.q_proj.bias",
- "layers.0.self_attn.k_proj.bias",
- "layers.0.self_attn.v_proj.bias",
- "layers.0.self_attn.qkv_proj.bias",
- ),
- ]
- fuse_gate_up_keys = (
- "layers.0.mlp.gate_proj.weight",
- "layers.0.mlp.up_proj.weight",
- "layers.0.mlp.gate_up_fused_proj.weight",
- )
- num_heads = config.num_attention_heads
- num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
- fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
- fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)
- final_actions = {}
- if is_fuse:
- if fuse_attention_qkv:
- for i in range(config.num_hidden_layers):
- for fuse_keys in fuse_qkv_keys:
- keys = tuple(
- [
- key.replace("layers.0.", f"layers.{i}.")
- for key in fuse_keys
- ]
- )
- final_actions[keys] = partial(
- fn,
- is_qkv=True,
- num_heads=num_heads,
- num_key_value_heads=num_key_value_heads,
- )
- if fuse_attention_ffn:
- for i in range(config.num_hidden_layers):
- keys = tuple(
- [
- key.replace("layers.0.", f"layers.{i}.")
- for key in fuse_gate_up_keys
- ]
- )
- final_actions[keys] = fn
- else:
- if not fuse_attention_qkv:
- for i in range(config.num_hidden_layers):
- for fuse_keys in fuse_qkv_keys:
- keys = tuple(
- [
- key.replace("layers.0.", f"layers.{i}.")
- for key in fuse_keys
- ]
- )
- final_actions[keys] = partial(
- fn,
- split_nums=3,
- is_qkv=True,
- num_heads=num_heads,
- num_key_value_heads=num_key_value_heads,
- )
- if not fuse_attention_ffn:
- for i in range(config.num_hidden_layers):
- keys = tuple(
- [
- key.replace("layers.0.", f"layers.{i}.")
- for key in fuse_gate_up_keys
- ]
- )
- final_actions[keys] = partial(fn, split_nums=2)
- return final_actions
- class Qwen2Model(Qwen2PretrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
- Args:
- config: Qwen2Config
- """
- def __init__(self, config: Qwen2Config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.hidden_size = config.hidden_size
- self.sequence_parallel = config.sequence_parallel
- self.recompute_granularity = config.recompute_granularity
- self.no_recompute_layers = (
- config.no_recompute_layers if config.no_recompute_layers is not None else []
- )
- # Recompute defaults to False and is controlled by Trainer
- self.enable_recompute = False
- if (
- config.tensor_parallel_degree > 1
- and config.vocab_size % config.tensor_parallel_degree == 0
- ):
- self.embed_tokens = mpu.VocabParallelEmbedding(
- self.vocab_size,
- self.hidden_size,
- weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
- )
- else:
- self.embed_tokens = nn.Embedding(
- self.vocab_size,
- self.hidden_size,
- )
- self.layers = nn.LayerList(
- [
- Qwen2DecoderLayer(
- config=config,
- layerwise_recompute=layer_idx not in self.no_recompute_layers,
- )
- for layer_idx in range(config.num_hidden_layers)
- ]
- )
- self.norm = Qwen2RMSNorm(config)
- def get_input_embeddings(self):
- return self.embed_tokens
- def set_input_embeddings(self, value):
- self.embed_tokens = value
- @staticmethod
- def _prepare_decoder_attention_mask(
- attention_mask, input_shape, past_key_values_length, dtype
- ):
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- if len(attention_mask.shape) == 2:
- expanded_attn_mask = _expand_2d_mask(
- attention_mask, dtype, tgt_length=input_shape[-1]
- )
- # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
- if input_shape[-1] > 1:
- combined_attention_mask = _make_causal_mask(
- input_shape,
- past_key_values_length=past_key_values_length,
- )
- expanded_attn_mask = expanded_attn_mask & combined_attention_mask
- # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
- elif len(attention_mask.shape) == 3:
- expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
- # if attention_mask is already 4-D, do nothing
- else:
- expanded_attn_mask = attention_mask
- else:
- expanded_attn_mask = _make_causal_mask(
- input_shape,
- past_key_values_length=past_key_values_length,
- )
- # Convert bool attention_mask to float attention mask, which will be added to attention_scores later
- if get_device_type() == "xpu":
- x = paddle.to_tensor(0.0, dtype="float32")
- y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
- expanded_attn_mask = paddle.where(expanded_attn_mask, x, y)
- else:
- expanded_attn_mask = paddle.where(
- expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min
- ).astype(dtype)
- return expanded_attn_mask
- def forward(
- self,
- input_ids: paddle.Tensor = None,
- position_ids: Optional[paddle.Tensor] = None,
- attention_mask: Optional[paddle.Tensor] = None,
- inputs_embeds: Optional[paddle.Tensor] = None,
- use_cache: Optional[bool] = None,
- past_key_values: Optional[List[paddle.Tensor]] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- attn_mask_startend_row_indices=None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
- )
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError(
- "You have to specify either decoder_input_ids or decoder_inputs_embeds"
- )
- if past_key_values is None:
- past_key_values = tuple([None] * len(self.layers))
- # NOTE: to make cache can be clear in-time
- past_key_values = list(past_key_values)
- seq_length_with_past = seq_length
- cache_length = 0
- if past_key_values[0] is not None:
- cache_length = past_key_values[0][0].shape[1]
- seq_length_with_past += cache_length
- if inputs_embeds is None:
- # [bs, seq_len, dim]
- inputs_embeds = self.embed_tokens(input_ids)
- if self.sequence_parallel:
- # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
- bs, seq_len, hidden_size = inputs_embeds.shape
- inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size])
- # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
- inputs_embeds = ScatterOp.apply(inputs_embeds)
- # [bs, seq_len]
- attention_mask = (
- paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
- if attention_mask is None
- else attention_mask
- )
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
- ) # [bs, 1, seq_len, seq_len]
- if self.config.use_flash_attention:
- attention_mask = None if is_casual_mask(attention_mask) else attention_mask
- if position_ids is None:
- position_ids = paddle.arange(seq_length, dtype="int64").expand(
- (batch_size, seq_length)
- )
- hidden_states = inputs_embeds
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
- for idx, (decoder_layer) in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- past_key_value = (
- past_key_values[idx] if past_key_values is not None else None
- )
- has_gradient = not hidden_states.stop_gradient
- if (
- self.enable_recompute
- and idx not in self.no_recompute_layers
- and has_gradient
- and self.recompute_granularity == "full"
- ):
- layer_outputs = self.recompute_training_full(
- decoder_layer,
- hidden_states,
- position_ids,
- attention_mask,
- output_attentions,
- past_key_value,
- use_cache,
- attn_mask_startend_row_indices=attn_mask_startend_row_indices,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- position_ids,
- attention_mask,
- output_attentions,
- past_key_value,
- use_cache,
- attn_mask_startend_row_indices=attn_mask_startend_row_indices,
- )
- # NOTE: clear outdate cache after it has been used for memory saving
- past_key_value = past_key_values[idx] = None
- if type(layer_outputs) is tuple:
- hidden_states = layer_outputs[0]
- else:
- hidden_states = layer_outputs
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
- hidden_states = self.norm(hidden_states)
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- next_cache = next_decoder_cache if use_cache else None
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
- if v is not None
- )
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
- class Qwen2PretrainingCriterion(nn.Layer):
- """
- Criterion for Mixtral.
- It calculates the final loss.
- """
- def __init__(self, config: Qwen2Config):
- super(Qwen2PretrainingCriterion, self).__init__()
- self.ignore_index = getattr(config, "ignore_index", -100)
- self.config = config
- self.enable_parallel_cross_entropy = (
- config.tensor_parallel_degree > 1 and config.tensor_parallel_output
- )
- if (
- self.enable_parallel_cross_entropy
- ): # and False: # and lm_head is distributed
- self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index)
- else:
- self.loss_func = paddle.nn.CrossEntropyLoss(
- reduction="none", ignore_index=self.ignore_index
- )
- def forward(self, prediction_scores, masked_lm_labels):
- if self.enable_parallel_cross_entropy:
- if prediction_scores.shape[-1] == self.config.vocab_size:
- logging.warning(
- f"enable_parallel_cross_entropy, the vocab_size should be splitted: {prediction_scores.shape[-1]}, {self.config.vocab_size}"
- )
- self.loss_func = paddle.nn.CrossEntropyLoss(
- reduction="none", ignore_index=self.ignore_index
- )
- with paddle.amp.auto_cast(False):
- masked_lm_loss = self.loss_func(
- prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)
- )
- # skip ignore_index which loss == 0
- # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
- # loss = paddle.mean(masked_lm_loss)
- binary_sequence = paddle.where(
- masked_lm_loss > 0,
- paddle.ones_like(masked_lm_loss),
- paddle.zeros_like(masked_lm_loss),
- )
- count = paddle.sum(binary_sequence)
- if count == 0:
- loss = paddle.sum(masked_lm_loss * binary_sequence)
- else:
- loss = paddle.sum(masked_lm_loss * binary_sequence) / count
- return loss
- class Qwen2LMHead(nn.Layer):
- def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=False):
- super(Qwen2LMHead, self).__init__()
- self.config = config
- if (
- config.tensor_parallel_degree > 1
- and config.vocab_size % config.tensor_parallel_degree == 0
- ):
- vocab_size = config.vocab_size // config.tensor_parallel_degree
- else:
- vocab_size = config.vocab_size
- self.transpose_y = transpose_y
- if transpose_y:
- if embedding_weights is not None:
- self.weight = embedding_weights
- else:
- self.weight = self.create_parameter(
- shape=[vocab_size, config.hidden_size],
- dtype=paddle.get_default_dtype(),
- )
- else:
- if vocab_size != config.vocab_size:
- self.weight = self.create_parameter(
- shape=[config.hidden_size, vocab_size],
- dtype=paddle.get_default_dtype(),
- )
- else:
- self.weight = self.create_parameter(
- shape=[config.hidden_size, vocab_size],
- dtype=paddle.get_default_dtype(),
- )
- # Must set distributed attr for Tensor Parallel !
- self.weight.is_distributed = (
- True if (vocab_size != config.vocab_size) else False
- )
- if self.weight.is_distributed:
- # for tie_word_embeddings
- self.weight.split_axis = 0 if self.transpose_y else 1
- def forward(self, hidden_states, tensor_parallel_output=None):
- if self.config.sequence_parallel:
- hidden_states = GatherOp.apply(hidden_states)
- seq_length = self.config.seq_length
- hidden_states = paddle.reshape_(
- hidden_states, [-1, seq_length, self.config.hidden_size]
- )
- if tensor_parallel_output is None:
- tensor_parallel_output = self.config.tensor_parallel_output
- logits = parallel_matmul(
- hidden_states,
- self.weight,
- transpose_y=self.transpose_y,
- tensor_parallel_output=tensor_parallel_output,
- )
- return logits
- class Qwen2ForCausalLM(Qwen2PretrainedModel):
- enable_to_static_method = True
- _tied_weights_keys = ["lm_head.weight"]
- def __init__(self, config: Qwen2Config):
- super().__init__(config)
- self.qwen2 = Qwen2Model(config)
- if config.tie_word_embeddings:
- self.lm_head = Qwen2LMHead(
- config,
- embedding_weights=self.qwen2.embed_tokens.weight,
- transpose_y=True,
- )
- self.tie_weights()
- else:
- self.lm_head = Qwen2LMHead(config)
- self.criterion = Qwen2PretrainingCriterion(config)
- self.vocab_size = config.vocab_size
- def get_input_embeddings(self):
- return self.qwen2.embed_tokens
- def set_input_embeddings(self, value):
- self.qwen2.embed_tokens = value
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- def set_decoder(self, decoder):
- self.qwen2 = decoder
- def get_decoder(self):
- return self.qwen2
- def prepare_inputs_for_generation(
- self,
- input_ids,
- use_cache=False,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- **kwargs,
- ):
- batch_size, seq_length = input_ids.shape
- position_ids = kwargs.get(
- "position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))
- )
- if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(axis=-1)
- position_ids = position_ids[:, -1].unsqueeze(-1)
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids}
- model_inputs.update(
- {
- "position_ids": position_ids,
- "past_key_values": past_key_values,
- "use_cache": use_cache,
- "attention_mask": attention_mask,
- }
- )
- return model_inputs
- def _get_model_inputs_spec(self, dtype: str):
- return {
- "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
- "attention_mask": paddle.static.InputSpec(
- shape=[None, None], dtype="int64"
- ),
- "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
- }
- @staticmethod
- def update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=False
- ):
- # update cache
- if (
- isinstance(outputs, tuple)
- and len(outputs) > 1
- and not isinstance(outputs[1], paddle.Tensor)
- ):
- model_kwargs["past_key_values"] = outputs[1]
- if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs:
- model_kwargs["past_key_values"] = outputs.past_key_values
- # update position_ids
- if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
- position_ids = model_kwargs["position_ids"]
- model_kwargs["position_ids"] = paddle.concat(
- [position_ids, position_ids[..., -1:] + 1], axis=-1
- )
- if not is_encoder_decoder and "attention_mask" in model_kwargs:
- # TODO: support attention mask for other models
- attention_mask = model_kwargs["attention_mask"]
- if len(attention_mask.shape) == 2:
- model_kwargs["attention_mask"] = paddle.concat(
- [
- attention_mask,
- paddle.ones(
- [attention_mask.shape[0], 1], dtype=attention_mask.dtype
- ),
- ],
- axis=-1,
- )
- elif len(attention_mask.shape) == 4:
- model_kwargs["attention_mask"] = paddle.concat(
- [
- attention_mask,
- paddle.ones(
- [*attention_mask.shape[:3], 1], dtype=attention_mask.dtype
- ),
- ],
- axis=-1,
- )[:, :, -1:, :]
- return model_kwargs
- def forward(
- self,
- input_ids: paddle.Tensor = None,
- position_ids: Optional[paddle.Tensor] = None,
- attention_mask: Optional[paddle.Tensor] = None,
- inputs_embeds: Optional[paddle.Tensor] = None,
- labels: Optional[paddle.Tensor] = None,
- use_cache: Optional[bool] = None,
- past_key_values: Optional[List[paddle.Tensor]] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- attn_mask_startend_row_indices=None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Returns:
- Example:
- ```python
- >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
- >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
- if attn_mask_startend_row_indices is not None and attention_mask is not None:
- logging.warning(
- "You have provided both attn_mask_startend_row_indices and attention_mask. "
- "The attn_mask_startend_row_indices will be used."
- )
- attention_mask = None
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.qwen2(
- input_ids=input_ids,
- position_ids=position_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- attn_mask_startend_row_indices=attn_mask_startend_row_indices,
- )
- hidden_states = outputs[0]
- # if labels is None,means we need full output, instead of tensor_parallel_output
- # tensor_parallel_output is together with ParallelCrossEntropy
- tensor_parallel_output = (
- self.config.tensor_parallel_output
- and self.config.tensor_parallel_degree > 1
- )
- logits = self.lm_head(
- hidden_states, tensor_parallel_output=tensor_parallel_output
- )
- loss = None
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
|