|
|
@@ -0,0 +1,1706 @@
|
|
|
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+
|
|
|
+import math
|
|
|
+from functools import partial
|
|
|
+from typing import List, Optional, Tuple, Union
|
|
|
+
|
|
|
+import paddle
|
|
|
+import paddle.distributed.fleet.meta_parallel as mpu
|
|
|
+import paddle.nn as nn
|
|
|
+import paddle.nn.functional as F
|
|
|
+from paddle import Tensor
|
|
|
+from paddle.distributed import fleet
|
|
|
+from paddle.distributed.fleet.utils import sequence_parallel_utils
|
|
|
+
|
|
|
+from .....utils import logging
|
|
|
+from .....utils.env import get_device_type
|
|
|
+from ...common.vlm import fusion_ops
|
|
|
+from ...common.vlm.activations import ACT2FN
|
|
|
+from ...common.vlm.transformers import PretrainedConfig, PretrainedModel
|
|
|
+from ...common.vlm.transformers.model_outputs import (
|
|
|
+ BaseModelOutputWithPast,
|
|
|
+ CausalLMOutputWithPast,
|
|
|
+)
|
|
|
+
|
|
|
+try:
|
|
|
+ from paddle.incubate.nn.functional import fused_rotary_position_embedding
|
|
|
+except ImportError:
|
|
|
+ fused_rotary_position_embedding = None
|
|
|
+
|
|
|
+try:
|
|
|
+ from paddle.distributed.fleet.utils.sequence_parallel_utils import (
|
|
|
+ GatherOp,
|
|
|
+ ScatterOp,
|
|
|
+ mark_as_sequence_parallel_parameter,
|
|
|
+ )
|
|
|
+except:
|
|
|
+ pass
|
|
|
+
|
|
|
+try:
|
|
|
+ from paddle.nn.functional.flash_attention import flash_attention
|
|
|
+except:
|
|
|
+ flash_attention = None
|
|
|
+
|
|
|
+
|
|
|
+Linear = nn.Linear
|
|
|
+ColumnParallelLinear = mpu.ColumnParallelLinear
|
|
|
+RowParallelLinear = mpu.RowParallelLinear
|
|
|
+ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
|
|
|
+RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2Config(PretrainedConfig):
|
|
|
+ r"""
|
|
|
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
|
|
|
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
|
+ with the defaults will yield a similar configuration to that of
|
|
|
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
|
|
|
+
|
|
|
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
|
+ documentation from [`PretrainedConfig`] for more information.
|
|
|
+
|
|
|
+
|
|
|
+ Args:
|
|
|
+ vocab_size (`int`, *optional*, defaults to 151936):
|
|
|
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
|
|
|
+ `inputs_ids` passed when calling [`Qwen2Model`]
|
|
|
+ hidden_size (`int`, *optional*, defaults to 4096):
|
|
|
+ Dimension of the hidden representations.
|
|
|
+ intermediate_size (`int`, *optional*, defaults to 22016):
|
|
|
+ Dimension of the MLP representations.
|
|
|
+ num_hidden_layers (`int`, *optional*, defaults to 32):
|
|
|
+ Number of hidden layers in the Transformer encoder.
|
|
|
+ num_attention_heads (`int`, *optional*, defaults to 32):
|
|
|
+ Number of attention heads for each attention layer in the Transformer encoder.
|
|
|
+ num_key_value_heads (`int`, *optional*, defaults to 32):
|
|
|
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
|
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
|
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
|
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
|
+ by meanpooling all the original heads within that group. For more details checkout [this
|
|
|
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
|
|
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
|
+ The non-linear activation function (function or string) in the decoder.
|
|
|
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
|
|
|
+ The maximum sequence length that this model might ever be used with.
|
|
|
+ initializer_range (`float`, *optional*, defaults to 0.02):
|
|
|
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
|
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
|
+ The epsilon used by the rms normalization layers.
|
|
|
+ use_cache (`bool`, *optional*, defaults to `True`):
|
|
|
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
|
+ relevant if `config.is_decoder=True`.
|
|
|
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
|
+ Whether the model's input and output word embeddings should be tied.
|
|
|
+ rope_theta (`float`, *optional*, defaults to 10000.0):
|
|
|
+ The base period of the RoPE embeddings.
|
|
|
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
|
|
|
+ Whether to use sliding window attention.
|
|
|
+ sliding_window (`int`, *optional*, defaults to 4096):
|
|
|
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
|
|
+ max_window_layers (`int`, *optional*, defaults to 28):
|
|
|
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
|
|
+ attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
|
+ The dropout ratio for the attention probabilities.
|
|
|
+ """
|
|
|
+
|
|
|
+ model_type = "qwen2"
|
|
|
+ keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ vocab_size=151936,
|
|
|
+ hidden_size=4096,
|
|
|
+ intermediate_size=22016,
|
|
|
+ num_hidden_layers=32,
|
|
|
+ num_attention_heads=32,
|
|
|
+ num_key_value_heads=32,
|
|
|
+ hidden_act="silu",
|
|
|
+ max_position_embeddings=32768,
|
|
|
+ seq_length=32768,
|
|
|
+ initializer_range=0.02,
|
|
|
+ rms_norm_eps=1e-6,
|
|
|
+ use_cache=True,
|
|
|
+ tie_word_embeddings=False,
|
|
|
+ rope_theta=10000.0,
|
|
|
+ pad_token_id=0,
|
|
|
+ bos_token_id=151643,
|
|
|
+ eos_token_id=151643,
|
|
|
+ use_sliding_window=False,
|
|
|
+ sliding_window=4096,
|
|
|
+ max_window_layers=28,
|
|
|
+ attention_dropout=0.0,
|
|
|
+ rope_scaling_factor=1.0,
|
|
|
+ rope_scaling_type=None,
|
|
|
+ dpo_config=None,
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
+ self.vocab_size = vocab_size
|
|
|
+ self.max_position_embeddings = max_position_embeddings
|
|
|
+ self.seq_length = seq_length
|
|
|
+ self.hidden_size = hidden_size
|
|
|
+ self.intermediate_size = intermediate_size
|
|
|
+ self.num_hidden_layers = num_hidden_layers
|
|
|
+ self.num_attention_heads = num_attention_heads
|
|
|
+ self.use_sliding_window = use_sliding_window
|
|
|
+ self.sliding_window = sliding_window
|
|
|
+ self.max_window_layers = max_window_layers
|
|
|
+
|
|
|
+ # for backward compatibility
|
|
|
+ if num_key_value_heads is None:
|
|
|
+ num_key_value_heads = num_attention_heads
|
|
|
+
|
|
|
+ self.num_key_value_heads = num_key_value_heads
|
|
|
+ self.hidden_act = hidden_act
|
|
|
+ self.initializer_range = initializer_range
|
|
|
+ self.rms_norm_eps = rms_norm_eps
|
|
|
+ self.use_cache = use_cache
|
|
|
+ self.rope_theta = rope_theta
|
|
|
+ self.attention_dropout = attention_dropout
|
|
|
+
|
|
|
+ self.use_cache = use_cache
|
|
|
+ self.rope_scaling_factor = rope_scaling_factor
|
|
|
+ self.rope_scaling_type = rope_scaling_type
|
|
|
+
|
|
|
+ self.pad_token_id = pad_token_id
|
|
|
+ self.bos_token_id = bos_token_id
|
|
|
+ self.eos_token_id = eos_token_id
|
|
|
+ self.dpo_config = dpo_config
|
|
|
+
|
|
|
+ super().__init__(
|
|
|
+ pad_token_id=pad_token_id,
|
|
|
+ bos_token_id=bos_token_id,
|
|
|
+ eos_token_id=eos_token_id,
|
|
|
+ tie_word_embeddings=tie_word_embeddings,
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def get_triangle_upper_mask(x, mask=None):
|
|
|
+ if mask is not None:
|
|
|
+ return mask
|
|
|
+ # [bsz, n_head, q_len, kv_seq_len]
|
|
|
+ shape = x.shape
|
|
|
+ # [bsz, 1, q_len, kv_seq_len]
|
|
|
+ shape[1] = 1
|
|
|
+ mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype)
|
|
|
+ mask = paddle.triu(mask, diagonal=1)
|
|
|
+ mask.stop_gradient = True
|
|
|
+ return mask
|
|
|
+
|
|
|
+
|
|
|
+def parallel_matmul(
|
|
|
+ x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True
|
|
|
+):
|
|
|
+ is_fleet_init = True
|
|
|
+ tensor_parallel_degree = 1
|
|
|
+ try:
|
|
|
+ hcg = fleet.get_hybrid_communicate_group()
|
|
|
+ model_parallel_group = hcg.get_model_parallel_group()
|
|
|
+ tensor_parallel_degree = hcg.get_model_parallel_world_size()
|
|
|
+ except:
|
|
|
+ is_fleet_init = False
|
|
|
+
|
|
|
+ if paddle.in_dynamic_mode():
|
|
|
+ y_is_distributed = y.is_distributed
|
|
|
+ else:
|
|
|
+ y_is_distributed = tensor_parallel_degree > 1
|
|
|
+
|
|
|
+ if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
|
|
|
+ # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
|
|
|
+ input_parallel = paddle.distributed.collective._c_identity(
|
|
|
+ x, group=model_parallel_group
|
|
|
+ )
|
|
|
+ logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
|
|
|
+
|
|
|
+ if tensor_parallel_output:
|
|
|
+ return logits
|
|
|
+
|
|
|
+ return paddle.distributed.collective._c_concat(
|
|
|
+ logits, group=model_parallel_group
|
|
|
+ )
|
|
|
+
|
|
|
+ else:
|
|
|
+ logits = paddle.matmul(x, y, transpose_y=transpose_y)
|
|
|
+ return logits
|
|
|
+
|
|
|
+
|
|
|
+def scaled_dot_product_attention(
|
|
|
+ query_states,
|
|
|
+ config,
|
|
|
+ key_states,
|
|
|
+ value_states,
|
|
|
+ attention_mask,
|
|
|
+ output_attentions,
|
|
|
+ attn_mask_startend_row_indices=None,
|
|
|
+ training=True,
|
|
|
+ sequence_parallel=False,
|
|
|
+ skip_recompute=False,
|
|
|
+):
|
|
|
+ bsz, q_len, num_heads, head_dim = query_states.shape
|
|
|
+ _, kv_seq_len, _, _ = value_states.shape
|
|
|
+
|
|
|
+ if config.use_flash_attention and flash_attention:
|
|
|
+ # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
|
|
|
+ # Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
|
|
|
+
|
|
|
+ return fusion_ops.fusion_flash_attention(
|
|
|
+ query_states,
|
|
|
+ config,
|
|
|
+ key_states,
|
|
|
+ value_states,
|
|
|
+ attention_mask,
|
|
|
+ output_attentions,
|
|
|
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
|
|
|
+ sequence_parallel=sequence_parallel,
|
|
|
+ skip_recompute=skip_recompute,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
|
|
|
+ query_states = paddle.transpose(query_states, [0, 2, 1, 3])
|
|
|
+ # merge with the next transpose
|
|
|
+ key_states = paddle.transpose(key_states, [0, 2, 1, 3])
|
|
|
+ value_states = paddle.transpose(value_states, [0, 2, 1, 3])
|
|
|
+
|
|
|
+ # Add pre divided factor to fix nan under float16.
|
|
|
+ if paddle.in_dynamic_mode() and query_states.dtype == paddle.float16:
|
|
|
+ pre_divided_factor = 32
|
|
|
+ else:
|
|
|
+ pre_divided_factor = 1
|
|
|
+
|
|
|
+ attn_weights = paddle.matmul(
|
|
|
+ query_states / (math.sqrt(head_dim) * pre_divided_factor),
|
|
|
+ key_states.transpose([0, 1, 3, 2]),
|
|
|
+ )
|
|
|
+
|
|
|
+ if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]:
|
|
|
+ raise ValueError(
|
|
|
+ f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is"
|
|
|
+ f" {attn_weights.shape}"
|
|
|
+ )
|
|
|
+
|
|
|
+ if attention_mask is None:
|
|
|
+ attention_mask = get_triangle_upper_mask(attn_weights)
|
|
|
+
|
|
|
+ attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len])
|
|
|
+ if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]:
|
|
|
+ raise ValueError(
|
|
|
+ f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}"
|
|
|
+ )
|
|
|
+
|
|
|
+ attn_weights = attn_weights + attention_mask
|
|
|
+
|
|
|
+ if not paddle.in_dynamic_mode():
|
|
|
+ attn_weights = F.softmax(
|
|
|
+ attn_weights * pre_divided_factor, axis=-1, dtype="float32"
|
|
|
+ ).astype(query_states.dtype)
|
|
|
+ else:
|
|
|
+ with paddle.amp.auto_cast(False):
|
|
|
+ attn_weights = F.softmax(
|
|
|
+ attn_weights.astype("float32") * pre_divided_factor,
|
|
|
+ axis=-1,
|
|
|
+ dtype="float32",
|
|
|
+ ).astype(query_states.dtype)
|
|
|
+
|
|
|
+ attn_weights = F.dropout(
|
|
|
+ attn_weights, p=config.attention_dropout, training=training
|
|
|
+ )
|
|
|
+
|
|
|
+ attn_output = paddle.matmul(attn_weights, value_states)
|
|
|
+ attn_output = attn_output.transpose([0, 2, 1, 3])
|
|
|
+
|
|
|
+ if sequence_parallel:
|
|
|
+ attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
|
|
|
+ else:
|
|
|
+ attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
|
|
|
+ return (attn_output, attn_weights) if output_attentions else attn_output
|
|
|
+
|
|
|
+
|
|
|
+def is_casual_mask(attention_mask):
|
|
|
+ """
|
|
|
+ Upper triangular of attention_mask equals to attention_mask is casual
|
|
|
+ """
|
|
|
+ return (paddle.triu(attention_mask) == attention_mask).all().item()
|
|
|
+
|
|
|
+
|
|
|
+def _make_causal_mask(input_ids_shape, past_key_values_length):
|
|
|
+ """
|
|
|
+ Make causal mask used for self-attention
|
|
|
+ """
|
|
|
+ batch_size, target_length = input_ids_shape # target_length: seq_len
|
|
|
+
|
|
|
+ mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
|
|
|
+
|
|
|
+ if past_key_values_length > 0:
|
|
|
+ # [tgt_len, tgt_len + past_len]
|
|
|
+ mask = paddle.concat(
|
|
|
+ [paddle.ones([target_length, past_key_values_length], dtype="bool"), mask],
|
|
|
+ axis=-1,
|
|
|
+ )
|
|
|
+
|
|
|
+ # [bs, 1, tgt_len, tgt_len + past_len]
|
|
|
+ return mask[None, None, :, :].expand(
|
|
|
+ [batch_size, 1, target_length, target_length + past_key_values_length]
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _expand_2d_mask(mask, dtype, tgt_length):
|
|
|
+ """
|
|
|
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
|
|
+ """
|
|
|
+ batch_size, src_length = mask.shape[0], mask.shape[-1]
|
|
|
+ tgt_length = tgt_length if tgt_length is not None else src_length
|
|
|
+
|
|
|
+ mask = mask[:, None, None, :].astype("bool")
|
|
|
+ mask.stop_gradient = True
|
|
|
+ expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
|
|
|
+
|
|
|
+ return expanded_mask
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2RMSNorm(nn.Layer):
|
|
|
+ def __init__(self, config: Qwen2Config):
|
|
|
+ """
|
|
|
+ Qwen2RMSNorm is equivalent to T5LayerNorm
|
|
|
+ """
|
|
|
+ super().__init__()
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
+ self.weight = paddle.create_parameter(
|
|
|
+ shape=[self.hidden_size],
|
|
|
+ dtype=paddle.get_default_dtype(),
|
|
|
+ default_initializer=nn.initializer.Constant(1.0),
|
|
|
+ )
|
|
|
+ self.variance_epsilon = config.rms_norm_eps
|
|
|
+ self.config = config
|
|
|
+
|
|
|
+ if config.sequence_parallel:
|
|
|
+ mark_as_sequence_parallel_parameter(self.weight)
|
|
|
+
|
|
|
+ def forward(self, hidden_states):
|
|
|
+ if self.config.use_fused_rms_norm:
|
|
|
+ return fusion_ops.fusion_rms_norm(
|
|
|
+ hidden_states, self.weight, self.variance_epsilon, False
|
|
|
+ )
|
|
|
+
|
|
|
+ if paddle.in_dynamic_mode():
|
|
|
+ with paddle.amp.auto_cast(False):
|
|
|
+ variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
|
|
|
+ hidden_states = (
|
|
|
+ paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
|
|
|
+ hidden_states = (
|
|
|
+ paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
|
|
|
+ hidden_states = paddle.cast(hidden_states, self.weight.dtype)
|
|
|
+ return hidden_states * self.weight
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2RotaryEmbedding(nn.Layer):
|
|
|
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
|
|
|
+ super().__init__()
|
|
|
+ self.dim = dim
|
|
|
+ self.max_position_embeddings = max_position_embeddings
|
|
|
+ self.base = base
|
|
|
+ # [dim / 2]
|
|
|
+ self.inv_freq = 1.0 / (
|
|
|
+ self.base
|
|
|
+ ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)
|
|
|
+ )
|
|
|
+ self._set_cos_sin_cache(seq_len=max_position_embeddings)
|
|
|
+
|
|
|
+ def _set_cos_sin_cache(self, seq_len):
|
|
|
+ self.max_seq_len_cached = seq_len
|
|
|
+ # [seq_len]
|
|
|
+ t = paddle.arange(seq_len, dtype="float32")
|
|
|
+ # [seq_len, dim/2]
|
|
|
+ freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
|
|
|
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
|
+ # [seq_len, dim]
|
|
|
+ emb = paddle.concat([freqs, freqs], axis=-1)
|
|
|
+ # [1, seqlen, 1, dim]
|
|
|
+ self.cos_cached = emb.cos()[None, :, None, :]
|
|
|
+ self.sin_cached = emb.sin()[None, :, None, :]
|
|
|
+
|
|
|
+ def forward(self, x, seq_len=None):
|
|
|
+ # x: [bs, num_attention_heads, seq_len, head_size]
|
|
|
+ if seq_len > self.max_seq_len_cached:
|
|
|
+ self._set_cos_sin_cache(seq_len)
|
|
|
+ cos = self.cos_cached[:, :seq_len, :, :]
|
|
|
+ sin = self.sin_cached[:, :seq_len, :, :]
|
|
|
+ return (
|
|
|
+ cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
|
|
|
+ sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def rotate_half(x):
|
|
|
+ """Rotates half the hidden dims of the input."""
|
|
|
+ x1 = x[..., : x.shape[-1] // 2]
|
|
|
+ x2 = x[..., x.shape[-1] // 2 :]
|
|
|
+ return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
|
|
|
+
|
|
|
+
|
|
|
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
|
+ if position_ids is None:
|
|
|
+ # Note: Only for Qwen2MoEForCausalLMPipe model pretraining
|
|
|
+ cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
|
|
|
+ sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
|
|
|
+ else:
|
|
|
+ cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
|
|
|
+ sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
|
|
|
+ cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
|
|
|
+ sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
|
|
|
+ q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
|
+ k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
|
+ return q_embed, k_embed
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2MLP(nn.Layer):
|
|
|
+ def __init__(self, config: Qwen2Config, is_shared=False, skip_recompute_ops=None):
|
|
|
+ super().__init__()
|
|
|
+ if skip_recompute_ops is None:
|
|
|
+ skip_recompute_ops = {}
|
|
|
+ self.skip_recompute_ops = skip_recompute_ops
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
+ self.intermediate_size = config.intermediate_size
|
|
|
+ self.fuse_attention_ffn = config.fuse_attention_ffn
|
|
|
+
|
|
|
+ self.tensor_parallel_degree = config.tensor_parallel_degree
|
|
|
+
|
|
|
+ if config.sequence_parallel:
|
|
|
+ ColumnParallelLinear = ColumnSequenceParallelLinear
|
|
|
+ RowParallelLinear = RowSequenceParallelLinear
|
|
|
+
|
|
|
+ if config.tensor_parallel_degree > 1:
|
|
|
+ if self.fuse_attention_ffn:
|
|
|
+ self.gate_up_fused_proj = ColumnParallelLinear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.intermediate_size * 2,
|
|
|
+ gather_output=False,
|
|
|
+ has_bias=False,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.gate_proj = ColumnParallelLinear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.intermediate_size,
|
|
|
+ gather_output=False,
|
|
|
+ has_bias=False,
|
|
|
+ )
|
|
|
+ self.up_proj = ColumnParallelLinear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.intermediate_size,
|
|
|
+ gather_output=False,
|
|
|
+ has_bias=False,
|
|
|
+ )
|
|
|
+ self.down_proj = RowParallelLinear(
|
|
|
+ self.intermediate_size,
|
|
|
+ self.hidden_size,
|
|
|
+ input_is_parallel=True,
|
|
|
+ has_bias=False,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ if self.fuse_attention_ffn:
|
|
|
+ self.gate_up_fused_proj = Linear(
|
|
|
+ self.hidden_size, self.intermediate_size * 2, bias_attr=False
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.gate_proj = Linear(
|
|
|
+ self.hidden_size, self.intermediate_size, bias_attr=False
|
|
|
+ ) # w1
|
|
|
+ self.up_proj = Linear(
|
|
|
+ self.hidden_size, self.intermediate_size, bias_attr=False
|
|
|
+ ) # w3
|
|
|
+ self.down_proj = Linear(
|
|
|
+ self.intermediate_size, self.hidden_size, bias_attr=False
|
|
|
+ ) # w2
|
|
|
+
|
|
|
+ if config.hidden_act == "silu":
|
|
|
+ self.act_fn = fusion_ops.swiglu
|
|
|
+ self.fuse_swiglu = True
|
|
|
+ else:
|
|
|
+ self.act_fn = ACT2FN[config.hidden_act]
|
|
|
+ self.fuse_swiglu = False
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ if self.fuse_attention_ffn:
|
|
|
+ x = self.gate_up_fused_proj(x)
|
|
|
+ if self.fuse_swiglu:
|
|
|
+ y = None
|
|
|
+ else:
|
|
|
+ x, y = x.chunk(2, axis=-1)
|
|
|
+ else:
|
|
|
+ x, y = self.gate_proj(x), self.up_proj(x)
|
|
|
+
|
|
|
+ if self.fuse_swiglu:
|
|
|
+ x = self.act_fn(x, y)
|
|
|
+ else:
|
|
|
+ x = self.act_fn(x) * y
|
|
|
+
|
|
|
+ return self.down_proj(x)
|
|
|
+
|
|
|
+
|
|
|
+def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
|
|
|
+ """
|
|
|
+ This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch,
|
|
|
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
|
+ """
|
|
|
+ batch, slen, num_key_value_heads, head_dim = hidden_states.shape
|
|
|
+ if n_rep == 1:
|
|
|
+ return hidden_states
|
|
|
+
|
|
|
+ hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
|
|
|
+ return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim])
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2Attention(nn.Layer):
|
|
|
+ """
|
|
|
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
|
|
+ and "Generating Long Sequences with Sparse Transformers".
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: Qwen2Config,
|
|
|
+ layerwise_recompute: bool = True,
|
|
|
+ skip_recompute_ops=None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ if skip_recompute_ops is None:
|
|
|
+ skip_recompute_ops = {}
|
|
|
+ self.config = config
|
|
|
+ self.skip_recompute_ops = skip_recompute_ops
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
+ self.num_heads = config.num_attention_heads
|
|
|
+
|
|
|
+ self.head_dim = self.hidden_size // config.num_attention_heads
|
|
|
+
|
|
|
+ self.num_key_value_heads = config.num_key_value_heads
|
|
|
+ assert config.num_attention_heads // config.num_key_value_heads
|
|
|
+ self.num_key_value_groups = (
|
|
|
+ config.num_attention_heads // config.num_key_value_heads
|
|
|
+ )
|
|
|
+ self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads
|
|
|
+ self.max_position_embeddings = config.max_position_embeddings
|
|
|
+ self.rope_theta = config.rope_theta
|
|
|
+ self.is_causal = True
|
|
|
+ self.attention_dropout = config.attention_dropout
|
|
|
+
|
|
|
+ self.seq_length = config.seq_length
|
|
|
+ self.sequence_parallel = config.sequence_parallel
|
|
|
+
|
|
|
+ self.fuse_attention_qkv = config.fuse_attention_qkv
|
|
|
+
|
|
|
+ # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
|
|
|
+ # Enable_recompute defaults to False and is controlled by Trainer
|
|
|
+ self.enable_recompute = False
|
|
|
+ self.layerwise_recompute = layerwise_recompute
|
|
|
+ self.recompute_granularity = config.recompute_granularity
|
|
|
+ if config.tensor_parallel_degree > 1:
|
|
|
+ assert (
|
|
|
+ self.num_heads % config.tensor_parallel_degree == 0
|
|
|
+ ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
|
|
|
+ self.num_heads = self.num_heads // config.tensor_parallel_degree
|
|
|
+
|
|
|
+ assert (
|
|
|
+ self.num_key_value_heads % config.tensor_parallel_degree == 0
|
|
|
+ ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
|
|
|
+ self.num_key_value_heads = (
|
|
|
+ self.num_key_value_heads // config.tensor_parallel_degree
|
|
|
+ )
|
|
|
+
|
|
|
+ self.use_fused_rope = config.use_fused_rope
|
|
|
+ if self.use_fused_rope:
|
|
|
+ if (
|
|
|
+ get_device_type() not in ["gpu", "xpu"]
|
|
|
+ or fused_rotary_position_embedding is None
|
|
|
+ ):
|
|
|
+ logging.warning(
|
|
|
+ "Enable fuse rope in the config, but fuse rope is not available. "
|
|
|
+ "Will disable fuse rope. Try using latest gpu version of Paddle."
|
|
|
+ )
|
|
|
+ self.use_fused_rope = False
|
|
|
+
|
|
|
+ if config.sequence_parallel:
|
|
|
+ ColumnParallelLinear = ColumnSequenceParallelLinear
|
|
|
+ RowParallelLinear = RowSequenceParallelLinear
|
|
|
+
|
|
|
+ if config.tensor_parallel_degree > 1:
|
|
|
+ if self.fuse_attention_qkv:
|
|
|
+ self.qkv_proj = ColumnParallelLinear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.hidden_size
|
|
|
+ + 2 * self.config.num_key_value_heads * self.head_dim,
|
|
|
+ has_bias=True,
|
|
|
+ gather_output=False,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.q_proj = ColumnParallelLinear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.hidden_size,
|
|
|
+ has_bias=True,
|
|
|
+ gather_output=False,
|
|
|
+ )
|
|
|
+ self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
|
|
|
+ self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
|
|
|
+ self.o_proj = RowParallelLinear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.hidden_size,
|
|
|
+ has_bias=False,
|
|
|
+ input_is_parallel=True,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ if self.fuse_attention_qkv:
|
|
|
+ self.qkv_proj = Linear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.hidden_size
|
|
|
+ + 2 * self.config.num_key_value_heads * self.head_dim,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True)
|
|
|
+ self.k_proj = Linear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.config.num_key_value_heads * self.head_dim,
|
|
|
+ bias_attr=True,
|
|
|
+ )
|
|
|
+ self.v_proj = Linear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.config.num_key_value_heads * self.head_dim,
|
|
|
+ bias_attr=True,
|
|
|
+ )
|
|
|
+ self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False)
|
|
|
+
|
|
|
+ self.rotary_emb = Qwen2RotaryEmbedding(
|
|
|
+ self.head_dim,
|
|
|
+ max_position_embeddings=self.max_position_embeddings,
|
|
|
+ base=self.rope_theta,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.attn_func = scaled_dot_product_attention
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ hidden_states,
|
|
|
+ position_ids: Optional[Tuple[paddle.Tensor]] = None,
|
|
|
+ past_key_value: Optional[Tuple[paddle.Tensor]] = None,
|
|
|
+ attention_mask: Optional[paddle.Tensor] = None,
|
|
|
+ output_attentions: bool = False,
|
|
|
+ use_cache: bool = False,
|
|
|
+ attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
|
|
|
+ **kwargs,
|
|
|
+ ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
|
|
|
+ """Input shape: Batch x Time x Channel"""
|
|
|
+ # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
|
|
|
+
|
|
|
+ if self.fuse_attention_qkv:
|
|
|
+ mix_layer = self.qkv_proj(hidden_states)
|
|
|
+ if self.sequence_parallel:
|
|
|
+ target_shape = [
|
|
|
+ -1,
|
|
|
+ self.seq_length,
|
|
|
+ self.num_key_value_heads,
|
|
|
+ (self.num_key_value_groups + 2) * self.head_dim,
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ target_shape = [
|
|
|
+ 0,
|
|
|
+ 0,
|
|
|
+ self.num_key_value_heads,
|
|
|
+ (self.num_key_value_groups + 2) * self.head_dim,
|
|
|
+ ]
|
|
|
+ mix_layer = paddle.reshape_(mix_layer, target_shape)
|
|
|
+ query_states, key_states, value_states = paddle.split(
|
|
|
+ mix_layer,
|
|
|
+ num_or_sections=[
|
|
|
+ self.num_key_value_groups * self.head_dim,
|
|
|
+ self.head_dim,
|
|
|
+ self.head_dim,
|
|
|
+ ],
|
|
|
+ axis=-1,
|
|
|
+ )
|
|
|
+ if self.gqa_or_mqa:
|
|
|
+ query_states = paddle.reshape_(
|
|
|
+ query_states, [0, 0, self.num_heads, self.head_dim]
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ query_states = self.q_proj(hidden_states)
|
|
|
+ key_states = self.k_proj(hidden_states)
|
|
|
+ value_states = self.v_proj(hidden_states)
|
|
|
+
|
|
|
+ if self.sequence_parallel:
|
|
|
+ target_query_shape = [
|
|
|
+ -1,
|
|
|
+ self.seq_length,
|
|
|
+ self.num_heads,
|
|
|
+ self.head_dim,
|
|
|
+ ]
|
|
|
+ target_key_value_shape = [
|
|
|
+ -1,
|
|
|
+ self.seq_length,
|
|
|
+ self.num_key_value_heads,
|
|
|
+ self.head_dim,
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ target_query_shape = [0, 0, self.num_heads, self.head_dim]
|
|
|
+ target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
|
|
|
+ query_states = query_states.reshape(shape=target_query_shape)
|
|
|
+ key_states = key_states.reshape(shape=target_key_value_shape)
|
|
|
+ value_states = value_states.reshape(shape=target_key_value_shape)
|
|
|
+
|
|
|
+ kv_seq_len = key_states.shape[-3]
|
|
|
+ if past_key_value is not None:
|
|
|
+ kv_seq_len += past_key_value[0].shape[-3]
|
|
|
+ if self.use_fused_rope:
|
|
|
+ assert past_key_value is None, "fuse rotary not support cache kv for now"
|
|
|
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
+ query_states, key_states, _ = fused_rotary_position_embedding(
|
|
|
+ query_states,
|
|
|
+ key_states,
|
|
|
+ v=None,
|
|
|
+ sin=sin,
|
|
|
+ cos=cos,
|
|
|
+ position_ids=position_ids,
|
|
|
+ use_neox_rotary_style=False,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
+ query_states, key_states = apply_rotary_pos_emb(
|
|
|
+ query_states, key_states, cos, sin, position_ids
|
|
|
+ )
|
|
|
+
|
|
|
+ # [bs, seq_len, num_head, head_dim]
|
|
|
+ if past_key_value is not None:
|
|
|
+ key_states = paddle.concat([past_key_value[0], key_states], axis=1)
|
|
|
+ value_states = paddle.concat([past_key_value[1], value_states], axis=1)
|
|
|
+ past_key_value = (key_states, value_states) if use_cache else None
|
|
|
+
|
|
|
+ # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
|
|
|
+ # repeat k/v heads if n_kv_heads < n_heads
|
|
|
+ paddle_version = float(paddle.__version__[:3])
|
|
|
+ if not self.config.use_flash_attention or (
|
|
|
+ (paddle_version != 0.0) and (paddle_version <= 2.6)
|
|
|
+ ):
|
|
|
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
+
|
|
|
+ outputs = self.attn_func(
|
|
|
+ query_states,
|
|
|
+ self.config,
|
|
|
+ key_states,
|
|
|
+ value_states,
|
|
|
+ attention_mask,
|
|
|
+ output_attentions,
|
|
|
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
|
|
|
+ training=self.training,
|
|
|
+ sequence_parallel=self.sequence_parallel,
|
|
|
+ )
|
|
|
+ if output_attentions:
|
|
|
+ attn_output, attn_weights = outputs
|
|
|
+ else:
|
|
|
+ attn_output = outputs
|
|
|
+
|
|
|
+ # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
|
|
|
+ # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
|
|
|
+ attn_output = self.o_proj(attn_output)
|
|
|
+
|
|
|
+ if not output_attentions:
|
|
|
+ attn_weights = None
|
|
|
+
|
|
|
+ outputs = (attn_output,)
|
|
|
+
|
|
|
+ if output_attentions:
|
|
|
+ outputs += (attn_weights,)
|
|
|
+
|
|
|
+ if use_cache:
|
|
|
+ outputs += (past_key_value,)
|
|
|
+
|
|
|
+ if type(outputs) is tuple and len(outputs) == 1:
|
|
|
+ outputs = outputs[0]
|
|
|
+
|
|
|
+ return outputs
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2DecoderLayer(nn.Layer):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: Qwen2Config,
|
|
|
+ layerwise_recompute: bool = False,
|
|
|
+ skip_recompute_ops=None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ if skip_recompute_ops is None:
|
|
|
+ skip_recompute_ops = {}
|
|
|
+ self.config = config
|
|
|
+ self.skip_recompute_ops = skip_recompute_ops
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
+ self.self_attn = Qwen2Attention(
|
|
|
+ config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops
|
|
|
+ )
|
|
|
+
|
|
|
+ self.mlp = Qwen2MLP(config, skip_recompute_ops=skip_recompute_ops)
|
|
|
+ self.input_layernorm = Qwen2RMSNorm(config)
|
|
|
+ self.post_attention_layernorm = Qwen2RMSNorm(config)
|
|
|
+
|
|
|
+ # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
|
|
|
+ # Enable_recompute defaults to False and is controlled by Trainer
|
|
|
+ self.enable_recompute = False
|
|
|
+ self.layerwise_recompute = layerwise_recompute
|
|
|
+ self.recompute_granularity = config.recompute_granularity
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ hidden_states: paddle.Tensor,
|
|
|
+ position_ids: Optional[paddle.Tensor] = None,
|
|
|
+ attention_mask: Optional[paddle.Tensor] = None,
|
|
|
+ output_attentions: Optional[bool] = False,
|
|
|
+ past_key_value: Optional[Tuple[paddle.Tensor]] = None,
|
|
|
+ use_cache: Optional[bool] = False,
|
|
|
+ attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
|
|
|
+ **kwargs,
|
|
|
+ ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
|
+ attention_mask (`paddle.Tensor`, *optional*): attention mask of size
|
|
|
+ `(batch, sequence_length)` where padding elements are indicated by 0.
|
|
|
+ output_attentions (`bool`, *optional*):
|
|
|
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
|
+ returned tensors for more detail.
|
|
|
+ use_cache (`bool`, *optional*):
|
|
|
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
|
+ (see `past_key_values`).
|
|
|
+ past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
|
|
|
+ """
|
|
|
+
|
|
|
+ # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel)
|
|
|
+ residual = hidden_states
|
|
|
+
|
|
|
+ hidden_states = self.input_layernorm(hidden_states)
|
|
|
+
|
|
|
+ # Self Attention
|
|
|
+ outputs = self.self_attn(
|
|
|
+ hidden_states,
|
|
|
+ position_ids,
|
|
|
+ past_key_value,
|
|
|
+ attention_mask,
|
|
|
+ output_attentions,
|
|
|
+ use_cache,
|
|
|
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
|
|
|
+ )
|
|
|
+
|
|
|
+ if type(outputs) is tuple:
|
|
|
+ hidden_states = outputs[0]
|
|
|
+ else:
|
|
|
+ hidden_states = outputs
|
|
|
+
|
|
|
+ if output_attentions:
|
|
|
+ self_attn_weights = outputs[1]
|
|
|
+
|
|
|
+ if use_cache:
|
|
|
+ present_key_value = outputs[2 if output_attentions else 1]
|
|
|
+
|
|
|
+ hidden_states = residual + hidden_states
|
|
|
+
|
|
|
+ # Fully Connected
|
|
|
+ residual = hidden_states
|
|
|
+ hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
+ hidden_states = self.mlp(hidden_states)
|
|
|
+
|
|
|
+ hidden_states = residual + hidden_states
|
|
|
+
|
|
|
+ outputs = (hidden_states,)
|
|
|
+
|
|
|
+ if output_attentions:
|
|
|
+ outputs += (self_attn_weights,)
|
|
|
+
|
|
|
+ if use_cache:
|
|
|
+ outputs += (present_key_value,)
|
|
|
+
|
|
|
+ if type(outputs) is tuple and len(outputs) == 1:
|
|
|
+ outputs = outputs[0]
|
|
|
+
|
|
|
+ return outputs
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2PretrainedModel(PretrainedModel):
|
|
|
+ config_class = Qwen2Config
|
|
|
+ base_model_prefix = "qwen2"
|
|
|
+ _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"]
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True):
|
|
|
+
|
|
|
+ from paddlenlp.transformers.conversion_utils import split_or_merge_func
|
|
|
+
|
|
|
+ fn = split_or_merge_func(
|
|
|
+ is_split=is_split,
|
|
|
+ tensor_parallel_degree=config.tensor_parallel_degree,
|
|
|
+ tensor_parallel_rank=config.tensor_parallel_rank,
|
|
|
+ num_attention_heads=config.num_attention_heads,
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_tensor_parallel_split_mappings(num_layers):
|
|
|
+ final_actions = {}
|
|
|
+
|
|
|
+ base_actions = {
|
|
|
+ # Row Linear
|
|
|
+ "embed_tokens.weight": partial(fn, is_column=False),
|
|
|
+ "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
|
|
|
+ "layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
|
|
|
+ }
|
|
|
+
|
|
|
+ if config.tie_word_embeddings:
|
|
|
+ base_actions["lm_head.weight"] = partial(fn, is_column=False)
|
|
|
+ else:
|
|
|
+ base_actions["lm_head.weight"] = partial(fn, is_column=True)
|
|
|
+
|
|
|
+ if not config.vocab_size % config.tensor_parallel_degree == 0:
|
|
|
+ base_actions.pop("lm_head.weight")
|
|
|
+ base_actions.pop("embed_tokens.weight")
|
|
|
+ # Column Linear
|
|
|
+ if config.fuse_attention_qkv:
|
|
|
+ base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+ base_actions["layers.0.self_attn.qkv_proj.bias"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ base_actions["layers.0.self_attn.q_proj.weight"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+ base_actions["layers.0.self_attn.q_proj.bias"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+ # if we have enough num_key_value_heads to split, then split it.
|
|
|
+ if config.num_key_value_heads % config.tensor_parallel_degree == 0:
|
|
|
+ base_actions["layers.0.self_attn.k_proj.weight"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+ base_actions["layers.0.self_attn.v_proj.weight"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+ base_actions["layers.0.self_attn.k_proj.bias"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+ base_actions["layers.0.self_attn.v_proj.bias"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+
|
|
|
+ if config.fuse_attention_ffn:
|
|
|
+ base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
|
|
|
+ fn, is_column=True, is_naive_2fuse=True
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ base_actions["layers.0.mlp.gate_proj.weight"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+ base_actions["layers.0.mlp.up_proj.weight"] = partial(
|
|
|
+ fn, is_column=True
|
|
|
+ )
|
|
|
+
|
|
|
+ for key, action in base_actions.items():
|
|
|
+ if "layers.0." in key:
|
|
|
+ for i in range(num_layers):
|
|
|
+ final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
|
|
|
+ final_actions[key] = action
|
|
|
+
|
|
|
+ return final_actions
|
|
|
+
|
|
|
+ mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
|
|
|
+
|
|
|
+ return mappings
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_fuse_or_split_param_mappings(cls, config: Qwen2Config, is_fuse=False):
|
|
|
+ # return parameter fuse utils
|
|
|
+ from paddlenlp.transformers.conversion_utils import split_or_fuse_func
|
|
|
+
|
|
|
+ fn = split_or_fuse_func(is_fuse=is_fuse)
|
|
|
+
|
|
|
+ # last key is fused key, other keys are to be fused.
|
|
|
+ fuse_qkv_keys = [
|
|
|
+ (
|
|
|
+ "layers.0.self_attn.q_proj.weight",
|
|
|
+ "layers.0.self_attn.k_proj.weight",
|
|
|
+ "layers.0.self_attn.v_proj.weight",
|
|
|
+ "layers.0.self_attn.qkv_proj.weight",
|
|
|
+ ),
|
|
|
+ (
|
|
|
+ "layers.0.self_attn.q_proj.bias",
|
|
|
+ "layers.0.self_attn.k_proj.bias",
|
|
|
+ "layers.0.self_attn.v_proj.bias",
|
|
|
+ "layers.0.self_attn.qkv_proj.bias",
|
|
|
+ ),
|
|
|
+ ]
|
|
|
+
|
|
|
+ fuse_gate_up_keys = (
|
|
|
+ "layers.0.mlp.gate_proj.weight",
|
|
|
+ "layers.0.mlp.up_proj.weight",
|
|
|
+ "layers.0.mlp.gate_up_fused_proj.weight",
|
|
|
+ )
|
|
|
+ num_heads = config.num_attention_heads
|
|
|
+ num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
|
|
|
+ fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
|
|
|
+ fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)
|
|
|
+
|
|
|
+ final_actions = {}
|
|
|
+ if is_fuse:
|
|
|
+ if fuse_attention_qkv:
|
|
|
+ for i in range(config.num_hidden_layers):
|
|
|
+ for fuse_keys in fuse_qkv_keys:
|
|
|
+ keys = tuple(
|
|
|
+ [
|
|
|
+ key.replace("layers.0.", f"layers.{i}.")
|
|
|
+ for key in fuse_keys
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ final_actions[keys] = partial(
|
|
|
+ fn,
|
|
|
+ is_qkv=True,
|
|
|
+ num_heads=num_heads,
|
|
|
+ num_key_value_heads=num_key_value_heads,
|
|
|
+ )
|
|
|
+ if fuse_attention_ffn:
|
|
|
+ for i in range(config.num_hidden_layers):
|
|
|
+ keys = tuple(
|
|
|
+ [
|
|
|
+ key.replace("layers.0.", f"layers.{i}.")
|
|
|
+ for key in fuse_gate_up_keys
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ final_actions[keys] = fn
|
|
|
+ else:
|
|
|
+ if not fuse_attention_qkv:
|
|
|
+ for i in range(config.num_hidden_layers):
|
|
|
+ for fuse_keys in fuse_qkv_keys:
|
|
|
+ keys = tuple(
|
|
|
+ [
|
|
|
+ key.replace("layers.0.", f"layers.{i}.")
|
|
|
+ for key in fuse_keys
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ final_actions[keys] = partial(
|
|
|
+ fn,
|
|
|
+ split_nums=3,
|
|
|
+ is_qkv=True,
|
|
|
+ num_heads=num_heads,
|
|
|
+ num_key_value_heads=num_key_value_heads,
|
|
|
+ )
|
|
|
+ if not fuse_attention_ffn:
|
|
|
+ for i in range(config.num_hidden_layers):
|
|
|
+ keys = tuple(
|
|
|
+ [
|
|
|
+ key.replace("layers.0.", f"layers.{i}.")
|
|
|
+ for key in fuse_gate_up_keys
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ final_actions[keys] = partial(fn, split_nums=2)
|
|
|
+ return final_actions
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2Model(Qwen2PretrainedModel):
|
|
|
+ """
|
|
|
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config: Qwen2Config
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, config: Qwen2Config):
|
|
|
+ super().__init__(config)
|
|
|
+ self.padding_idx = config.pad_token_id
|
|
|
+ self.vocab_size = config.vocab_size
|
|
|
+
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
+ self.sequence_parallel = config.sequence_parallel
|
|
|
+ self.recompute_granularity = config.recompute_granularity
|
|
|
+ self.no_recompute_layers = (
|
|
|
+ config.no_recompute_layers if config.no_recompute_layers is not None else []
|
|
|
+ )
|
|
|
+
|
|
|
+ # Recompute defaults to False and is controlled by Trainer
|
|
|
+ self.enable_recompute = False
|
|
|
+ if (
|
|
|
+ config.tensor_parallel_degree > 1
|
|
|
+ and config.vocab_size % config.tensor_parallel_degree == 0
|
|
|
+ ):
|
|
|
+ self.embed_tokens = mpu.VocabParallelEmbedding(
|
|
|
+ self.vocab_size,
|
|
|
+ self.hidden_size,
|
|
|
+ weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.embed_tokens = nn.Embedding(
|
|
|
+ self.vocab_size,
|
|
|
+ self.hidden_size,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.layers = nn.LayerList(
|
|
|
+ [
|
|
|
+ Qwen2DecoderLayer(
|
|
|
+ config=config,
|
|
|
+ layerwise_recompute=layer_idx not in self.no_recompute_layers,
|
|
|
+ )
|
|
|
+ for layer_idx in range(config.num_hidden_layers)
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ self.norm = Qwen2RMSNorm(config)
|
|
|
+
|
|
|
+ def get_input_embeddings(self):
|
|
|
+ return self.embed_tokens
|
|
|
+
|
|
|
+ def set_input_embeddings(self, value):
|
|
|
+ self.embed_tokens = value
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _prepare_decoder_attention_mask(
|
|
|
+ attention_mask, input_shape, past_key_values_length, dtype
|
|
|
+ ):
|
|
|
+ if attention_mask is not None:
|
|
|
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
|
+ if len(attention_mask.shape) == 2:
|
|
|
+ expanded_attn_mask = _expand_2d_mask(
|
|
|
+ attention_mask, dtype, tgt_length=input_shape[-1]
|
|
|
+ )
|
|
|
+ # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
|
|
|
+ if input_shape[-1] > 1:
|
|
|
+ combined_attention_mask = _make_causal_mask(
|
|
|
+ input_shape,
|
|
|
+ past_key_values_length=past_key_values_length,
|
|
|
+ )
|
|
|
+ expanded_attn_mask = expanded_attn_mask & combined_attention_mask
|
|
|
+ # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
|
|
|
+ elif len(attention_mask.shape) == 3:
|
|
|
+ expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
|
|
|
+ # if attention_mask is already 4-D, do nothing
|
|
|
+ else:
|
|
|
+ expanded_attn_mask = attention_mask
|
|
|
+ else:
|
|
|
+ expanded_attn_mask = _make_causal_mask(
|
|
|
+ input_shape,
|
|
|
+ past_key_values_length=past_key_values_length,
|
|
|
+ )
|
|
|
+ # Convert bool attention_mask to float attention mask, which will be added to attention_scores later
|
|
|
+ if get_device_type() == "xpu":
|
|
|
+ x = paddle.to_tensor(0.0, dtype="float32")
|
|
|
+ y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
|
|
|
+ expanded_attn_mask = paddle.where(expanded_attn_mask, x, y)
|
|
|
+ else:
|
|
|
+ expanded_attn_mask = paddle.where(
|
|
|
+ expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min
|
|
|
+ ).astype(dtype)
|
|
|
+ return expanded_attn_mask
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_ids: paddle.Tensor = None,
|
|
|
+ position_ids: Optional[paddle.Tensor] = None,
|
|
|
+ attention_mask: Optional[paddle.Tensor] = None,
|
|
|
+ inputs_embeds: Optional[paddle.Tensor] = None,
|
|
|
+ use_cache: Optional[bool] = None,
|
|
|
+ past_key_values: Optional[List[paddle.Tensor]] = None,
|
|
|
+ output_attentions: Optional[bool] = None,
|
|
|
+ output_hidden_states: Optional[bool] = None,
|
|
|
+ return_dict: Optional[bool] = None,
|
|
|
+ attn_mask_startend_row_indices=None,
|
|
|
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
|
+
|
|
|
+ output_attentions = (
|
|
|
+ output_attentions
|
|
|
+ if output_attentions is not None
|
|
|
+ else self.config.output_attentions
|
|
|
+ )
|
|
|
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip
|
|
|
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
+ return_dict = (
|
|
|
+ return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
+ )
|
|
|
+
|
|
|
+ # retrieve input_ids and inputs_embeds
|
|
|
+ if input_ids is not None and inputs_embeds is not None:
|
|
|
+ raise ValueError(
|
|
|
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
|
+ )
|
|
|
+ elif input_ids is not None:
|
|
|
+ batch_size, seq_length = input_ids.shape
|
|
|
+ elif inputs_embeds is not None:
|
|
|
+ batch_size, seq_length, _ = inputs_embeds.shape
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
|
+ )
|
|
|
+
|
|
|
+ if past_key_values is None:
|
|
|
+ past_key_values = tuple([None] * len(self.layers))
|
|
|
+ # NOTE: to make cache can be clear in-time
|
|
|
+ past_key_values = list(past_key_values)
|
|
|
+
|
|
|
+ seq_length_with_past = seq_length
|
|
|
+ cache_length = 0
|
|
|
+ if past_key_values[0] is not None:
|
|
|
+ cache_length = past_key_values[0][0].shape[1]
|
|
|
+ seq_length_with_past += cache_length
|
|
|
+ if inputs_embeds is None:
|
|
|
+ # [bs, seq_len, dim]
|
|
|
+ inputs_embeds = self.embed_tokens(input_ids)
|
|
|
+
|
|
|
+ if self.sequence_parallel:
|
|
|
+ # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
|
|
|
+ bs, seq_len, hidden_size = inputs_embeds.shape
|
|
|
+ inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size])
|
|
|
+ # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
|
|
|
+ inputs_embeds = ScatterOp.apply(inputs_embeds)
|
|
|
+
|
|
|
+ # [bs, seq_len]
|
|
|
+ attention_mask = (
|
|
|
+ paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
|
|
|
+ if attention_mask is None
|
|
|
+ else attention_mask
|
|
|
+ )
|
|
|
+ attention_mask = self._prepare_decoder_attention_mask(
|
|
|
+ attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
|
|
|
+ ) # [bs, 1, seq_len, seq_len]
|
|
|
+ if self.config.use_flash_attention:
|
|
|
+ attention_mask = None if is_casual_mask(attention_mask) else attention_mask
|
|
|
+
|
|
|
+ if position_ids is None:
|
|
|
+ position_ids = paddle.arange(seq_length, dtype="int64").expand(
|
|
|
+ (batch_size, seq_length)
|
|
|
+ )
|
|
|
+
|
|
|
+ hidden_states = inputs_embeds
|
|
|
+
|
|
|
+ # decoder layers
|
|
|
+ all_hidden_states = () if output_hidden_states else None
|
|
|
+ all_self_attns = () if output_attentions else None
|
|
|
+ next_decoder_cache = () if use_cache else None
|
|
|
+
|
|
|
+ for idx, (decoder_layer) in enumerate(self.layers):
|
|
|
+ if output_hidden_states:
|
|
|
+ all_hidden_states += (hidden_states,)
|
|
|
+ past_key_value = (
|
|
|
+ past_key_values[idx] if past_key_values is not None else None
|
|
|
+ )
|
|
|
+
|
|
|
+ has_gradient = not hidden_states.stop_gradient
|
|
|
+ if (
|
|
|
+ self.enable_recompute
|
|
|
+ and idx not in self.no_recompute_layers
|
|
|
+ and has_gradient
|
|
|
+ and self.recompute_granularity == "full"
|
|
|
+ ):
|
|
|
+ layer_outputs = self.recompute_training_full(
|
|
|
+ decoder_layer,
|
|
|
+ hidden_states,
|
|
|
+ position_ids,
|
|
|
+ attention_mask,
|
|
|
+ output_attentions,
|
|
|
+ past_key_value,
|
|
|
+ use_cache,
|
|
|
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ layer_outputs = decoder_layer(
|
|
|
+ hidden_states,
|
|
|
+ position_ids,
|
|
|
+ attention_mask,
|
|
|
+ output_attentions,
|
|
|
+ past_key_value,
|
|
|
+ use_cache,
|
|
|
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
|
|
|
+ )
|
|
|
+
|
|
|
+ # NOTE: clear outdate cache after it has been used for memory saving
|
|
|
+ past_key_value = past_key_values[idx] = None
|
|
|
+ if type(layer_outputs) is tuple:
|
|
|
+ hidden_states = layer_outputs[0]
|
|
|
+ else:
|
|
|
+ hidden_states = layer_outputs
|
|
|
+
|
|
|
+ if output_attentions:
|
|
|
+ all_self_attns += (layer_outputs[1],)
|
|
|
+
|
|
|
+ if use_cache:
|
|
|
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
|
+
|
|
|
+ hidden_states = self.norm(hidden_states)
|
|
|
+
|
|
|
+ # add hidden states from the last decoder layer
|
|
|
+ if output_hidden_states:
|
|
|
+ all_hidden_states += (hidden_states,)
|
|
|
+
|
|
|
+ next_cache = next_decoder_cache if use_cache else None
|
|
|
+
|
|
|
+ if not return_dict:
|
|
|
+ return tuple(
|
|
|
+ v
|
|
|
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
|
|
+ if v is not None
|
|
|
+ )
|
|
|
+ return BaseModelOutputWithPast(
|
|
|
+ last_hidden_state=hidden_states,
|
|
|
+ past_key_values=next_cache,
|
|
|
+ hidden_states=all_hidden_states,
|
|
|
+ attentions=all_self_attns,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2PretrainingCriterion(nn.Layer):
|
|
|
+ """
|
|
|
+ Criterion for Mixtral.
|
|
|
+ It calculates the final loss.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, config: Qwen2Config):
|
|
|
+ super(Qwen2PretrainingCriterion, self).__init__()
|
|
|
+ self.ignore_index = getattr(config, "ignore_index", -100)
|
|
|
+ self.config = config
|
|
|
+ self.enable_parallel_cross_entropy = (
|
|
|
+ config.tensor_parallel_degree > 1 and config.tensor_parallel_output
|
|
|
+ )
|
|
|
+
|
|
|
+ if (
|
|
|
+ self.enable_parallel_cross_entropy
|
|
|
+ ): # and False: # and lm_head is distributed
|
|
|
+ self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index)
|
|
|
+ else:
|
|
|
+ self.loss_func = paddle.nn.CrossEntropyLoss(
|
|
|
+ reduction="none", ignore_index=self.ignore_index
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, prediction_scores, masked_lm_labels):
|
|
|
+ if self.enable_parallel_cross_entropy:
|
|
|
+ if prediction_scores.shape[-1] == self.config.vocab_size:
|
|
|
+ logging.warning(
|
|
|
+ f"enable_parallel_cross_entropy, the vocab_size should be splitted: {prediction_scores.shape[-1]}, {self.config.vocab_size}"
|
|
|
+ )
|
|
|
+ self.loss_func = paddle.nn.CrossEntropyLoss(
|
|
|
+ reduction="none", ignore_index=self.ignore_index
|
|
|
+ )
|
|
|
+
|
|
|
+ with paddle.amp.auto_cast(False):
|
|
|
+ masked_lm_loss = self.loss_func(
|
|
|
+ prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)
|
|
|
+ )
|
|
|
+
|
|
|
+ # skip ignore_index which loss == 0
|
|
|
+ # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
|
|
|
+ # loss = paddle.mean(masked_lm_loss)
|
|
|
+ binary_sequence = paddle.where(
|
|
|
+ masked_lm_loss > 0,
|
|
|
+ paddle.ones_like(masked_lm_loss),
|
|
|
+ paddle.zeros_like(masked_lm_loss),
|
|
|
+ )
|
|
|
+ count = paddle.sum(binary_sequence)
|
|
|
+ if count == 0:
|
|
|
+ loss = paddle.sum(masked_lm_loss * binary_sequence)
|
|
|
+ else:
|
|
|
+ loss = paddle.sum(masked_lm_loss * binary_sequence) / count
|
|
|
+
|
|
|
+ return loss
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2LMHead(nn.Layer):
|
|
|
+ def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=False):
|
|
|
+ super(Qwen2LMHead, self).__init__()
|
|
|
+ self.config = config
|
|
|
+ if (
|
|
|
+ config.tensor_parallel_degree > 1
|
|
|
+ and config.vocab_size % config.tensor_parallel_degree == 0
|
|
|
+ ):
|
|
|
+ vocab_size = config.vocab_size // config.tensor_parallel_degree
|
|
|
+ else:
|
|
|
+ vocab_size = config.vocab_size
|
|
|
+
|
|
|
+ self.transpose_y = transpose_y
|
|
|
+ if transpose_y:
|
|
|
+ if embedding_weights is not None:
|
|
|
+ self.weight = embedding_weights
|
|
|
+ else:
|
|
|
+ self.weight = self.create_parameter(
|
|
|
+ shape=[vocab_size, config.hidden_size],
|
|
|
+ dtype=paddle.get_default_dtype(),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ if vocab_size != config.vocab_size:
|
|
|
+ self.weight = self.create_parameter(
|
|
|
+ shape=[config.hidden_size, vocab_size],
|
|
|
+ dtype=paddle.get_default_dtype(),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.weight = self.create_parameter(
|
|
|
+ shape=[config.hidden_size, vocab_size],
|
|
|
+ dtype=paddle.get_default_dtype(),
|
|
|
+ )
|
|
|
+
|
|
|
+ # Must set distributed attr for Tensor Parallel !
|
|
|
+ self.weight.is_distributed = (
|
|
|
+ True if (vocab_size != config.vocab_size) else False
|
|
|
+ )
|
|
|
+ if self.weight.is_distributed:
|
|
|
+ # for tie_word_embeddings
|
|
|
+ self.weight.split_axis = 0 if self.transpose_y else 1
|
|
|
+
|
|
|
+ def forward(self, hidden_states, tensor_parallel_output=None):
|
|
|
+ if self.config.sequence_parallel:
|
|
|
+ hidden_states = GatherOp.apply(hidden_states)
|
|
|
+ seq_length = self.config.seq_length
|
|
|
+ hidden_states = paddle.reshape_(
|
|
|
+ hidden_states, [-1, seq_length, self.config.hidden_size]
|
|
|
+ )
|
|
|
+
|
|
|
+ if tensor_parallel_output is None:
|
|
|
+ tensor_parallel_output = self.config.tensor_parallel_output
|
|
|
+
|
|
|
+ logits = parallel_matmul(
|
|
|
+ hidden_states,
|
|
|
+ self.weight,
|
|
|
+ transpose_y=self.transpose_y,
|
|
|
+ tensor_parallel_output=tensor_parallel_output,
|
|
|
+ )
|
|
|
+ return logits
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2ForCausalLM(Qwen2PretrainedModel):
|
|
|
+ enable_to_static_method = True
|
|
|
+ _tied_weights_keys = ["lm_head.weight"]
|
|
|
+
|
|
|
+ def __init__(self, config: Qwen2Config):
|
|
|
+ super().__init__(config)
|
|
|
+ self.qwen2 = Qwen2Model(config)
|
|
|
+ if config.tie_word_embeddings:
|
|
|
+ self.lm_head = Qwen2LMHead(
|
|
|
+ config,
|
|
|
+ embedding_weights=self.qwen2.embed_tokens.weight,
|
|
|
+ transpose_y=True,
|
|
|
+ )
|
|
|
+ self.tie_weights()
|
|
|
+ else:
|
|
|
+ self.lm_head = Qwen2LMHead(config)
|
|
|
+ self.criterion = Qwen2PretrainingCriterion(config)
|
|
|
+ self.vocab_size = config.vocab_size
|
|
|
+
|
|
|
+ def get_input_embeddings(self):
|
|
|
+ return self.qwen2.embed_tokens
|
|
|
+
|
|
|
+ def set_input_embeddings(self, value):
|
|
|
+ self.qwen2.embed_tokens = value
|
|
|
+
|
|
|
+ def get_output_embeddings(self):
|
|
|
+ return self.lm_head
|
|
|
+
|
|
|
+ def set_output_embeddings(self, new_embeddings):
|
|
|
+ self.lm_head = new_embeddings
|
|
|
+
|
|
|
+ def set_decoder(self, decoder):
|
|
|
+ self.qwen2 = decoder
|
|
|
+
|
|
|
+ def get_decoder(self):
|
|
|
+ return self.qwen2
|
|
|
+
|
|
|
+ def prepare_inputs_for_generation(
|
|
|
+ self,
|
|
|
+ input_ids,
|
|
|
+ use_cache=False,
|
|
|
+ past_key_values=None,
|
|
|
+ attention_mask=None,
|
|
|
+ inputs_embeds=None,
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
+ batch_size, seq_length = input_ids.shape
|
|
|
+ position_ids = kwargs.get(
|
|
|
+ "position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))
|
|
|
+ )
|
|
|
+ if past_key_values:
|
|
|
+ input_ids = input_ids[:, -1].unsqueeze(axis=-1)
|
|
|
+ position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
|
+
|
|
|
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
|
+ if inputs_embeds is not None and past_key_values is None:
|
|
|
+ model_inputs = {"inputs_embeds": inputs_embeds}
|
|
|
+ else:
|
|
|
+ model_inputs = {"input_ids": input_ids}
|
|
|
+
|
|
|
+ model_inputs.update(
|
|
|
+ {
|
|
|
+ "position_ids": position_ids,
|
|
|
+ "past_key_values": past_key_values,
|
|
|
+ "use_cache": use_cache,
|
|
|
+ "attention_mask": attention_mask,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ return model_inputs
|
|
|
+
|
|
|
+ def _get_model_inputs_spec(self, dtype: str):
|
|
|
+ return {
|
|
|
+ "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
|
|
|
+ "attention_mask": paddle.static.InputSpec(
|
|
|
+ shape=[None, None], dtype="int64"
|
|
|
+ ),
|
|
|
+ "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
|
|
|
+ }
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def update_model_kwargs_for_generation(
|
|
|
+ outputs, model_kwargs, is_encoder_decoder=False
|
|
|
+ ):
|
|
|
+ # update cache
|
|
|
+ if (
|
|
|
+ isinstance(outputs, tuple)
|
|
|
+ and len(outputs) > 1
|
|
|
+ and not isinstance(outputs[1], paddle.Tensor)
|
|
|
+ ):
|
|
|
+ model_kwargs["past_key_values"] = outputs[1]
|
|
|
+
|
|
|
+ if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs:
|
|
|
+ model_kwargs["past_key_values"] = outputs.past_key_values
|
|
|
+
|
|
|
+ # update position_ids
|
|
|
+ if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
|
|
|
+ position_ids = model_kwargs["position_ids"]
|
|
|
+ model_kwargs["position_ids"] = paddle.concat(
|
|
|
+ [position_ids, position_ids[..., -1:] + 1], axis=-1
|
|
|
+ )
|
|
|
+
|
|
|
+ if not is_encoder_decoder and "attention_mask" in model_kwargs:
|
|
|
+ # TODO: support attention mask for other models
|
|
|
+ attention_mask = model_kwargs["attention_mask"]
|
|
|
+ if len(attention_mask.shape) == 2:
|
|
|
+ model_kwargs["attention_mask"] = paddle.concat(
|
|
|
+ [
|
|
|
+ attention_mask,
|
|
|
+ paddle.ones(
|
|
|
+ [attention_mask.shape[0], 1], dtype=attention_mask.dtype
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ axis=-1,
|
|
|
+ )
|
|
|
+ elif len(attention_mask.shape) == 4:
|
|
|
+ model_kwargs["attention_mask"] = paddle.concat(
|
|
|
+ [
|
|
|
+ attention_mask,
|
|
|
+ paddle.ones(
|
|
|
+ [*attention_mask.shape[:3], 1], dtype=attention_mask.dtype
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ axis=-1,
|
|
|
+ )[:, :, -1:, :]
|
|
|
+
|
|
|
+ return model_kwargs
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_ids: paddle.Tensor = None,
|
|
|
+ position_ids: Optional[paddle.Tensor] = None,
|
|
|
+ attention_mask: Optional[paddle.Tensor] = None,
|
|
|
+ inputs_embeds: Optional[paddle.Tensor] = None,
|
|
|
+ labels: Optional[paddle.Tensor] = None,
|
|
|
+ use_cache: Optional[bool] = None,
|
|
|
+ past_key_values: Optional[List[paddle.Tensor]] = None,
|
|
|
+ output_attentions: Optional[bool] = None,
|
|
|
+ output_hidden_states: Optional[bool] = None,
|
|
|
+ return_dict: Optional[bool] = None,
|
|
|
+ attn_mask_startend_row_indices=None,
|
|
|
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
|
+ r"""
|
|
|
+ Args:
|
|
|
+ labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
|
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
|
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+
|
|
|
+ Example:
|
|
|
+
|
|
|
+ ```python
|
|
|
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
|
|
|
+
|
|
|
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
|
|
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
|
+
|
|
|
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
|
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
+
|
|
|
+ >>> # Generate
|
|
|
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
|
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
|
+ ```"""
|
|
|
+
|
|
|
+ output_attentions = (
|
|
|
+ output_attentions
|
|
|
+ if output_attentions is not None
|
|
|
+ else self.config.output_attentions
|
|
|
+ )
|
|
|
+ output_hidden_states = (
|
|
|
+ output_hidden_states
|
|
|
+ if output_hidden_states is not None
|
|
|
+ else self.config.output_hidden_states
|
|
|
+ )
|
|
|
+ return_dict = (
|
|
|
+ return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
+ )
|
|
|
+
|
|
|
+ if attn_mask_startend_row_indices is not None and attention_mask is not None:
|
|
|
+ logging.warning(
|
|
|
+ "You have provided both attn_mask_startend_row_indices and attention_mask. "
|
|
|
+ "The attn_mask_startend_row_indices will be used."
|
|
|
+ )
|
|
|
+ attention_mask = None
|
|
|
+
|
|
|
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
+ outputs = self.qwen2(
|
|
|
+ input_ids=input_ids,
|
|
|
+ position_ids=position_ids,
|
|
|
+ attention_mask=attention_mask,
|
|
|
+ inputs_embeds=inputs_embeds,
|
|
|
+ use_cache=use_cache,
|
|
|
+ past_key_values=past_key_values,
|
|
|
+ output_attentions=output_attentions,
|
|
|
+ output_hidden_states=output_hidden_states,
|
|
|
+ return_dict=return_dict,
|
|
|
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
|
|
|
+ )
|
|
|
+
|
|
|
+ hidden_states = outputs[0]
|
|
|
+
|
|
|
+ # if labels is None,means we need full output, instead of tensor_parallel_output
|
|
|
+ # tensor_parallel_output is together with ParallelCrossEntropy
|
|
|
+ tensor_parallel_output = (
|
|
|
+ self.config.tensor_parallel_output
|
|
|
+ and self.config.tensor_parallel_degree > 1
|
|
|
+ )
|
|
|
+
|
|
|
+ logits = self.lm_head(
|
|
|
+ hidden_states, tensor_parallel_output=tensor_parallel_output
|
|
|
+ )
|
|
|
+ loss = None
|
|
|
+
|
|
|
+ if not return_dict:
|
|
|
+ output = (logits,) + outputs[1:]
|
|
|
+ return (loss,) + output if loss is not None else output
|
|
|
+
|
|
|
+ return CausalLMOutputWithPast(
|
|
|
+ loss=loss,
|
|
|
+ logits=logits,
|
|
|
+ past_key_values=outputs.past_key_values,
|
|
|
+ hidden_states=outputs.hidden_states,
|
|
|
+ attentions=outputs.attentions,
|
|
|
+ )
|