| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363 |
- # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Paddle Ernie model"""
- import contextlib
- import functools
- from functools import partial
- from typing import Optional, Tuple
- import numpy as np
- import paddle
- import paddle.distributed as dist
- import paddle.nn.functional as F
- from paddle import incubate, nn, tensor
- from paddle.autograd import PyLayer
- from paddle.distributed import fleet
- from paddle.distributed.fleet.layers.mpu import mp_ops
- from paddle.distributed.fleet.layers.mpu.mp_layers import (
- ColumnParallelLinear,
- RowParallelLinear,
- VocabParallelEmbedding,
- )
- from paddle.distributed.fleet.meta_parallel import (
- ParallelCrossEntropy,
- get_rng_state_tracker,
- )
- from paddle.distributed.fleet.utils import recompute
- from ......utils import logging
- from ....common.vlm.transformers import PretrainedModel
- from ....common.vlm.transformers.model_outputs import (
- BaseModelOutputWithPastAndCrossAttentions,
- )
- from ._config import PaddleOCRVLConfig
- from ._distributed import (
- AllGatherVarlenOp,
- ColumnParallelLinear,
- ColumnSequenceParallelLinear,
- GatherOp,
- RowParallelLinear,
- RowSequenceParallelLinear,
- RRColumnSequenceParallelLinear,
- RRRowSequenceParallelLinear,
- mark_as_sequence_parallel_parameter,
- parallel_matmul,
- sequence_parallel_sparse_mask_labels,
- )
- from ._fusion_ops import (
- Linear,
- fused_rms_norm_ext,
- fused_swiglu,
- fusion_flash_attention,
- )
- from ._sequence_parallel_utils import ScatterOp
- def calc_lm_head_logits(
- config, hidden_states, weight, bias, tensor_parallel_output=None, training=True
- ):
- """
- Calculate language model head logits with support for various parallelization strategies.
- This is the core function that computes the final output logits for a language model,
- handling sequence parallelism and tensor parallelism configurations.
- Args:
- config (PaddleOCRVLConfig): Model configuration.
- hidden_states (Tensor): Hidden states from the transformer layers
- weight (Tensor): Weight matrix for the language model head
- bias (Tensor): Bias vector for the language model head
- tensor_parallel_output (bool, optional): Override for tensor parallel output behavior.
- If None, uses config.tensor_parallel_output.
- Defaults to None.
- training (bool, optional): Whether in training mode. Defaults to True.
- Returns:
- Tensor: The computed logits for language modeling.
- """
- if config.sequence_parallel:
- if config.use_sparse_head_and_loss_fn:
- pass # Nothing needs to be done.
- else:
- hidden_states = GatherOp.apply(hidden_states)
- max_sequence_length = config.max_sequence_length
- hidden_states = hidden_states.reshape(
- [-1, max_sequence_length, hidden_states.shape[-1]]
- )
- if tensor_parallel_output is None:
- tensor_parallel_output = config.tensor_parallel_output
- logits = parallel_matmul(
- hidden_states,
- weight,
- bias=bias,
- transpose_y=config.tie_word_embeddings,
- tensor_parallel_degree=config.tensor_parallel_degree,
- tensor_parallel_output=tensor_parallel_output,
- fuse_linear=config.fuse_linear,
- training=training,
- )
- return logits
- def subbatch(f, arg_idx, axis, bs, out_idx, use_recompute=False, same_arg_idx={}):
- """
- Converts a function to one that applies to subbatch of an input dimension.
- This is useful for processing large tensors in smaller chunks to reduce memory usage.
- Args:
- f (Callable): Original function to be converted to subbatch processing.
- arg_idx ([int]): Indices of the inputs to be subbatched.
- axis ([int]): Indices of the dimensions to be subbatched for each input.
- bs (int): Subbatch size (number of elements to process at once).
- out_idx (int): Index of the output dimension that needs stacking.
- use_recompute (bool, optional): Whether to use recomputation for memory savings. Defaults to False.
- same_arg_idx (dict, optional): Mapping of argument indices that share the same tensor.
- e.g. {1: 0} means args[1] == args[0], avoiding duplicate slicing.
- Returns:
- Callable: Converted function that processes inputs in subbatches.
- """
- @functools.wraps(f)
- def wrapper(*args, **kwargs):
- assert len(arg_idx) == len(
- axis
- ), "Number of batching args and number of batching dims should match."
- inps = [args[i] for i in arg_idx]
- axis_width = [inp.shape[d] for inp, d in zip(inps, axis)]
- assert len(set(axis_width)) == 1, "Batch sizes should be kept equal."
- inp_axis = {inp: d for inp, d in zip(inps, axis)}
- axis_width = axis_width[0]
- if axis_width < bs:
- return f(*args, **kwargs)
- outs = []
- for slice_at in np.arange(0, axis_width, bs):
- _args = []
- for i, inp in enumerate(args):
- if i in same_arg_idx:
- assert (
- i > same_arg_idx[i]
- ), f"expect i > same_arg_idx[i], but got i: {i} and same_arg_idx[i]: {same_arg_idx[i]}"
- _args.append(_args[same_arg_idx[i]])
- elif i in arg_idx:
- inp = inp.slice(
- [inp_axis[inp]],
- [slice_at],
- [min(inp.shape[inp_axis[inp]], slice_at + bs)],
- )
- _args.append(inp)
- else:
- _args.append(inp)
- if use_recompute:
- out = paddle.distributed.fleet.utils.recompute(f, *_args, **kwargs)
- else:
- out = f(*_args, **kwargs)
- outs.append(out)
- return paddle.concat(outs, out_idx)
- return wrapper
- def _rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return paddle.concat((-x2, x1), axis=-1)
- def _apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
- mrope_section = mrope_section * 2
- cos = paddle.concat(
- [m[i % 3] for i, m in enumerate(cos.split(mrope_section, axis=-1))], axis=-1
- ).unsqueeze(unsqueeze_dim)
- sin = paddle.concat(
- [m[i % 3] for i, m in enumerate(sin.split(mrope_section, axis=-1))], axis=-1
- ).unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (_rotate_half(q) * sin)
- k_embed = (k * cos) + (_rotate_half(k) * sin)
- return q_embed, k_embed
- class FusedDropoutImpl(nn.Layer):
- """
- Fused dropout implementation with residual connection support.
- This layer combines dropout and residual addition in a single operation for better performance,
- particularly on GPU devices. The dropout is conditionally applied based on the probability.
- Args:
- prob (float): Dropout probability (between 0 and 1)
- mode (str): Dropout mode, either 'upscale_in_train' or 'downscale_in_infer'
- Attributes:
- prob (float): Stores the dropout probability
- mode (str): Stores the dropout mode
- dropout (nn.Dropout): The actual dropout layer instance
- """
- def __init__(self, prob, mode):
- """
- Initialize the fused dropout layer.
- Args:
- prob (float): Dropout probability (0 means no dropout)
- mode (str): Dropout mode ('upscale_in_train' or 'downscale_in_infer')
- """
- super().__init__()
- self.prob = prob
- self.mode = mode
- self.dropout = nn.Dropout(p=prob, mode=mode)
- def forward(self, x, y):
- """
- Forward pass of the fused dropout layer.
- Args:
- x (Tensor): Input tensor to potentially apply dropout on
- y (Tensor): Residual tensor to add to the (possibly dropped out) x
- Returns:
- Tensor: Result of x (with optional dropout) + y
- """
- if self.prob > 0:
- x = self.dropout(x)
- output = x + y
- return output
- class RMSNorm(nn.Layer):
- """
- Root Mean Square Layer Normalization (RMSNorm) implementation.
- RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs,
- omitting the mean-centering operation. This provides computational efficiency while maintaining
- good performance.
- """
- def __init__(self, config):
- """
- Initialize RMSNorm layer.
- Args:
- config (PaddleOCRVLConfig): Model configuration.
- """
- super().__init__()
- self.hidden_size = config.hidden_size
- self.weight = paddle.create_parameter(
- shape=[self.hidden_size],
- dtype=paddle.get_default_dtype(),
- default_initializer=nn.initializer.Constant(1.0),
- )
- self.variance_epsilon = config.rms_norm_eps
- self.config = config
- if config.sequence_parallel:
- mark_as_sequence_parallel_parameter(self.weight)
- def forward(self, hidden_states):
- """
- Apply RMS normalization to input hidden states.
- Args:
- hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
- Returns:
- Tensor: Normalized output tensor of same shape as input
- Note:
- - Uses fused kernel if config.fuse_rms_norm is True for better performance
- - Otherwise computes RMSNorm manually:
- 1. Compute variance of features
- 2. Apply reciprocal square root normalization
- 3. Scale by learned weight parameter
- - Maintains original dtype for numerical stability during computation
- """
- if self.config.fuse_rms_norm:
- return fused_rms_norm_ext(
- hidden_states, self.weight, self.variance_epsilon
- )[0].astype(self.weight.dtype)
- with paddle.amp.auto_cast(False):
- variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
- hidden_states = (
- paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
- )
- return hidden_states.astype(self.weight.dtype) * self.weight
- class LayerNorm(nn.LayerNorm):
- """
- Layer Normalization (LayerNorm) implementation with optional optimizations.
- This extends PaddlePaddle's built-in LayerNorm with:
- 1. Sequence parallelism support
- 2. Fast fused kernel implementation option
- 3. Configurable epsilon value
- """
- def __init__(self, config):
- """
- Initialize LayerNorm with configuration.
- Args:
- config (PaddleOCRVLConfig): Model configuration contains normalization parameters and flags.
- """
- super().__init__(config.hidden_size, epsilon=config.rms_norm_eps)
- self.config = config
- if config.sequence_parallel:
- mark_as_sequence_parallel_parameter(self.weight)
- mark_as_sequence_parallel_parameter(self.bias)
- class KeyeRotaryEmbedding(nn.Layer):
- def __init__(self, config: PaddleOCRVLConfig, device=None):
- super().__init__()
- self.rope_kwargs = {}
- if config is None:
- raise NotImplementedError
- else:
- # BC: "rope_type" was originally "type"
- if config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get(
- "rope_type", config.rope_scaling.get("type")
- )
- else:
- self.rope_type = "default"
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get(
- "rope_type", config.rope_scaling.get("type")
- )
- else:
- self.rope_type = "default"
- self.config = config
- if self.rope_type == "default":
- dim = config.head_dim
- inv_freq = 1.0 / (
- config.rope_theta
- ** (paddle.arange(0, dim, 2, dtype="int64").astype("float32") / dim)
- )
- self.attention_scaling = 1.0
- else:
- raise ValueError(f"Unsupported rope type: {self.rope_type}")
- self.register_buffer("inv_freq", inv_freq, persistable=False)
- self.original_inv_freq = self.inv_freq
- @paddle.no_grad()
- def forward(self, x, position_ids):
- # Core RoPE block. In contrast to other models, Keye has different position ids for the grids
- # So we expand the inv_freq to shape (3, ...)
- inv_freq_expanded = (
- self.inv_freq[None, None, :, None]
- .cast("float32")
- .expand((3, position_ids.shape[1], -1, 1))
- )
- position_ids_expanded = position_ids[:, :, None, :].cast(
- "float32"
- ) # shape (3, bs, 1, positions)
- with paddle.amp.auto_cast(enable=False):
- freqs = (
- inv_freq_expanded.cast("float32")
- @ position_ids_expanded.cast("float32")
- ).transpose((0, 1, 3, 2))
- emb = paddle.concat((freqs, freqs), axis=-1)
- cos = emb.cos()
- sin = emb.sin()
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
- return cos.astype(x.dtype), sin.astype(x.dtype)
- class Ernie4_5MLP(nn.Layer):
- """
- Ernie4_5MLP - Gated Multi-Layer Perceptron module used in Ernie model.
- """
- def __init__(self, config, layer_idx=0):
- """
- Initialize the MLP module with configuration options.
- Args:
- config (PaddleOCRVLConfig): Model configurations.
- layer_idx (int): Index of current layer (default: 0)
- """
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- if config.tensor_parallel_degree > 1:
- ColumnLN = (
- ColumnSequenceParallelLinear
- if config.sequence_parallel
- else ColumnParallelLinear
- )
- RowLN = (
- RowSequenceParallelLinear
- if config.sequence_parallel
- else RowParallelLinear
- )
- column_ln_configs = {}
- if (
- config.recompute
- and config.sequence_parallel
- and config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False)
- ):
- ColumnLN = RRColumnSequenceParallelLinear
- column_ln_configs = {"use_rr": True}
- self.up_gate_proj = ColumnLN(
- self.hidden_size,
- self.intermediate_size * 2,
- gather_output=False,
- has_bias=config.use_bias,
- fuse_matmul_bias=config.fuse_linear,
- **column_ln_configs,
- )
- else:
- LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else Linear
- self.up_gate_proj = LinearFN(
- self.hidden_size, self.intermediate_size * 2, bias_attr=config.use_bias
- )
- if config.tensor_parallel_degree > 1:
- row_ln_configs = {}
- if (
- config.recompute
- and config.sequence_parallel
- and config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False)
- ):
- RowLN = RRRowSequenceParallelLinear
- row_ln_configs = {"use_rr": True}
- self.down_proj = RowLN(
- self.intermediate_size,
- self.hidden_size,
- input_is_parallel=True,
- has_bias=config.use_bias,
- fuse_matmul_bias=config.fuse_linear,
- **row_ln_configs,
- )
- else:
- LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else Linear
- self.down_proj = LinearFN(
- self.intermediate_size, self.hidden_size, bias_attr=config.use_bias
- )
- self.fuse_swiglu = config.fuse_swiglu
- if self.fuse_swiglu:
- assert fused_swiglu is not None, "fused_swiglu operator is not found."
- def forward(self, x):
- """
- Forward pass through the MLP module.
- Args:
- x (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
- Returns:
- Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
- Note:
- Implements SwiGLU activation: swish(Wx) * (Vx) where W and V are
- the first and second halves of up_gate_proj output respectively.
- """
- if self.fuse_swiglu:
- x = self.up_gate_proj(x)
- x = fused_swiglu(x)
- else:
- gate, x = self.up_gate_proj(x).chunk(2, axis=-1)
- x = F.silu(gate) * x
- return self.down_proj(x)
- class Ernie4_5Attention(nn.Layer):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config, layer_idx=0):
- """Initialize the attention layer.
- Args:
- config (PaddleOCRVLConfig): Model configuration.
- layer_idx (int, optional): Index in transformer stack. Defaults to 0.
- """
- super().__init__()
- self.layer_idx = layer_idx
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.num_key_value_heads = config.num_key_value_heads
- if getattr(config, "head_dim", None) is None:
- self.head_dim = self.hidden_size // self.num_heads
- else:
- self.head_dim = config.head_dim
- self.is_gqa = (
- config.num_key_value_heads is not None
- and config.num_key_value_heads != self.num_heads
- )
- self.rope_scaling = config.rope_scaling
- self.freq_allocation = config.get("freq_allocation", 0)
- if config.tensor_parallel_degree > 1:
- assert (
- self.num_heads % config.tensor_parallel_degree == 0
- ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
- self.num_heads = self.num_heads // config.tensor_parallel_degree
- if self.is_gqa:
- assert (
- self.num_key_value_heads % config.tensor_parallel_degree == 0
- ), f"num_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
- self.num_key_value_heads = (
- self.num_key_value_heads // config.tensor_parallel_degree
- )
- if self.is_gqa:
- logging.info(
- f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}"
- )
- assert (
- self.num_heads % self.num_key_value_heads == 0
- ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}"
- if getattr(config, "head_dim", None) is None:
- kv_hidden_size = (
- self.hidden_size // self.num_heads * self.num_key_value_heads
- )
- else:
- kv_hidden_size = self.head_dim * config.num_key_value_heads
- q_hidden_size = self.head_dim * config.num_attention_heads
- else:
- q_hidden_size = kv_hidden_size = self.head_dim * config.num_attention_heads
- if config.tensor_parallel_degree > 1:
- column_ln_configs = {}
- ColumnLN = (
- ColumnSequenceParallelLinear
- if config.sequence_parallel
- else ColumnParallelLinear
- )
- RowLN = (
- RowSequenceParallelLinear
- if config.sequence_parallel
- else RowParallelLinear
- )
- if (
- config.recompute
- and config.sequence_parallel
- and config.skip_recompute_ops[layer_idx].get(
- "attention_column_ln", False
- )
- ):
- ColumnLN = RRColumnSequenceParallelLinear
- column_ln_configs = {"use_rr": True}
- if getattr(config, "head_dim", None) is None:
- qkv_hidden_size = (
- self.hidden_size * 3
- if not self.is_gqa
- else self.hidden_size + kv_hidden_size * 2
- )
- else:
- qkv_hidden_size = q_hidden_size + kv_hidden_size * 2
- self.qkv_proj = ColumnLN(
- self.hidden_size,
- qkv_hidden_size,
- has_bias=config.use_bias,
- gather_output=False,
- fuse_matmul_bias=config.fuse_linear,
- **column_ln_configs,
- )
- else:
- LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else Linear
- if getattr(config, "head_dim", None) is None:
- qkv_hidden_size = (
- self.hidden_size * 3
- if not self.is_gqa
- else self.hidden_size + kv_hidden_size * 2
- )
- else:
- qkv_hidden_size = q_hidden_size + kv_hidden_size * 2
- self.qkv_proj = LinearFN(
- self.hidden_size,
- qkv_hidden_size,
- bias_attr=config.use_bias,
- )
- if config.tensor_parallel_degree > 1:
- row_ln_configs = {}
- if (
- config.recompute
- and config.sequence_parallel
- and config.skip_recompute_ops[layer_idx].get("attention_row_ln", False)
- ):
- RowLN = RRRowSequenceParallelLinear
- row_ln_configs = {"use_rr": True}
- self.o_proj = RowLN(
- (
- self.hidden_size
- if getattr(config, "head_dim", None) is None
- else q_hidden_size
- ),
- self.hidden_size,
- has_bias=config.use_bias,
- input_is_parallel=True,
- fuse_matmul_bias=config.fuse_linear,
- **row_ln_configs,
- )
- else:
- LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else Linear
- self.o_proj = LinearFN(
- (
- self.hidden_size
- if getattr(config, "head_dim", None) is None
- else q_hidden_size
- ),
- self.hidden_size,
- bias_attr=config.use_bias,
- )
- self.config = config
- self._rr_flash_attn = None
- if config.recompute and config.skip_recompute_ops[layer_idx].get(
- "flash_attn", False
- ):
- # TODO
- raise NotImplementedError
- self.set_attn_func()
- def set_attn_func(self):
- """Configure attention function based on settings.
- Selects between flash/core attention.
- """
- config = self.config
- if config.use_flash_attention:
- self.attn_func = self._flash_attention_wrapper
- else:
- self.attn_func = self.core_attn
- if config.cachekv_quant:
- # TODO: Support `cachekv_quant`
- raise NotImplementedError
- def forward(
- self,
- hidden_states,
- position_embeddings,
- past_key_value: Optional[Tuple[paddle.Tensor]] = None,
- attention_mask: Optional[paddle.Tensor] = None,
- attn_mask_start_row_indices: Optional[paddle.Tensor] = None,
- position_ids: Optional[Tuple[paddle.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- token_type_ids: Optional[Tuple[paddle.Tensor]] = None, # MLLM
- ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
- """Compute attention outputs.
- Args:
- hidden_states (paddle.Tensor): Input tensor [bsz, seq_len, hidden_size]
- position_embeddings (paddle.Tensor): Position embeddings
- past_key_value (Optional[Tuple[paddle.Tensor, paddle.Tensor]]): Cached key/value states
- attention_mask (Optional[paddle.Tensor]): Attention mask tensor
- attn_mask_start_row_indices (Optional[paddle.Tensor]): Variable length attention indices
- position_ids (Optional[paddle.Tensor]): Position indices for RoPE
- output_attentions (bool): Return attention weights if True
- use_cache (bool): Cache key/value states if True
- Returns:
- Tuple containing:
- - attention_output: [bsz, seq_len, hidden_size]
- - attention_weights: Optional attention probabilities
- - updated_key_value_cache: Optional updated cache
- """
- if token_type_ids is not None:
- token_type_ids = token_type_ids[:, :-1]
- if self.config.sequence_parallel:
- if token_type_ids is not None:
- token_type_ids = token_type_ids.reshape([-1])
- token_type_ids = ScatterOp.apply(token_type_ids)
- token_type_ids.stop_gradient = True
- max_sequence_length = self.config.max_sequence_length
- bsz = (
- hidden_states.shape[0]
- * self.config.tensor_parallel_degree
- // max_sequence_length
- )
- q_len = max_sequence_length
- else:
- bsz, q_len, _ = hidden_states.shape
- query_states = key_states = value_states = mix_layer = None
- mix_layer = self.qkv_proj(hidden_states)
- if self.is_gqa:
- query_states, key_states, value_states = paddle.split(
- mix_layer.reshape([bsz, q_len, -1, self.head_dim]),
- [self.num_heads, self.num_key_value_heads, self.num_key_value_heads],
- axis=2,
- )
- mix_layer = None
- else:
- mix_layer = mix_layer.reshape(
- [bsz, q_len, self.num_heads, 3 * self.head_dim]
- )
- if mix_layer is not None:
- has_gradient = not mix_layer.stop_gradient
- else:
- has_gradient = not (
- query_states.stop_gradient
- and key_states.stop_gradient
- and value_states.stop_gradient
- )
- if (
- self.config.recompute
- and self.config.recompute_granularity == "core_attn"
- and has_gradient
- ):
- assert past_key_value is None, "do not use kv cache in recompute"
- assert not use_cache
- attn_output, attn_weights, past_key_value = recompute(
- self.rope_attn,
- mix_layer,
- query_states,
- key_states,
- value_states,
- position_embeddings,
- attention_mask,
- position_ids,
- output_attentions,
- past_key_value,
- use_cache,
- attn_mask_start_row_indices,
- use_reentrant=self.config.recompute_use_reentrant,
- )
- else:
- attn_output, attn_weights, past_key_value = self.rope_attn(
- mix_layer=mix_layer,
- query_states=query_states,
- key_states=key_states,
- value_states=value_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- position_ids=position_ids,
- output_attentions=output_attentions,
- past_key_value=past_key_value,
- use_cache=use_cache,
- attn_mask_start_row_indices=attn_mask_start_row_indices,
- )
- if self.config.sequence_parallel:
- attn_output = attn_output.reshape([-1, attn_output.shape[-1]])
- attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
- return attn_output, attn_weights, past_key_value
- def _flash_attention_wrapper(
- self,
- q,
- k,
- v,
- attention_mask=None,
- attn_mask_start_row_indices=None,
- seq_length=None,
- ):
- """Optimized flash attention implementation.
- Args:
- q (paddle.Tensor): Query tensor
- k (paddle.Tensor): Key tensor
- v (paddle.Tensor): Value tensor
- attention_mask (Optional[paddle.Tensor]): Attention mask
- attn_mask_start_row_indices (Optional[paddle.Tensor]): Variable length indices
- seq_length (Optional[int]): Sequence length
- Returns:
- paddle.Tensor: Attention output tensor
- """
- return fusion_flash_attention(
- q,
- k,
- v,
- self.training,
- self.config.attention_probs_dropout_prob,
- self.config.use_sparse_flash_attn,
- attention_mask,
- attn_mask_start_row_indices,
- seq_length,
- self.config.use_var_len_flash_attn,
- self._rr_flash_attn if self.training else None,
- )
- def core_attn(
- self,
- q,
- k,
- v,
- attention_mask=None,
- attn_mask_start_row_indices=None,
- seq_length=None,
- ):
- """Standard self-attention implementation.
- Args:
- q (paddle.Tensor): Query tensor
- k (paddle.Tensor): Key tensor
- v (paddle.Tensor): Value tensor
- attention_mask (Optional[paddle.Tensor]): Attention mask
- attn_mask_start_row_indices (Optional[paddle.Tensor]): Variable length indices
- seq_length (Optional[int]): Sequence length
- Returns:
- Tuple[paddle.Tensor, paddle.Tensor]: Attention output and weights
- """
- perm = [
- 0,
- 2,
- 1,
- 3,
- ] # [1, 2, 0, 3] if self.sequence_parallel else [0, 2, 1, 3]
- origin_dtype = q.dtype
- q = tensor.transpose(x=q, perm=perm)
- k = tensor.transpose(x=k, perm=perm)
- v = tensor.transpose(x=v, perm=perm)
- replicate = self.config.num_attention_heads // self.config.num_key_value_heads
- k = paddle.repeat_interleave(k, replicate, axis=1)
- v = paddle.repeat_interleave(v, replicate, axis=1)
- scale_qk_coeff = self.config.scale_qk_coeff * self.head_dim**0.5
- product = paddle.matmul(x=q.scale(1.0 / scale_qk_coeff), y=k, transpose_y=True)
- product = product.cast(paddle.float32)
- if self.config.scale_qk_coeff != 1.0:
- product = product.scale(self.config.scale_qk_coeff)
- if attention_mask is not None:
- attention_mask = attention_mask.cast(paddle.float32)
- if self.config.fuse_softmax_mask:
- weights = incubate.softmax_mask_fuse(product, attention_mask)
- else:
- product = product + attention_mask
- weights = F.softmax(product)
- else:
- weights = incubate.softmax_mask_fuse_upper_triangle(product)
- weights = weights.cast(origin_dtype)
- if self.config.attention_probs_dropout_prob:
- with get_rng_state_tracker().rng_state("local_seed"):
- weights = F.dropout(
- weights,
- self.config.attention_probs_dropout_prob,
- training=self.training,
- mode="upscale_in_train",
- )
- out = paddle.matmul(weights, v)
- # combine heads
- out = tensor.transpose(out, perm=[0, 2, 1, 3])
- # If sequence_parallel is true, out shape is [s, b, h] after reshape
- # else out shape is [b, s, h]
- out = tensor.reshape(x=out, shape=[0, 0, -1])
- return out, weights
- def rope_attn(
- self,
- mix_layer,
- query_states,
- key_states,
- value_states,
- position_embeddings,
- attention_mask,
- position_ids,
- output_attentions=False,
- past_key_value=None,
- use_cache=False,
- attn_mask_start_row_indices=None,
- ):
- if mix_layer is not None:
- query_states, key_states, value_states = paddle.split(mix_layer, 3, axis=-1)
- query_states_dtype = query_states.dtype
- kv_seq_len = position_ids.max() + 1
- offset = 0
- if past_key_value is not None:
- # LLM
- offset = past_key_value[0].shape[-3]
- kv_seq_len += offset
- query_states = query_states.astype(query_states_dtype)
- key_states = key_states.astype(query_states_dtype)
- if position_ids.dim() == 3 and position_ids.shape[0] > 1:
- position_ids = position_ids[0:1]
- cos, sin = position_embeddings
- query_states, key_states = _apply_multimodal_rotary_pos_emb(
- query_states, key_states, cos, sin, self.rope_scaling["mrope_section"], 2
- )
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = paddle.concat([past_key_value[0], key_states], axis=1)
- value_states = paddle.concat([past_key_value[1], value_states], axis=1)
- # NOTE(for generation): use list instead of tuple to store the cache
- # tensors, so that we can clear the cache tensors for memory efficiency.
- past_key_value = [key_states, value_states] if use_cache else None
- seq_length = query_states.shape[1]
- attn_output, attn_weights = self.attn_func(
- query_states,
- key_states,
- value_states,
- attention_mask,
- attn_mask_start_row_indices,
- seq_length,
- )
- return attn_output, attn_weights, past_key_value
- class FusedHeadParallelCrossEntropy(PyLayer):
- """Fused parallel cross-entropy loss computation for large sequence lengths.
- Combines head projection and loss computation with optimized memory usage for long sequences,
- supporting tensor parallel training.
- """
- @staticmethod
- def forward(
- ctx,
- hidden_states,
- weight,
- bias,
- labels,
- tensor_parallel_degree,
- mp_group=None,
- ignore_index=-100,
- seq_chunk_size=8192,
- transpose_y=False,
- fuse_linear=False,
- training=True,
- ):
- """Forward pass for parallel cross-entropy computation.
- Args:
- ctx: Context object for saving tensors between forward/backward
- hidden_states (paddle.Tensor): Input tensor of shape [batch_size*seq_len, hidden_size]
- weight (paddle.Tensor): Weight matrix for projection
- bias (Optional[paddle.Tensor]): Optional bias vector
- labels (paddle.Tensor): Target labels tensor of shape [batch_size*seq_len]
- tensor_parallel_degree (int): Degree of tensor parallelism
- mp_group (Optional[dist.Group]): Model parallel group. Defaults to None (auto-detect)
- ignore_index (int): Index to ignore in loss computation. Defaults to -100
- seq_chunk_size (int): Chunk size for processing long sequences. Defaults to 8192
- transpose_y (bool): Whether to transpose weight matrix. Defaults to False
- fuse_linear (bool): Whether to use fused linear ops. Defaults to False
- training (bool): Whether in training mode. Defaults to True
- Returns:
- Tuple[paddle.Tensor, paddle.Tensor]:
- - loss: Computed loss tensor
- - gathered_labels: Concatenated labels from all parallel groups
- """
- ctx.tensor_parallel_degree = tensor_parallel_degree
- ctx.ignore_index = ignore_index
- ctx.seq_chunk_size = seq_chunk_size
- ctx.transpose_y = transpose_y
- ctx.fuse_linear = fuse_linear
- ctx.training = training
- ctx.hidden_states_shape = hidden_states.shape
- ctx.mp_group = (
- fleet.get_hybrid_communicate_group().get_model_parallel_group()
- if mp_group is None
- else mp_group
- )
- ctx.rank = ctx.mp_group.rank
- ctx.world_size = ctx.mp_group.nranks
- loss_all = []
- labels_all = []
- with paddle.no_grad():
- labels = labels.reshape_([-1])
- hidden_states = hidden_states.reshape_([-1, hidden_states.shape[-1]])
- num_tokens_per_rank = []
- dist.stream.all_gather(
- num_tokens_per_rank,
- paddle.to_tensor(hidden_states.shape[0], dtype=paddle.int32),
- group=ctx.mp_group,
- )
- ctx.num_tokens_per_rank = num_tokens_per_rank
- for idx in range(ctx.world_size):
- if idx == ctx.rank:
- hidden_states_recv = hidden_states
- labels_recv = labels
- else:
- hidden_states_recv = paddle.empty(
- [ctx.num_tokens_per_rank[idx], hidden_states.shape[-1]],
- dtype=hidden_states.dtype,
- )
- labels_recv = paddle.empty(
- [ctx.num_tokens_per_rank[idx]], dtype=labels.dtype
- )
- dist.stream.broadcast(
- hidden_states_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group
- )
- dist.stream.broadcast(
- labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group
- )
- seq_len = hidden_states_recv.shape[0]
- num_chunk = (seq_len + ctx.seq_chunk_size - 1) // ctx.seq_chunk_size
- loss_chunk = []
- for chunk_idx in range(num_chunk):
- start = chunk_idx * ctx.seq_chunk_size
- end = min(start + ctx.seq_chunk_size, seq_len)
- hidden_states_chunk = hidden_states_recv._slice(start, end)
- labels_chunk = labels_recv._slice(start, end)
- logits = parallel_matmul(
- hidden_states_chunk,
- weight,
- bias=bias,
- transpose_y=ctx.transpose_y,
- tensor_parallel_degree=ctx.tensor_parallel_degree,
- tensor_parallel_output=True,
- fuse_linear=ctx.fuse_linear,
- training=ctx.training,
- )
- with paddle.amp.auto_cast(False):
- loss = mp_ops._c_softmax_with_cross_entropy(
- logits.cast("float32"),
- labels_chunk.unsqueeze(-1),
- group=ctx.mp_group,
- ignore_index=ctx.ignore_index,
- )
- loss_chunk.append(loss)
- loss_all.append(paddle.concat(loss_chunk, axis=0))
- labels_all.append(labels_recv)
- ctx.loss_concat_sections = [loss.shape[0] for loss in loss_all]
- loss_all = paddle.concat(loss_all, axis=0)
- labels_all = paddle.concat(labels_all, axis=0)
- tensor_inputs = [hidden_states, weight, bias, labels]
- ctx.save_for_backward(*tensor_inputs)
- return loss_all, labels_all
- @staticmethod
- def backward(ctx, loss_all_grad, labels_all_grad):
- """Backward pass for parallel cross-entropy computation.
- Args:
- ctx: Context object with saved tensors from forward
- loss_all_grad (paddle.Tensor): Gradient of loss
- labels_all_grad (paddle.Tensor): Gradient of labels (unused)
- Returns:
- Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[paddle.Tensor], None]:
- - hidden_states_grad: Gradient for input hidden states
- - weight_grad: Gradient for weight matrix (None if not trainable)
- - bias_grad: Gradient for bias vector (None if not trainable or not provided)
- - None: Placeholder for labels gradient
- """
- hidden_states, weight, bias, labels = ctx.saved_tensor()
- loss_all_grad_list = paddle.split(
- loss_all_grad, ctx.loss_concat_sections, axis=0
- )
- def detach_variable(inp):
- if inp is None:
- return None
- x = inp.detach()
- x.stop_gradient = inp.stop_gradient
- return x
- if weight.stop_gradient is False:
- weight_main_grad = paddle.zeros(weight.shape, dtype=paddle.float32)
- else:
- weight_main_grad = None
- if bias is not None and bias.stop_gradient is False:
- bias_main_grad = paddle.zeros(bias.shape, dtype=paddle.float32)
- else:
- bias_main_grad = None
- hidden_states = detach_variable(hidden_states)
- weight = detach_variable(weight)
- bias = detach_variable(bias)
- labels = detach_variable(labels)
- with paddle.base.dygraph.guard():
- tracer = paddle.base.framework._dygraph_tracer()
- tracer._has_grad = True
- for idx in range(ctx.world_size):
- if idx == ctx.rank:
- hidden_states_recv = hidden_states
- labels_recv = labels
- else:
- hidden_states_recv = paddle.empty(
- [ctx.num_tokens_per_rank[idx], hidden_states.shape[-1]],
- dtype=hidden_states.dtype,
- )
- labels_recv = paddle.empty(
- [ctx.num_tokens_per_rank[idx]], dtype=labels.dtype
- )
- dist.stream.broadcast(
- hidden_states_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group
- )
- dist.stream.broadcast(
- labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group
- )
- hidden_states_recv.stop_gradient = False
- seq_len = hidden_states_recv.shape[0]
- num_chunk = (seq_len + ctx.seq_chunk_size - 1) // ctx.seq_chunk_size
- for chunk_idx in range(num_chunk):
- start = chunk_idx * ctx.seq_chunk_size
- end = min(start + ctx.seq_chunk_size, seq_len)
- hidden_states_chunk = hidden_states_recv.slice(
- axes=[0], starts=[start], ends=[end]
- )
- labels_chunk = labels_recv._slice(start, end)
- loss_grad_chunk = loss_all_grad_list[idx]._slice(start, end)
- logits = parallel_matmul(
- hidden_states_chunk,
- weight,
- bias=bias,
- transpose_y=ctx.transpose_y,
- tensor_parallel_degree=ctx.tensor_parallel_degree,
- tensor_parallel_output=True,
- fuse_linear=ctx.fuse_linear,
- training=ctx.training,
- )
- with paddle.amp.auto_cast(False):
- loss_chunk = mp_ops._c_softmax_with_cross_entropy(
- logits.cast("float32"),
- labels_chunk.unsqueeze(-1),
- group=ctx.mp_group,
- ignore_index=ctx.ignore_index,
- )
- with paddle.amp.auto_cast(enable=False):
- paddle.autograd.backward(loss_chunk, loss_grad_chunk)
- if weight_main_grad is not None:
- weight_main_grad.add_(weight.grad.cast(paddle.float32))
- weight.clear_gradient(True)
- if bias_main_grad is not None:
- bias_main_grad.add_(bias.grad.cast(paddle.float32))
- bias.clear_gradient(True)
- if idx == ctx.rank:
- hidden_states_grad = hidden_states_recv.grad
- hidden_states_grad = hidden_states_grad.reshape(
- ctx.hidden_states_shape
- )
- if weight_main_grad is not None:
- weight_main_grad = weight_main_grad.astype(weight.dtype)
- if bias_main_grad is not None:
- bias_main_grad = bias_main_grad.astype(bias.dtype)
- return (
- hidden_states_grad,
- weight_main_grad,
- bias_main_grad,
- None,
- )
- class ErniePretrainingCriterion(paddle.nn.Layer):
- """Criterion for ERNIE pretraining task."""
- def __init__(self, config, return_tuple=True):
- """Initialize the pretraining criterion.
- Args:
- config (PaddleOCRVLConfig): Model configuration.
- return_tuple (bool): Whether to return loss as tuple (loss, loss_sum). Defaults to True.
- """
- super(ErniePretrainingCriterion, self).__init__()
- self.ignored_index = getattr(config, "ignored_index", -100)
- self.config = config
- self.return_tuple = return_tuple
- self.enable_parallel_cross_entropy = (
- config.tensor_parallel_degree > 1 and config.tensor_parallel_output
- )
- if (
- self.enable_parallel_cross_entropy
- ): # and False: # and lm_head is distributed
- logging.info("using parallel cross entroy, take care")
- self.loss_func = ParallelCrossEntropy()
- else:
- self.loss_func = paddle.nn.CrossEntropyLoss(
- reduction="none",
- )
- self.token_balance_loss = config.token_balance_loss
- def forward(self, prediction_scores, masked_lm_labels, loss_mask=None):
- """Compute the pretraining loss.
- Args:
- prediction_scores (Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]):
- Either:
- - Direct logits tensor [batch_size, seq_len, vocab_size]
- - Tuple of (hidden_states, weight, bias) for sparse head computation
- masked_lm_labels (paddle.Tensor): Target labels tensor [batch_size, seq_len]
- loss_mask (Optional[paddle.Tensor]): Optional mask for valid tokens. Defaults to None.
- Returns:
- Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
- - If return_tuple=False: Single loss tensor
- - If return_tuple=True: Tuple of (normalized_loss, sum_loss)
- """
- if self.config.use_sparse_head_and_loss_fn:
- hidden_states, outlinear_weight, outlinear_bias = prediction_scores[:3]
- if self.config.sequence_parallel:
- masked_lm_labels, sparse_label_idx = (
- sequence_parallel_sparse_mask_labels(
- masked_lm_labels, self.ignored_index
- )
- )
- sparse_label_idx = sparse_label_idx.reshape([-1, 1])
- hidden_states = paddle.gather(hidden_states, sparse_label_idx, axis=0)
- hidden_states = AllGatherVarlenOp.apply(hidden_states)
- else:
- masked_lm_labels = masked_lm_labels.flatten()
- sparse_label_idx = paddle.nonzero(
- masked_lm_labels != self.ignored_index
- ).flatten()
- masked_lm_labels = paddle.take_along_axis(
- masked_lm_labels, sparse_label_idx, axis=0
- )
- hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]])
- hidden_states = paddle.take_along_axis(
- hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0
- )
- # `loss_mask` must be reset to None and re-calculate it in ErnieBotPretrainingCriterion
- # when use use_sparse_head_and_loss_fn.
- loss_mask = None
- if self.config.use_recompute_loss_fn:
- offload_kwargs = {}
- if self.config.get("offload_lm_head", False):
- offload_kwargs["offload_indices"] = [1]
- res = recompute(
- self.forward_impl_with_calc_logits,
- masked_lm_labels,
- loss_mask,
- hidden_states,
- outlinear_weight,
- outlinear_bias,
- **offload_kwargs,
- )
- else:
- logits = calc_lm_head_logits(
- self.config,
- hidden_states,
- outlinear_weight,
- outlinear_bias,
- training=self.training,
- )
- res = self.forward_impl(logits, masked_lm_labels, loss_mask)
- elif self.config.use_recompute_loss_fn:
- if self.config.use_fused_head_and_loss_fn:
- res = self.forward_impl_with_fused_head_loss_fn(
- masked_lm_labels, loss_mask, *prediction_scores
- )
- else:
- assert isinstance(prediction_scores, tuple) and len(
- prediction_scores
- ) in [3, 4], prediction_scores
- res = recompute(
- self.forward_impl_with_calc_logits,
- masked_lm_labels,
- loss_mask,
- *prediction_scores,
- )
- else:
- res = self.forward_impl(prediction_scores, masked_lm_labels, loss_mask)
- return res
- def forward_impl_with_fused_head_loss_fn(
- self,
- masked_lm_labels,
- loss_mask,
- hidden_states,
- outlinear_weight,
- outlinear_bias,
- ):
- """Compute loss with fused head and parallel cross-entropy.
- Args:
- masked_lm_labels (paddle.Tensor): Target labels tensor [batch_size, seq_len]
- loss_mask (Optional[paddle.Tensor]): Optional mask for valid tokens
- hidden_states (paddle.Tensor): Hidden states from transformer [batch_size, seq_len, hidden_size]
- outlinear_weight (paddle.Tensor): Weight matrix for output projection
- outlinear_bias (Optional[paddle.Tensor]): Optional bias for output projection
- Returns:
- Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
- Same return format as forward()
- """
- assert (
- self.config.tensor_parallel_degree > 0
- ), "use_fused_head_and_loss_fn require tensor_parallel_degree > 0"
- masked_lm_loss, masked_lm_labels_all = FusedHeadParallelCrossEntropy.apply(
- hidden_states,
- outlinear_weight,
- outlinear_bias,
- masked_lm_labels,
- self.config.tensor_parallel_degree,
- ignore_index=self.ignored_index,
- seq_chunk_size=self.config.get("loss_subbatch_seqlen", 32768),
- transpose_y=self.config.tie_word_embeddings,
- fuse_linear=self.config.fuse_linear,
- training=self.training,
- )
- if loss_mask is None:
- loss_mask = masked_lm_labels_all != self.ignored_index
- if (~loss_mask).all(): # empty span
- logging.warning(
- f"encounter empty span when calculate loss, ignored_index={self.ignored_index}"
- )
- loss = paddle.mean(masked_lm_loss) * 0.0
- loss_sum = masked_lm_loss.sum().detach()
- else:
- loss_mask = loss_mask.reshape([-1]).cast(paddle.float32)
- # 逐位对齐, 全精度聚合
- masked_lm_loss = paddle.sum(
- masked_lm_loss.cast(paddle.float32).reshape([-1]) * loss_mask
- )
- loss = masked_lm_loss / loss_mask.sum()
- if self.token_balance_loss:
- _loss = masked_lm_loss / self.config.token_balance_seqlen
- loss = _loss - _loss.detach() + loss.detach() # for 对线
- loss_sum = masked_lm_loss.sum().detach()
- if not self.return_tuple: # only used in pp
- if self.training:
- return loss
- return loss_sum
- return loss, loss_sum
- def forward_impl_with_calc_logits(
- self,
- masked_lm_labels,
- loss_mask,
- hidden_states,
- outlinear_weight,
- outlinear_bias,
- ):
- """Compute logits then calculate loss.
- Args:
- Same as forward_impl_with_fused_head_loss_fn()
- Returns:
- Same return format as forward()
- """
- logits = calc_lm_head_logits(
- self.config,
- hidden_states,
- outlinear_weight,
- outlinear_bias,
- training=self.training,
- )
- return self.forward_impl(logits, masked_lm_labels, loss_mask)
- def loss_impl(self, prediction_scores, masked_lm_labels):
- """Core loss computation without reduction.
- Args:
- prediction_scores (paddle.Tensor): Logits tensor [batch_size, seq_len, vocab_size]
- masked_lm_labels (paddle.Tensor): Target labels tensor [batch_size, seq_len]
- Returns:
- paddle.Tensor: Unreduced loss tensor
- """
- prediction_scores = prediction_scores.cast("float32")
- masked_lm_loss = self.loss_func(
- prediction_scores, masked_lm_labels.unsqueeze(-1)
- )
- return masked_lm_loss
- def forward_impl(self, prediction_scores, masked_lm_labels, loss_mask=None):
- """Standard loss computation with reduction and masking.
- Args:
- prediction_scores (paddle.Tensor): Logits tensor [batch_size, seq_len, vocab_size]
- masked_lm_labels (paddle.Tensor): Target labels tensor [batch_size, seq_len]
- loss_mask (Optional[paddle.Tensor]): Optional mask for valid tokens
- Returns:
- Same return format as forward()
- """
- if self.enable_parallel_cross_entropy:
- assert prediction_scores.shape[-1] != self.config.vocab_size, (
- f"enable_parallel_cross_entropy, the vocab_size should be splited:"
- f" {prediction_scores.shape[-1]}, {self.config.vocab_size}"
- )
- with paddle.amp.auto_cast(False):
- prediction_scores_dims = len(prediction_scores.shape)
- if prediction_scores_dims == 2 and prediction_scores.shape[
- 0
- ] > self.config.get("loss_subbatch_seqlen", 32768):
- sb_loss_func = subbatch(
- self.loss_impl,
- [0, 1],
- [0, 0],
- self.config.get("loss_subbatch_seqlen", 32768),
- 0,
- )
- masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
- elif prediction_scores_dims == 3 and prediction_scores.shape[
- 1
- ] > self.config.get("loss_subbatch_seqlen", 32768):
- sb_loss_func = subbatch(
- self.loss_impl,
- [0, 1],
- [1, 1],
- self.config.get("loss_subbatch_seqlen", 32768),
- 1,
- )
- masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
- else:
- masked_lm_loss = self.loss_impl(prediction_scores, masked_lm_labels)
- if loss_mask is None:
- loss_mask = masked_lm_labels != self.ignored_index
- lossmask = masked_lm_labels != self.ignored_index
- if (~lossmask).all(): # empty span
- logging.warning(
- f"encounter empty span when calculate loss, ignored_index={self.ignored_index}"
- )
- loss = paddle.mean(masked_lm_loss) * 0.0
- loss_sum = masked_lm_loss.sum().detach()
- else:
- loss_mask = loss_mask.reshape([-1]).cast(paddle.float32)
- # 逐位对齐, 全精度聚合
- masked_lm_loss = paddle.sum(
- masked_lm_loss.cast(paddle.float32).reshape([-1]) * loss_mask
- )
- loss = masked_lm_loss / loss_mask.sum()
- if self.token_balance_loss:
- _loss = masked_lm_loss / self.config.token_balance_seqlen
- loss = _loss - _loss.detach() + loss.detach() # for 对线
- loss_sum = masked_lm_loss.sum().detach()
- if not self.return_tuple: # only used in pp
- if self.training:
- return loss
- return loss_sum
- return loss, loss_sum
- class Ernie4_5LMHead(nn.Layer):
- """Language model head for ERNIE with support for tensor parallelism."""
- def __init__(self, config):
- """Initialize the language model head.
- Args:
- config (PaddleOCRVLConfig): Model configuration containing:
- - vocab_size: Size of vocabulary
- - hidden_size: Dimension of hidden states
- - tensor_parallel_degree: Degree of tensor parallelism
- - tie_word_embeddings: Whether to tie input/output embeddings
- - weight_share_add_bias: Whether to add bias when weight sharing
- - use_bias: Whether to use bias term
- - use_recompute_loss_fn: Whether to defer logits computation to loss function
- - use_sparse_head_and_loss_fn: Whether to use sparse head computation
- """
- super(Ernie4_5LMHead, self).__init__()
- self.config = config
- if config.tensor_parallel_degree > 1:
- vocab_size = config.vocab_size // config.tensor_parallel_degree
- else:
- vocab_size = config.vocab_size
- self.weight = self.create_parameter(
- shape=(
- [vocab_size, config.hidden_size]
- if config.tie_word_embeddings
- else [config.hidden_size, vocab_size]
- ),
- dtype=paddle.get_default_dtype(),
- )
- logging.info(
- f"output-weight:{self.weight.shape} config.tie_word_embeddings={config.tie_word_embeddings}"
- )
- if config.weight_share_add_bias and config.use_bias:
- self.bias = self.create_parameter(
- shape=[vocab_size],
- dtype=paddle.get_default_dtype(),
- attr=paddle.ParamAttr(
- initializer=paddle.nn.initializer.constant.Constant(0.0)
- ),
- )
- else:
- self.bias = None
- # Must set distributed attr for Tensor Parallel !
- self.weight.is_distributed = (
- True if (vocab_size != config.vocab_size) else False
- )
- if config.weight_share_add_bias and config.use_bias:
- self.bias.is_distributed = (
- True if (vocab_size != config.vocab_size) else False
- )
- if self.weight.is_distributed:
- self.weight.split_axis = 1
- if (
- config.weight_share_add_bias
- and config.use_bias
- and self.bias.is_distributed
- ):
- self.bias.split_axis = 0
- if self.config.use_recompute_loss_fn:
- logging.info(
- "Using recompute_loss_fn, the calculation of logits will be moved into "
- "loss_fn for memory optimization"
- )
- def forward(self, hidden_states, tensor_parallel_output=None):
- """Project hidden states to vocabulary logits.
- Args:
- hidden_states (paddle.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
- tensor_parallel_output (Optional[bool]): Whether to output parallel results. Defaults to None.
- Returns:
- Union[
- Tuple[paddle.Tensor, paddle.Tensor, Optional[paddle.Tensor]]:
- # When use_recompute_loss_fn or use_sparse_head_and_loss_fn
- - hidden_states: Original input
- - weight: Projection weights
- - bias: Optional bias term
- Tuple[paddle.Tensor, paddle.Tensor, Optional[paddle.Tensor], bool]: # With tensor_parallel_output
- Same as above plus tensor_parallel_output flag
- paddle.Tensor: # Normal case
- Logits tensor of shape [batch_size, seq_len, vocab_size]
- ]
- """
- # will enter this branch when:
- # 1. use_recompute_loss_fn or use_sparse_head_and_loss_fn
- # 2. dpo training
- if self.config.use_recompute_loss_fn or self.config.use_sparse_head_and_loss_fn:
- return (
- hidden_states,
- self.weight,
- self.bias,
- self.config.tie_word_embeddings,
- )
- return calc_lm_head_logits(
- self.config,
- hidden_states,
- self.weight,
- self.bias,
- tensor_parallel_output,
- training=self.training,
- )
- class Ernie4_5DecoderLayer(nn.Layer):
- """A single transformer decoder layer in ERNIE model.
- Contains self-attention and feed-forward components,
- support, residual connections, and layer normalization.
- """
- def __init__(self, config, layer_idx):
- """Initialize the decoder layer.
- Args:
- config (PaddleOCRVLConfig): Model configuration.
- layer_idx (int): Index of this layer in the transformer stack
- """
- super().__init__()
- self.hidden_size = config.hidden_size
- self.layer_idx = layer_idx
- self.config = config
- self.self_attn = Ernie4_5Attention(config, layer_idx)
- self.mlp = Ernie4_5MLP(config)
- Norm = RMSNorm if config.use_rmsnorm else LayerNorm
- self.input_layernorm = Norm(config)
- self.post_attention_layernorm = Norm(config)
- self.residual_add1 = FusedDropoutImpl(
- config.hidden_dropout_prob, mode="upscale_in_train"
- )
- self.residual_add2 = FusedDropoutImpl(
- config.hidden_dropout_prob, mode="upscale_in_train"
- )
- if config.sequence_parallel:
- mark_as_sequence_parallel_parameter(self.post_attention_layernorm.weight)
- if not hasattr(config, "disable_ffn_model_parallel"):
- mark_as_sequence_parallel_parameter(self.input_layernorm.weight)
- if config.use_bias:
- mark_as_sequence_parallel_parameter(self.self_attn.o_proj.bias)
- mark_as_sequence_parallel_parameter(self.mlp.down_proj.bias)
- if not config.use_rmsnorm and config.use_bias:
- mark_as_sequence_parallel_parameter(self.post_attention_layernorm.bias)
- mark_as_sequence_parallel_parameter(self.input_layernorm.bias)
- def forward(
- self,
- hidden_states: paddle.Tensor,
- position_embeddings: paddle.Tensor,
- attention_mask: Optional[paddle.Tensor] = None,
- attn_mask_start_row_indices: Optional[paddle.Tensor] = None,
- position_ids: Optional[paddle.Tensor] = None,
- token_type_ids: Optional[paddle.Tensor] = None,
- output_attentions: Optional[bool] = False,
- past_key_value: Optional[Tuple[paddle.Tensor]] = None,
- use_cache: Optional[bool] = False,
- ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
- """Forward pass through the decoder layer.
- Args:
- hidden_states (paddle.Tensor): Input tensor [batch_size, seq_len, hidden_size]
- position_embeddings (paddle.Tensor): Position embeddings
- attention_mask (Optional[paddle.Tensor]): Attention mask tensor
- attn_mask_start_row_indices (Optional[paddle.Tensor]): Indices for variable length attention
- position_ids (Optional[paddle.Tensor]): Position indices for rotary embeddings
- output_attentions (Optional[bool]): Whether to return attention weights
- past_key_value (Optional[Tuple[paddle.Tensor]]): Cached key/value states
- use_cache (Optional[bool]): Whether to cache key/value states
- Returns:
- Union: Various output combinations depending on arguments:
- - Base case: Hidden states tensor
- - With attention: Tuple of (hidden_states, attention_weights)
- - With cache: Tuple of (hidden_states, cached_key_value)
- """
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- has_gradient = not hidden_states.stop_gradient
- if (
- self.config.recompute
- and self.config.recompute_granularity == "full_attn"
- and has_gradient
- ):
- hidden_states, self_attn_weights, present_key_value = recompute(
- self.self_attn,
- hidden_states,
- position_embeddings,
- past_key_value,
- attention_mask,
- attn_mask_start_row_indices,
- position_ids,
- output_attentions,
- use_cache,
- use_reentrant=self.config.recompute_use_reentrant,
- )
- else:
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- attn_mask_start_row_indices=attn_mask_start_row_indices,
- position_ids=position_ids,
- output_attentions=output_attentions,
- use_cache=use_cache,
- token_type_ids=token_type_ids,
- )
- with self.model_parallel_dropout():
- hidden_states = self.residual_add1(hidden_states, residual)
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- with self.model_parallel_dropout():
- hidden_states = self.residual_add2(hidden_states, residual)
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
- # remove empty tuple for pipeline parallel
- if type(outputs) is tuple and len(outputs) == 1:
- outputs = outputs[0]
- return outputs
- def model_parallel_dropout(self):
- """Get context manager for model-parallel dropout with proper seed control.
- Returns:
- Context manager for dropout operation
- """
- if (
- self.config.tensor_parallel_degree > 1
- and self.config.hidden_dropout_prob > 0.0
- ):
- current_seed = (
- "local_seed" if self.config.sequence_parallel else "global_seed"
- )
- return get_rng_state_tracker().rng_state(current_seed)
- return contextlib.nullcontext()
- class Ernie4_5PretrainedModel(PretrainedModel):
- """Base class for ERNIE pretrained models."""
- config_class = PaddleOCRVLConfig
- base_model_prefix = "ernie"
- @classmethod
- def _get_tensor_parallel_mappings(cls, config, is_split=True):
- """Generate tensor parallel mappings for model conversion.
- Args:
- config (PaddleOCRVLConfig): Model configuration.
- is_split (bool): Whether to generate split mappings (True)
- or merge mappings (False). Defaults to True.
- Returns:
- Dict[str, Callable[[Any], Any]]: Dictionary mapping parameter names
- to their corresponding split/merge functions for tensor parallelism.
- """
- from ..conversion_utils import split_or_merge_func
- fn = split_or_merge_func(
- is_split=is_split,
- tensor_parallel_degree=config.tensor_parallel_degree,
- tensor_parallel_rank=config.tensor_parallel_rank,
- num_attention_heads=config.num_attention_heads,
- )
- def gqa_qkv_split_func(
- weight,
- tensor_parallel_degree,
- tensor_parallel_rank,
- num_attention_heads,
- num_key_value_heads,
- head_dim,
- is_quant=False,
- is_split=True,
- ):
- if is_quant:
- weight = weight.T
- def get_shape(tensor):
- return (
- tensor.get_shape() if hasattr(tensor, "get_shape") else tensor.shape
- )
- def slice_tensor(tensor, start, end):
- shape = get_shape(tensor)
- if len(shape) == 1:
- return tensor[start:end]
- else:
- return tensor[..., start:end]
- q_end = num_attention_heads * head_dim
- k_end = q_end + num_key_value_heads * head_dim
- v_end = k_end + num_key_value_heads * head_dim
- q = slice_tensor(weight, 0, q_end)
- k = slice_tensor(weight, q_end, k_end)
- v = slice_tensor(weight, k_end, v_end)
- def split_tensor(tensor, degree):
- shape = get_shape(tensor)
- size = shape[-1]
- block_size = size // degree
- if hasattr(tensor, "get_shape"):
- return [
- slice_tensor(tensor, i * block_size, (i + 1) * block_size)
- for i in range(degree)
- ]
- else:
- return np.split(tensor, degree, axis=-1)
- q_list = split_tensor(q, tensor_parallel_degree)
- k_list = split_tensor(k, tensor_parallel_degree)
- v_list = split_tensor(v, tensor_parallel_degree)
- if tensor_parallel_rank is None:
- out = [
- np.concatenate([q_i, k_i, v_i], axis=-1)
- for q_i, k_i, v_i in zip(q_list, k_list, v_list)
- ]
- else:
- out = np.concatenate(
- [
- q_list[tensor_parallel_rank],
- k_list[tensor_parallel_rank],
- v_list[tensor_parallel_rank],
- ],
- axis=-1,
- )
- if is_quant:
- out = out.T
- return out
- def gqa_qkv_merge_func(
- weight_list,
- num_attention_heads,
- num_key_value_heads,
- head_dim,
- is_quant=False,
- is_split=False,
- ):
- tensor_parallel_degree = len(weight_list)
- num_attention_heads = num_attention_heads // tensor_parallel_degree
- num_key_value_heads = num_key_value_heads // tensor_parallel_degree
- is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
- def get_shape(tensor):
- return (
- tensor.get_shape() if hasattr(tensor, "get_shape") else tensor.shape
- )
- def slice_tensor(tensor, start, end):
- if len(get_shape(tensor)) == 1:
- return tensor[start:end]
- else:
- return tensor[..., start:end]
- q_list, k_list, v_list = [], [], []
- for weight in weight_list:
- if is_quant:
- weight = weight.T
- q_end = num_attention_heads * head_dim
- k_end = q_end + num_key_value_heads * head_dim
- v_end = k_end + num_key_value_heads * head_dim
- q = slice_tensor(weight, 0, q_end)
- k = slice_tensor(weight, q_end, k_end)
- v = slice_tensor(weight, k_end, v_end)
- q_list.append(q)
- k_list.append(k)
- v_list.append(v)
- merged = q_list + k_list + v_list
- if is_paddle_tensor:
- tensor = paddle.concat(merged, axis=-1)
- if tensor.place.is_gpu_place():
- tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
- else:
- tensor = np.concatenate(merged, axis=-1)
- if is_quant:
- tensor = tensor.T
- return tensor
- if (
- config.num_key_value_heads is not None
- and config.num_key_value_heads != config.num_attention_heads
- ):
- if is_split:
- qkv_fn = partial(
- gqa_qkv_split_func,
- tensor_parallel_degree=config.tensor_parallel_degree,
- tensor_parallel_rank=config.tensor_parallel_rank,
- num_attention_heads=config.num_attention_heads,
- num_key_value_heads=config.num_key_value_heads,
- head_dim=(
- config.hidden_size // config.num_attention_heads
- if config.head_dim is None
- else config.head_dim
- ),
- is_quant=False,
- is_split=True,
- )
- else:
- qkv_fn = partial(
- gqa_qkv_merge_func,
- num_attention_heads=config.num_attention_heads,
- num_key_value_heads=config.num_key_value_heads,
- head_dim=(
- config.hidden_size // config.num_attention_heads
- if config.head_dim is None
- else config.head_dim
- ),
- is_quant=False,
- is_split=False,
- )
- else:
- qkv_fn = partial(fn, is_column=True)
- def get_tensor_parallel_split_mappings(num_hidden_layers):
- final_actions = {}
- base_actions = {
- # Column Linear
- "layers.0.self_attn.qkv_proj.weight": qkv_fn,
- "layers.0.mlp.up_gate_proj.weight": partial(
- fn, is_column=True, is_naive_2fuse=True
- ),
- "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings),
- # Row Linear
- "embed_tokens.weight": partial(fn, is_column=False),
- "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
- "layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
- }
- if config.use_bias:
- base_actions.update(
- {
- # Column Linear
- "layers.0.self_attn.qkv_proj.bias": qkv_fn,
- "layers.0.mlp.up_gate_proj.bias": partial(
- fn, is_column=True, is_naive_2fuse=True
- ),
- "layers.0.mlp.down_proj.bias": lambda x: x[
- :
- ], # convert PySafeSlice to ndarray.
- "lm_head.bias": partial(fn, is_column=True),
- }
- )
- for key, action in base_actions.items():
- if "layers.0." in key:
- for i in range(num_hidden_layers):
- final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
- else:
- final_actions[key] = action
- return final_actions
- mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
- return mappings
- class Ernie4_5Model(Ernie4_5PretrainedModel):
- """The core ERNIE transformer model"""
- def __init__(self, config: PaddleOCRVLConfig):
- """Initialize the ERNIE model architecture.
- Args:
- config (PaddleOCRVLConfig): Model configuration.
- """
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.hidden_size = config.hidden_size
- self.config = config
- if config.tensor_parallel_degree > 1:
- self.embed_tokens = VocabParallelEmbedding(
- self.vocab_size,
- self.hidden_size,
- )
- else:
- self.embed_tokens = nn.Embedding(
- self.vocab_size,
- self.hidden_size,
- )
- self.layers = nn.LayerList(
- [Ernie4_5DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
- )
- Norm = RMSNorm if config.use_rmsnorm else LayerNorm
- self.norm = Norm(config)
- self.rotary_emb = KeyeRotaryEmbedding(config=config)
- self.gradient_checkpointing = False
- def get_input_embeddings(self):
- """Get the input embedding layer.
- Returns:
- nn.Embedding: The embedding layer for input tokens
- """
- return self.embed_tokens
- def set_input_embeddings(self, value):
- """Set new input embeddings.
- Args:
- value (nn.Embedding): New embedding layer to use
- """
- self.embed_tokens = value
- @paddle.jit.not_to_static
- def recompute_training(
- self,
- layer_module,
- hidden_states,
- position_embeddings,
- attention_mask,
- attn_mask_start_row_indices,
- position_ids,
- token_type_ids,
- output_attentions,
- past_key_value,
- use_cache,
- ):
- """Perform gradient checkpointing for memory-efficient training.
- Args:
- layer_module (nn.Layer): Transformer layer to recompute
- hidden_states (paddle.Tensor): Input hidden states
- position_embeddings (paddle.Tensor): Position embeddings
- attention_mask (paddle.Tensor): Attention mask
- attn_mask_start_row_indices (paddle.Tensor): Variable length indices
- position_ids (paddle.Tensor): Position indices
- output_attentions (bool): Whether to output attention weights
- past_key_value (Optional[Tuple[paddle.Tensor]]): Cached key/value states
- use_cache (bool): Whether to cache key/value states
- Returns:
- paddle.Tensor: Output hidden states after recomputation
- """
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs, output_gate_logits=False)
- return custom_forward
- hidden_states = recompute(
- create_custom_forward(layer_module),
- hidden_states,
- position_embeddings,
- attention_mask,
- attn_mask_start_row_indices,
- position_ids,
- token_type_ids,
- output_attentions,
- past_key_value,
- use_cache,
- )
- return hidden_states
- def forward(
- self,
- input_ids=None,
- position_ids=None,
- token_type_ids=None,
- attention_mask=None,
- attn_mask_start_row_indices=None,
- inputs_embeds=None,
- use_cache=None,
- past_key_values=None,
- output_attentions=False,
- output_hidden_states=None,
- return_dict=False,
- ):
- """Forward pass through the ERNIE model.
- Args:
- input_ids (Optional[paddle.Tensor]): Input token IDs
- position_ids (Optional[paddle.Tensor]): Position indices
- attention_mask (Optional[paddle.Tensor]): Attention mask
- attn_mask_start_row_indices (Optional[paddle.Tensor]): Variable length attention indices
- inputs_embeds (Optional[paddle.Tensor]): Precomputed embeddings
- use_cache (Optional[bool]): Whether to cache key/value states
- past_key_values (Optional[Tuple[Tuple[paddle.Tensor]]]): Cached key/value states
- output_attentions (Optional[bool]): Whether to output attention weights
- output_hidden_states (Optional[bool]): Whether to output all hidden states
- return_dict (Optional[bool]): Whether to return dict or tuple
- Returns:
- Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
- Various outputs depending on configuration, including:
- - last_hidden_state: Final layer hidden states
- - past_key_values: Cached key/value states if use_cache=True
- - hidden_states: All hidden states if output_hidden_states=True
- - attentions: Attention weights if output_attentions=True
- """
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
- )
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError(
- "You have to specify either decoder_input_ids or decoder_inputs_embeds"
- )
- if batch_size != 1:
- raise NotImplementedError
- layers = self.layers[: self.config.num_hidden_layers]
- if past_key_values is None:
- past_key_values = tuple([None] * len(layers))
- kv_seq_len = 0
- else:
- kv_seq_len = past_key_values[0][0].shape[1]
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- inputs_embeds = inputs_embeds.astype(self.embed_tokens.weight.dtype)
- if self.config.sequence_parallel:
- inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]])
- inputs_embeds = ScatterOp.apply(inputs_embeds)
- hidden_states = inputs_embeds
- if position_ids is None or position_ids.dim() == 2:
- raise NotImplementedError
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- if attention_mask is None:
- raise NotImplementedError
- causal_mask = self._update_causal_mask(
- attention_mask.astype("int64"),
- inputs_embeds,
- past_key_values,
- output_attentions,
- )
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
- for idx, (decoder_layer) in enumerate(layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- past_key_value = (
- past_key_values[idx] if past_key_values is not None else None
- )
- has_gradient = not hidden_states.stop_gradient
- if (
- self.config.recompute
- and self.config.recompute_granularity == "full"
- and has_gradient
- ):
- layer_outputs = self.recompute_training(
- decoder_layer,
- hidden_states,
- position_embeddings,
- causal_mask,
- attn_mask_start_row_indices,
- position_ids,
- token_type_ids,
- output_attentions,
- past_key_value,
- use_cache,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- position_embeddings,
- causal_mask,
- attn_mask_start_row_indices,
- position_ids,
- token_type_ids,
- output_attentions,
- past_key_value,
- use_cache,
- )
- if isinstance(layer_outputs, (tuple, list)):
- hidden_states = layer_outputs[0]
- else:
- hidden_states = layer_outputs
- if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- hidden_states = self.norm(hidden_states)
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- next_cache = next_decoder_cache if use_cache else None
- if not return_dict:
- return tuple(
- v
- for v in [
- hidden_states,
- next_cache,
- all_hidden_states,
- all_self_attns,
- ]
- if v is not None
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- cross_attentions=None,
- )
- def _update_causal_mask(
- self,
- attention_mask: paddle.Tensor,
- input_tensor: paddle.Tensor,
- past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]],
- output_attentions: bool = False,
- ):
- past_seen_tokens = (
- past_key_values[0][0].shape[1]
- if past_key_values is not None and past_key_values[0] is not None
- else 0
- )
- dtype = input_tensor.dtype
- min_dtype = paddle.finfo(dtype).min
- sequence_length = input_tensor.shape[1]
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, paddle.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
- cache_position = paddle.arange(
- past_seen_tokens, past_seen_tokens + sequence_length
- )
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
- return causal_mask
- @staticmethod
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length: int,
- target_length: int,
- dtype,
- cache_position,
- batch_size: int,
- ):
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = paddle.finfo(dtype).min
- causal_mask = paddle.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype
- )
- diagonal_attend_mask = paddle.arange(
- target_length
- ) > cache_position.reshape((-1, 1))
- diagonal_attend_mask = diagonal_attend_mask.astype(causal_mask.dtype)
- causal_mask *= diagonal_attend_mask
- causal_mask = causal_mask[None, None, :, :].expand((batch_size, 1, -1, -1))
- if attention_mask is not None:
- causal_mask = (
- causal_mask.clone()
- ) # copy to contiguous memory for in-place edit
- if attention_mask.shape[-1] > target_length:
- attention_mask = attention_mask[:, :target_length]
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
- :, None, None, :
- ].astype(causal_mask.dtype)
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[
- :, :, :, :mask_length
- ].masked_fill(padding_mask, min_dtype)
- return causal_mask
|