_ernie.py 88 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Paddle Ernie model"""
  15. import contextlib
  16. import functools
  17. from functools import partial
  18. from typing import Optional, Tuple
  19. import numpy as np
  20. import paddle
  21. import paddle.distributed as dist
  22. import paddle.nn.functional as F
  23. from paddle import incubate, nn, tensor
  24. from paddle.autograd import PyLayer
  25. from paddle.distributed import fleet
  26. from paddle.distributed.fleet.layers.mpu import mp_ops
  27. from paddle.distributed.fleet.layers.mpu.mp_layers import (
  28. ColumnParallelLinear,
  29. RowParallelLinear,
  30. VocabParallelEmbedding,
  31. )
  32. from paddle.distributed.fleet.meta_parallel import (
  33. ParallelCrossEntropy,
  34. get_rng_state_tracker,
  35. )
  36. from paddle.distributed.fleet.utils import recompute
  37. from ......utils import logging
  38. from ....common.vlm.transformers import PretrainedModel
  39. from ....common.vlm.transformers.model_outputs import (
  40. BaseModelOutputWithPastAndCrossAttentions,
  41. )
  42. from ._config import PaddleOCRVLConfig
  43. from ._distributed import (
  44. AllGatherVarlenOp,
  45. ColumnParallelLinear,
  46. ColumnSequenceParallelLinear,
  47. GatherOp,
  48. RowParallelLinear,
  49. RowSequenceParallelLinear,
  50. RRColumnSequenceParallelLinear,
  51. RRRowSequenceParallelLinear,
  52. mark_as_sequence_parallel_parameter,
  53. parallel_matmul,
  54. sequence_parallel_sparse_mask_labels,
  55. )
  56. from ._fusion_ops import (
  57. Linear,
  58. fused_rms_norm_ext,
  59. fused_swiglu,
  60. fusion_flash_attention,
  61. )
  62. from ._sequence_parallel_utils import ScatterOp
  63. def calc_lm_head_logits(
  64. config, hidden_states, weight, bias, tensor_parallel_output=None, training=True
  65. ):
  66. """
  67. Calculate language model head logits with support for various parallelization strategies.
  68. This is the core function that computes the final output logits for a language model,
  69. handling sequence parallelism and tensor parallelism configurations.
  70. Args:
  71. config (PaddleOCRVLConfig): Model configuration.
  72. hidden_states (Tensor): Hidden states from the transformer layers
  73. weight (Tensor): Weight matrix for the language model head
  74. bias (Tensor): Bias vector for the language model head
  75. tensor_parallel_output (bool, optional): Override for tensor parallel output behavior.
  76. If None, uses config.tensor_parallel_output.
  77. Defaults to None.
  78. training (bool, optional): Whether in training mode. Defaults to True.
  79. Returns:
  80. Tensor: The computed logits for language modeling.
  81. """
  82. if config.sequence_parallel:
  83. if config.use_sparse_head_and_loss_fn:
  84. pass # Nothing needs to be done.
  85. else:
  86. hidden_states = GatherOp.apply(hidden_states)
  87. max_sequence_length = config.max_sequence_length
  88. hidden_states = hidden_states.reshape(
  89. [-1, max_sequence_length, hidden_states.shape[-1]]
  90. )
  91. if tensor_parallel_output is None:
  92. tensor_parallel_output = config.tensor_parallel_output
  93. logits = parallel_matmul(
  94. hidden_states,
  95. weight,
  96. bias=bias,
  97. transpose_y=config.tie_word_embeddings,
  98. tensor_parallel_degree=config.tensor_parallel_degree,
  99. tensor_parallel_output=tensor_parallel_output,
  100. fuse_linear=config.fuse_linear,
  101. training=training,
  102. )
  103. return logits
  104. def subbatch(f, arg_idx, axis, bs, out_idx, use_recompute=False, same_arg_idx={}):
  105. """
  106. Converts a function to one that applies to subbatch of an input dimension.
  107. This is useful for processing large tensors in smaller chunks to reduce memory usage.
  108. Args:
  109. f (Callable): Original function to be converted to subbatch processing.
  110. arg_idx ([int]): Indices of the inputs to be subbatched.
  111. axis ([int]): Indices of the dimensions to be subbatched for each input.
  112. bs (int): Subbatch size (number of elements to process at once).
  113. out_idx (int): Index of the output dimension that needs stacking.
  114. use_recompute (bool, optional): Whether to use recomputation for memory savings. Defaults to False.
  115. same_arg_idx (dict, optional): Mapping of argument indices that share the same tensor.
  116. e.g. {1: 0} means args[1] == args[0], avoiding duplicate slicing.
  117. Returns:
  118. Callable: Converted function that processes inputs in subbatches.
  119. """
  120. @functools.wraps(f)
  121. def wrapper(*args, **kwargs):
  122. assert len(arg_idx) == len(
  123. axis
  124. ), "Number of batching args and number of batching dims should match."
  125. inps = [args[i] for i in arg_idx]
  126. axis_width = [inp.shape[d] for inp, d in zip(inps, axis)]
  127. assert len(set(axis_width)) == 1, "Batch sizes should be kept equal."
  128. inp_axis = {inp: d for inp, d in zip(inps, axis)}
  129. axis_width = axis_width[0]
  130. if axis_width < bs:
  131. return f(*args, **kwargs)
  132. outs = []
  133. for slice_at in np.arange(0, axis_width, bs):
  134. _args = []
  135. for i, inp in enumerate(args):
  136. if i in same_arg_idx:
  137. assert (
  138. i > same_arg_idx[i]
  139. ), f"expect i > same_arg_idx[i], but got i: {i} and same_arg_idx[i]: {same_arg_idx[i]}"
  140. _args.append(_args[same_arg_idx[i]])
  141. elif i in arg_idx:
  142. inp = inp.slice(
  143. [inp_axis[inp]],
  144. [slice_at],
  145. [min(inp.shape[inp_axis[inp]], slice_at + bs)],
  146. )
  147. _args.append(inp)
  148. else:
  149. _args.append(inp)
  150. if use_recompute:
  151. out = paddle.distributed.fleet.utils.recompute(f, *_args, **kwargs)
  152. else:
  153. out = f(*_args, **kwargs)
  154. outs.append(out)
  155. return paddle.concat(outs, out_idx)
  156. return wrapper
  157. def _rotate_half(x):
  158. """Rotates half the hidden dims of the input."""
  159. x1 = x[..., : x.shape[-1] // 2]
  160. x2 = x[..., x.shape[-1] // 2 :]
  161. return paddle.concat((-x2, x1), axis=-1)
  162. def _apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
  163. mrope_section = mrope_section * 2
  164. cos = paddle.concat(
  165. [m[i % 3] for i, m in enumerate(cos.split(mrope_section, axis=-1))], axis=-1
  166. ).unsqueeze(unsqueeze_dim)
  167. sin = paddle.concat(
  168. [m[i % 3] for i, m in enumerate(sin.split(mrope_section, axis=-1))], axis=-1
  169. ).unsqueeze(unsqueeze_dim)
  170. q_embed = (q * cos) + (_rotate_half(q) * sin)
  171. k_embed = (k * cos) + (_rotate_half(k) * sin)
  172. return q_embed, k_embed
  173. class FusedDropoutImpl(nn.Layer):
  174. """
  175. Fused dropout implementation with residual connection support.
  176. This layer combines dropout and residual addition in a single operation for better performance,
  177. particularly on GPU devices. The dropout is conditionally applied based on the probability.
  178. Args:
  179. prob (float): Dropout probability (between 0 and 1)
  180. mode (str): Dropout mode, either 'upscale_in_train' or 'downscale_in_infer'
  181. Attributes:
  182. prob (float): Stores the dropout probability
  183. mode (str): Stores the dropout mode
  184. dropout (nn.Dropout): The actual dropout layer instance
  185. """
  186. def __init__(self, prob, mode):
  187. """
  188. Initialize the fused dropout layer.
  189. Args:
  190. prob (float): Dropout probability (0 means no dropout)
  191. mode (str): Dropout mode ('upscale_in_train' or 'downscale_in_infer')
  192. """
  193. super().__init__()
  194. self.prob = prob
  195. self.mode = mode
  196. self.dropout = nn.Dropout(p=prob, mode=mode)
  197. def forward(self, x, y):
  198. """
  199. Forward pass of the fused dropout layer.
  200. Args:
  201. x (Tensor): Input tensor to potentially apply dropout on
  202. y (Tensor): Residual tensor to add to the (possibly dropped out) x
  203. Returns:
  204. Tensor: Result of x (with optional dropout) + y
  205. """
  206. if self.prob > 0:
  207. x = self.dropout(x)
  208. output = x + y
  209. return output
  210. class RMSNorm(nn.Layer):
  211. """
  212. Root Mean Square Layer Normalization (RMSNorm) implementation.
  213. RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs,
  214. omitting the mean-centering operation. This provides computational efficiency while maintaining
  215. good performance.
  216. """
  217. def __init__(self, config):
  218. """
  219. Initialize RMSNorm layer.
  220. Args:
  221. config (PaddleOCRVLConfig): Model configuration.
  222. """
  223. super().__init__()
  224. self.hidden_size = config.hidden_size
  225. self.weight = paddle.create_parameter(
  226. shape=[self.hidden_size],
  227. dtype=paddle.get_default_dtype(),
  228. default_initializer=nn.initializer.Constant(1.0),
  229. )
  230. self.variance_epsilon = config.rms_norm_eps
  231. self.config = config
  232. if config.sequence_parallel:
  233. mark_as_sequence_parallel_parameter(self.weight)
  234. def forward(self, hidden_states):
  235. """
  236. Apply RMS normalization to input hidden states.
  237. Args:
  238. hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
  239. Returns:
  240. Tensor: Normalized output tensor of same shape as input
  241. Note:
  242. - Uses fused kernel if config.fuse_rms_norm is True for better performance
  243. - Otherwise computes RMSNorm manually:
  244. 1. Compute variance of features
  245. 2. Apply reciprocal square root normalization
  246. 3. Scale by learned weight parameter
  247. - Maintains original dtype for numerical stability during computation
  248. """
  249. if self.config.fuse_rms_norm:
  250. return fused_rms_norm_ext(
  251. hidden_states, self.weight, self.variance_epsilon
  252. )[0].astype(self.weight.dtype)
  253. with paddle.amp.auto_cast(False):
  254. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  255. hidden_states = (
  256. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  257. )
  258. return hidden_states.astype(self.weight.dtype) * self.weight
  259. class LayerNorm(nn.LayerNorm):
  260. """
  261. Layer Normalization (LayerNorm) implementation with optional optimizations.
  262. This extends PaddlePaddle's built-in LayerNorm with:
  263. 1. Sequence parallelism support
  264. 2. Fast fused kernel implementation option
  265. 3. Configurable epsilon value
  266. """
  267. def __init__(self, config):
  268. """
  269. Initialize LayerNorm with configuration.
  270. Args:
  271. config (PaddleOCRVLConfig): Model configuration contains normalization parameters and flags.
  272. """
  273. super().__init__(config.hidden_size, epsilon=config.rms_norm_eps)
  274. self.config = config
  275. if config.sequence_parallel:
  276. mark_as_sequence_parallel_parameter(self.weight)
  277. mark_as_sequence_parallel_parameter(self.bias)
  278. class KeyeRotaryEmbedding(nn.Layer):
  279. def __init__(self, config: PaddleOCRVLConfig, device=None):
  280. super().__init__()
  281. self.rope_kwargs = {}
  282. if config is None:
  283. raise NotImplementedError
  284. else:
  285. # BC: "rope_type" was originally "type"
  286. if config.rope_scaling is not None:
  287. self.rope_type = config.rope_scaling.get(
  288. "rope_type", config.rope_scaling.get("type")
  289. )
  290. else:
  291. self.rope_type = "default"
  292. # BC: "rope_type" was originally "type"
  293. if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
  294. self.rope_type = config.rope_scaling.get(
  295. "rope_type", config.rope_scaling.get("type")
  296. )
  297. else:
  298. self.rope_type = "default"
  299. self.config = config
  300. if self.rope_type == "default":
  301. dim = config.head_dim
  302. inv_freq = 1.0 / (
  303. config.rope_theta
  304. ** (paddle.arange(0, dim, 2, dtype="int64").astype("float32") / dim)
  305. )
  306. self.attention_scaling = 1.0
  307. else:
  308. raise ValueError(f"Unsupported rope type: {self.rope_type}")
  309. self.register_buffer("inv_freq", inv_freq, persistable=False)
  310. self.original_inv_freq = self.inv_freq
  311. @paddle.no_grad()
  312. def forward(self, x, position_ids):
  313. # Core RoPE block. In contrast to other models, Keye has different position ids for the grids
  314. # So we expand the inv_freq to shape (3, ...)
  315. inv_freq_expanded = (
  316. self.inv_freq[None, None, :, None]
  317. .cast("float32")
  318. .expand((3, position_ids.shape[1], -1, 1))
  319. )
  320. position_ids_expanded = position_ids[:, :, None, :].cast(
  321. "float32"
  322. ) # shape (3, bs, 1, positions)
  323. with paddle.amp.auto_cast(enable=False):
  324. freqs = (
  325. inv_freq_expanded.cast("float32")
  326. @ position_ids_expanded.cast("float32")
  327. ).transpose((0, 1, 3, 2))
  328. emb = paddle.concat((freqs, freqs), axis=-1)
  329. cos = emb.cos()
  330. sin = emb.sin()
  331. # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
  332. cos = cos * self.attention_scaling
  333. sin = sin * self.attention_scaling
  334. return cos.astype(x.dtype), sin.astype(x.dtype)
  335. class Ernie4_5MLP(nn.Layer):
  336. """
  337. Ernie4_5MLP - Gated Multi-Layer Perceptron module used in Ernie model.
  338. """
  339. def __init__(self, config, layer_idx=0):
  340. """
  341. Initialize the MLP module with configuration options.
  342. Args:
  343. config (PaddleOCRVLConfig): Model configurations.
  344. layer_idx (int): Index of current layer (default: 0)
  345. """
  346. super().__init__()
  347. self.config = config
  348. self.hidden_size = config.hidden_size
  349. self.intermediate_size = config.intermediate_size
  350. if config.tensor_parallel_degree > 1:
  351. ColumnLN = (
  352. ColumnSequenceParallelLinear
  353. if config.sequence_parallel
  354. else ColumnParallelLinear
  355. )
  356. RowLN = (
  357. RowSequenceParallelLinear
  358. if config.sequence_parallel
  359. else RowParallelLinear
  360. )
  361. column_ln_configs = {}
  362. if (
  363. config.recompute
  364. and config.sequence_parallel
  365. and config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False)
  366. ):
  367. ColumnLN = RRColumnSequenceParallelLinear
  368. column_ln_configs = {"use_rr": True}
  369. self.up_gate_proj = ColumnLN(
  370. self.hidden_size,
  371. self.intermediate_size * 2,
  372. gather_output=False,
  373. has_bias=config.use_bias,
  374. fuse_matmul_bias=config.fuse_linear,
  375. **column_ln_configs,
  376. )
  377. else:
  378. LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else Linear
  379. self.up_gate_proj = LinearFN(
  380. self.hidden_size, self.intermediate_size * 2, bias_attr=config.use_bias
  381. )
  382. if config.tensor_parallel_degree > 1:
  383. row_ln_configs = {}
  384. if (
  385. config.recompute
  386. and config.sequence_parallel
  387. and config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False)
  388. ):
  389. RowLN = RRRowSequenceParallelLinear
  390. row_ln_configs = {"use_rr": True}
  391. self.down_proj = RowLN(
  392. self.intermediate_size,
  393. self.hidden_size,
  394. input_is_parallel=True,
  395. has_bias=config.use_bias,
  396. fuse_matmul_bias=config.fuse_linear,
  397. **row_ln_configs,
  398. )
  399. else:
  400. LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else Linear
  401. self.down_proj = LinearFN(
  402. self.intermediate_size, self.hidden_size, bias_attr=config.use_bias
  403. )
  404. self.fuse_swiglu = config.fuse_swiglu
  405. if self.fuse_swiglu:
  406. assert fused_swiglu is not None, "fused_swiglu operator is not found."
  407. def forward(self, x):
  408. """
  409. Forward pass through the MLP module.
  410. Args:
  411. x (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
  412. Returns:
  413. Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
  414. Note:
  415. Implements SwiGLU activation: swish(Wx) * (Vx) where W and V are
  416. the first and second halves of up_gate_proj output respectively.
  417. """
  418. if self.fuse_swiglu:
  419. x = self.up_gate_proj(x)
  420. x = fused_swiglu(x)
  421. else:
  422. gate, x = self.up_gate_proj(x).chunk(2, axis=-1)
  423. x = F.silu(gate) * x
  424. return self.down_proj(x)
  425. class Ernie4_5Attention(nn.Layer):
  426. """Multi-headed attention from 'Attention Is All You Need' paper"""
  427. def __init__(self, config, layer_idx=0):
  428. """Initialize the attention layer.
  429. Args:
  430. config (PaddleOCRVLConfig): Model configuration.
  431. layer_idx (int, optional): Index in transformer stack. Defaults to 0.
  432. """
  433. super().__init__()
  434. self.layer_idx = layer_idx
  435. self.hidden_size = config.hidden_size
  436. self.num_heads = config.num_attention_heads
  437. self.num_key_value_heads = config.num_key_value_heads
  438. if getattr(config, "head_dim", None) is None:
  439. self.head_dim = self.hidden_size // self.num_heads
  440. else:
  441. self.head_dim = config.head_dim
  442. self.is_gqa = (
  443. config.num_key_value_heads is not None
  444. and config.num_key_value_heads != self.num_heads
  445. )
  446. self.rope_scaling = config.rope_scaling
  447. self.freq_allocation = config.get("freq_allocation", 0)
  448. if config.tensor_parallel_degree > 1:
  449. assert (
  450. self.num_heads % config.tensor_parallel_degree == 0
  451. ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  452. self.num_heads = self.num_heads // config.tensor_parallel_degree
  453. if self.is_gqa:
  454. assert (
  455. self.num_key_value_heads % config.tensor_parallel_degree == 0
  456. ), f"num_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  457. self.num_key_value_heads = (
  458. self.num_key_value_heads // config.tensor_parallel_degree
  459. )
  460. if self.is_gqa:
  461. logging.info(
  462. f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}"
  463. )
  464. assert (
  465. self.num_heads % self.num_key_value_heads == 0
  466. ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}"
  467. if getattr(config, "head_dim", None) is None:
  468. kv_hidden_size = (
  469. self.hidden_size // self.num_heads * self.num_key_value_heads
  470. )
  471. else:
  472. kv_hidden_size = self.head_dim * config.num_key_value_heads
  473. q_hidden_size = self.head_dim * config.num_attention_heads
  474. else:
  475. q_hidden_size = kv_hidden_size = self.head_dim * config.num_attention_heads
  476. if config.tensor_parallel_degree > 1:
  477. column_ln_configs = {}
  478. ColumnLN = (
  479. ColumnSequenceParallelLinear
  480. if config.sequence_parallel
  481. else ColumnParallelLinear
  482. )
  483. RowLN = (
  484. RowSequenceParallelLinear
  485. if config.sequence_parallel
  486. else RowParallelLinear
  487. )
  488. if (
  489. config.recompute
  490. and config.sequence_parallel
  491. and config.skip_recompute_ops[layer_idx].get(
  492. "attention_column_ln", False
  493. )
  494. ):
  495. ColumnLN = RRColumnSequenceParallelLinear
  496. column_ln_configs = {"use_rr": True}
  497. if getattr(config, "head_dim", None) is None:
  498. qkv_hidden_size = (
  499. self.hidden_size * 3
  500. if not self.is_gqa
  501. else self.hidden_size + kv_hidden_size * 2
  502. )
  503. else:
  504. qkv_hidden_size = q_hidden_size + kv_hidden_size * 2
  505. self.qkv_proj = ColumnLN(
  506. self.hidden_size,
  507. qkv_hidden_size,
  508. has_bias=config.use_bias,
  509. gather_output=False,
  510. fuse_matmul_bias=config.fuse_linear,
  511. **column_ln_configs,
  512. )
  513. else:
  514. LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else Linear
  515. if getattr(config, "head_dim", None) is None:
  516. qkv_hidden_size = (
  517. self.hidden_size * 3
  518. if not self.is_gqa
  519. else self.hidden_size + kv_hidden_size * 2
  520. )
  521. else:
  522. qkv_hidden_size = q_hidden_size + kv_hidden_size * 2
  523. self.qkv_proj = LinearFN(
  524. self.hidden_size,
  525. qkv_hidden_size,
  526. bias_attr=config.use_bias,
  527. )
  528. if config.tensor_parallel_degree > 1:
  529. row_ln_configs = {}
  530. if (
  531. config.recompute
  532. and config.sequence_parallel
  533. and config.skip_recompute_ops[layer_idx].get("attention_row_ln", False)
  534. ):
  535. RowLN = RRRowSequenceParallelLinear
  536. row_ln_configs = {"use_rr": True}
  537. self.o_proj = RowLN(
  538. (
  539. self.hidden_size
  540. if getattr(config, "head_dim", None) is None
  541. else q_hidden_size
  542. ),
  543. self.hidden_size,
  544. has_bias=config.use_bias,
  545. input_is_parallel=True,
  546. fuse_matmul_bias=config.fuse_linear,
  547. **row_ln_configs,
  548. )
  549. else:
  550. LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else Linear
  551. self.o_proj = LinearFN(
  552. (
  553. self.hidden_size
  554. if getattr(config, "head_dim", None) is None
  555. else q_hidden_size
  556. ),
  557. self.hidden_size,
  558. bias_attr=config.use_bias,
  559. )
  560. self.config = config
  561. self._rr_flash_attn = None
  562. if config.recompute and config.skip_recompute_ops[layer_idx].get(
  563. "flash_attn", False
  564. ):
  565. # TODO
  566. raise NotImplementedError
  567. self.set_attn_func()
  568. def set_attn_func(self):
  569. """Configure attention function based on settings.
  570. Selects between flash/core attention.
  571. """
  572. config = self.config
  573. if config.use_flash_attention:
  574. self.attn_func = self._flash_attention_wrapper
  575. else:
  576. self.attn_func = self.core_attn
  577. if config.cachekv_quant:
  578. # TODO: Support `cachekv_quant`
  579. raise NotImplementedError
  580. def forward(
  581. self,
  582. hidden_states,
  583. position_embeddings,
  584. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  585. attention_mask: Optional[paddle.Tensor] = None,
  586. attn_mask_start_row_indices: Optional[paddle.Tensor] = None,
  587. position_ids: Optional[Tuple[paddle.Tensor]] = None,
  588. output_attentions: bool = False,
  589. use_cache: bool = False,
  590. token_type_ids: Optional[Tuple[paddle.Tensor]] = None, # MLLM
  591. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
  592. """Compute attention outputs.
  593. Args:
  594. hidden_states (paddle.Tensor): Input tensor [bsz, seq_len, hidden_size]
  595. position_embeddings (paddle.Tensor): Position embeddings
  596. past_key_value (Optional[Tuple[paddle.Tensor, paddle.Tensor]]): Cached key/value states
  597. attention_mask (Optional[paddle.Tensor]): Attention mask tensor
  598. attn_mask_start_row_indices (Optional[paddle.Tensor]): Variable length attention indices
  599. position_ids (Optional[paddle.Tensor]): Position indices for RoPE
  600. output_attentions (bool): Return attention weights if True
  601. use_cache (bool): Cache key/value states if True
  602. Returns:
  603. Tuple containing:
  604. - attention_output: [bsz, seq_len, hidden_size]
  605. - attention_weights: Optional attention probabilities
  606. - updated_key_value_cache: Optional updated cache
  607. """
  608. if token_type_ids is not None:
  609. token_type_ids = token_type_ids[:, :-1]
  610. if self.config.sequence_parallel:
  611. if token_type_ids is not None:
  612. token_type_ids = token_type_ids.reshape([-1])
  613. token_type_ids = ScatterOp.apply(token_type_ids)
  614. token_type_ids.stop_gradient = True
  615. max_sequence_length = self.config.max_sequence_length
  616. bsz = (
  617. hidden_states.shape[0]
  618. * self.config.tensor_parallel_degree
  619. // max_sequence_length
  620. )
  621. q_len = max_sequence_length
  622. else:
  623. bsz, q_len, _ = hidden_states.shape
  624. query_states = key_states = value_states = mix_layer = None
  625. mix_layer = self.qkv_proj(hidden_states)
  626. if self.is_gqa:
  627. query_states, key_states, value_states = paddle.split(
  628. mix_layer.reshape([bsz, q_len, -1, self.head_dim]),
  629. [self.num_heads, self.num_key_value_heads, self.num_key_value_heads],
  630. axis=2,
  631. )
  632. mix_layer = None
  633. else:
  634. mix_layer = mix_layer.reshape(
  635. [bsz, q_len, self.num_heads, 3 * self.head_dim]
  636. )
  637. if mix_layer is not None:
  638. has_gradient = not mix_layer.stop_gradient
  639. else:
  640. has_gradient = not (
  641. query_states.stop_gradient
  642. and key_states.stop_gradient
  643. and value_states.stop_gradient
  644. )
  645. if (
  646. self.config.recompute
  647. and self.config.recompute_granularity == "core_attn"
  648. and has_gradient
  649. ):
  650. assert past_key_value is None, "do not use kv cache in recompute"
  651. assert not use_cache
  652. attn_output, attn_weights, past_key_value = recompute(
  653. self.rope_attn,
  654. mix_layer,
  655. query_states,
  656. key_states,
  657. value_states,
  658. position_embeddings,
  659. attention_mask,
  660. position_ids,
  661. output_attentions,
  662. past_key_value,
  663. use_cache,
  664. attn_mask_start_row_indices,
  665. use_reentrant=self.config.recompute_use_reentrant,
  666. )
  667. else:
  668. attn_output, attn_weights, past_key_value = self.rope_attn(
  669. mix_layer=mix_layer,
  670. query_states=query_states,
  671. key_states=key_states,
  672. value_states=value_states,
  673. position_embeddings=position_embeddings,
  674. attention_mask=attention_mask,
  675. position_ids=position_ids,
  676. output_attentions=output_attentions,
  677. past_key_value=past_key_value,
  678. use_cache=use_cache,
  679. attn_mask_start_row_indices=attn_mask_start_row_indices,
  680. )
  681. if self.config.sequence_parallel:
  682. attn_output = attn_output.reshape([-1, attn_output.shape[-1]])
  683. attn_output = self.o_proj(attn_output)
  684. if not output_attentions:
  685. attn_weights = None
  686. return attn_output, attn_weights, past_key_value
  687. def _flash_attention_wrapper(
  688. self,
  689. q,
  690. k,
  691. v,
  692. attention_mask=None,
  693. attn_mask_start_row_indices=None,
  694. seq_length=None,
  695. ):
  696. """Optimized flash attention implementation.
  697. Args:
  698. q (paddle.Tensor): Query tensor
  699. k (paddle.Tensor): Key tensor
  700. v (paddle.Tensor): Value tensor
  701. attention_mask (Optional[paddle.Tensor]): Attention mask
  702. attn_mask_start_row_indices (Optional[paddle.Tensor]): Variable length indices
  703. seq_length (Optional[int]): Sequence length
  704. Returns:
  705. paddle.Tensor: Attention output tensor
  706. """
  707. return fusion_flash_attention(
  708. q,
  709. k,
  710. v,
  711. self.training,
  712. self.config.attention_probs_dropout_prob,
  713. self.config.use_sparse_flash_attn,
  714. attention_mask,
  715. attn_mask_start_row_indices,
  716. seq_length,
  717. self.config.use_var_len_flash_attn,
  718. self._rr_flash_attn if self.training else None,
  719. )
  720. def core_attn(
  721. self,
  722. q,
  723. k,
  724. v,
  725. attention_mask=None,
  726. attn_mask_start_row_indices=None,
  727. seq_length=None,
  728. ):
  729. """Standard self-attention implementation.
  730. Args:
  731. q (paddle.Tensor): Query tensor
  732. k (paddle.Tensor): Key tensor
  733. v (paddle.Tensor): Value tensor
  734. attention_mask (Optional[paddle.Tensor]): Attention mask
  735. attn_mask_start_row_indices (Optional[paddle.Tensor]): Variable length indices
  736. seq_length (Optional[int]): Sequence length
  737. Returns:
  738. Tuple[paddle.Tensor, paddle.Tensor]: Attention output and weights
  739. """
  740. perm = [
  741. 0,
  742. 2,
  743. 1,
  744. 3,
  745. ] # [1, 2, 0, 3] if self.sequence_parallel else [0, 2, 1, 3]
  746. origin_dtype = q.dtype
  747. q = tensor.transpose(x=q, perm=perm)
  748. k = tensor.transpose(x=k, perm=perm)
  749. v = tensor.transpose(x=v, perm=perm)
  750. replicate = self.config.num_attention_heads // self.config.num_key_value_heads
  751. k = paddle.repeat_interleave(k, replicate, axis=1)
  752. v = paddle.repeat_interleave(v, replicate, axis=1)
  753. scale_qk_coeff = self.config.scale_qk_coeff * self.head_dim**0.5
  754. product = paddle.matmul(x=q.scale(1.0 / scale_qk_coeff), y=k, transpose_y=True)
  755. product = product.cast(paddle.float32)
  756. if self.config.scale_qk_coeff != 1.0:
  757. product = product.scale(self.config.scale_qk_coeff)
  758. if attention_mask is not None:
  759. attention_mask = attention_mask.cast(paddle.float32)
  760. if self.config.fuse_softmax_mask:
  761. weights = incubate.softmax_mask_fuse(product, attention_mask)
  762. else:
  763. product = product + attention_mask
  764. weights = F.softmax(product)
  765. else:
  766. weights = incubate.softmax_mask_fuse_upper_triangle(product)
  767. weights = weights.cast(origin_dtype)
  768. if self.config.attention_probs_dropout_prob:
  769. with get_rng_state_tracker().rng_state("local_seed"):
  770. weights = F.dropout(
  771. weights,
  772. self.config.attention_probs_dropout_prob,
  773. training=self.training,
  774. mode="upscale_in_train",
  775. )
  776. out = paddle.matmul(weights, v)
  777. # combine heads
  778. out = tensor.transpose(out, perm=[0, 2, 1, 3])
  779. # If sequence_parallel is true, out shape is [s, b, h] after reshape
  780. # else out shape is [b, s, h]
  781. out = tensor.reshape(x=out, shape=[0, 0, -1])
  782. return out, weights
  783. def rope_attn(
  784. self,
  785. mix_layer,
  786. query_states,
  787. key_states,
  788. value_states,
  789. position_embeddings,
  790. attention_mask,
  791. position_ids,
  792. output_attentions=False,
  793. past_key_value=None,
  794. use_cache=False,
  795. attn_mask_start_row_indices=None,
  796. ):
  797. if mix_layer is not None:
  798. query_states, key_states, value_states = paddle.split(mix_layer, 3, axis=-1)
  799. query_states_dtype = query_states.dtype
  800. kv_seq_len = position_ids.max() + 1
  801. offset = 0
  802. if past_key_value is not None:
  803. # LLM
  804. offset = past_key_value[0].shape[-3]
  805. kv_seq_len += offset
  806. query_states = query_states.astype(query_states_dtype)
  807. key_states = key_states.astype(query_states_dtype)
  808. if position_ids.dim() == 3 and position_ids.shape[0] > 1:
  809. position_ids = position_ids[0:1]
  810. cos, sin = position_embeddings
  811. query_states, key_states = _apply_multimodal_rotary_pos_emb(
  812. query_states, key_states, cos, sin, self.rope_scaling["mrope_section"], 2
  813. )
  814. if past_key_value is not None:
  815. # reuse k, v, self_attention
  816. key_states = paddle.concat([past_key_value[0], key_states], axis=1)
  817. value_states = paddle.concat([past_key_value[1], value_states], axis=1)
  818. # NOTE(for generation): use list instead of tuple to store the cache
  819. # tensors, so that we can clear the cache tensors for memory efficiency.
  820. past_key_value = [key_states, value_states] if use_cache else None
  821. seq_length = query_states.shape[1]
  822. attn_output, attn_weights = self.attn_func(
  823. query_states,
  824. key_states,
  825. value_states,
  826. attention_mask,
  827. attn_mask_start_row_indices,
  828. seq_length,
  829. )
  830. return attn_output, attn_weights, past_key_value
  831. class FusedHeadParallelCrossEntropy(PyLayer):
  832. """Fused parallel cross-entropy loss computation for large sequence lengths.
  833. Combines head projection and loss computation with optimized memory usage for long sequences,
  834. supporting tensor parallel training.
  835. """
  836. @staticmethod
  837. def forward(
  838. ctx,
  839. hidden_states,
  840. weight,
  841. bias,
  842. labels,
  843. tensor_parallel_degree,
  844. mp_group=None,
  845. ignore_index=-100,
  846. seq_chunk_size=8192,
  847. transpose_y=False,
  848. fuse_linear=False,
  849. training=True,
  850. ):
  851. """Forward pass for parallel cross-entropy computation.
  852. Args:
  853. ctx: Context object for saving tensors between forward/backward
  854. hidden_states (paddle.Tensor): Input tensor of shape [batch_size*seq_len, hidden_size]
  855. weight (paddle.Tensor): Weight matrix for projection
  856. bias (Optional[paddle.Tensor]): Optional bias vector
  857. labels (paddle.Tensor): Target labels tensor of shape [batch_size*seq_len]
  858. tensor_parallel_degree (int): Degree of tensor parallelism
  859. mp_group (Optional[dist.Group]): Model parallel group. Defaults to None (auto-detect)
  860. ignore_index (int): Index to ignore in loss computation. Defaults to -100
  861. seq_chunk_size (int): Chunk size for processing long sequences. Defaults to 8192
  862. transpose_y (bool): Whether to transpose weight matrix. Defaults to False
  863. fuse_linear (bool): Whether to use fused linear ops. Defaults to False
  864. training (bool): Whether in training mode. Defaults to True
  865. Returns:
  866. Tuple[paddle.Tensor, paddle.Tensor]:
  867. - loss: Computed loss tensor
  868. - gathered_labels: Concatenated labels from all parallel groups
  869. """
  870. ctx.tensor_parallel_degree = tensor_parallel_degree
  871. ctx.ignore_index = ignore_index
  872. ctx.seq_chunk_size = seq_chunk_size
  873. ctx.transpose_y = transpose_y
  874. ctx.fuse_linear = fuse_linear
  875. ctx.training = training
  876. ctx.hidden_states_shape = hidden_states.shape
  877. ctx.mp_group = (
  878. fleet.get_hybrid_communicate_group().get_model_parallel_group()
  879. if mp_group is None
  880. else mp_group
  881. )
  882. ctx.rank = ctx.mp_group.rank
  883. ctx.world_size = ctx.mp_group.nranks
  884. loss_all = []
  885. labels_all = []
  886. with paddle.no_grad():
  887. labels = labels.reshape_([-1])
  888. hidden_states = hidden_states.reshape_([-1, hidden_states.shape[-1]])
  889. num_tokens_per_rank = []
  890. dist.stream.all_gather(
  891. num_tokens_per_rank,
  892. paddle.to_tensor(hidden_states.shape[0], dtype=paddle.int32),
  893. group=ctx.mp_group,
  894. )
  895. ctx.num_tokens_per_rank = num_tokens_per_rank
  896. for idx in range(ctx.world_size):
  897. if idx == ctx.rank:
  898. hidden_states_recv = hidden_states
  899. labels_recv = labels
  900. else:
  901. hidden_states_recv = paddle.empty(
  902. [ctx.num_tokens_per_rank[idx], hidden_states.shape[-1]],
  903. dtype=hidden_states.dtype,
  904. )
  905. labels_recv = paddle.empty(
  906. [ctx.num_tokens_per_rank[idx]], dtype=labels.dtype
  907. )
  908. dist.stream.broadcast(
  909. hidden_states_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group
  910. )
  911. dist.stream.broadcast(
  912. labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group
  913. )
  914. seq_len = hidden_states_recv.shape[0]
  915. num_chunk = (seq_len + ctx.seq_chunk_size - 1) // ctx.seq_chunk_size
  916. loss_chunk = []
  917. for chunk_idx in range(num_chunk):
  918. start = chunk_idx * ctx.seq_chunk_size
  919. end = min(start + ctx.seq_chunk_size, seq_len)
  920. hidden_states_chunk = hidden_states_recv._slice(start, end)
  921. labels_chunk = labels_recv._slice(start, end)
  922. logits = parallel_matmul(
  923. hidden_states_chunk,
  924. weight,
  925. bias=bias,
  926. transpose_y=ctx.transpose_y,
  927. tensor_parallel_degree=ctx.tensor_parallel_degree,
  928. tensor_parallel_output=True,
  929. fuse_linear=ctx.fuse_linear,
  930. training=ctx.training,
  931. )
  932. with paddle.amp.auto_cast(False):
  933. loss = mp_ops._c_softmax_with_cross_entropy(
  934. logits.cast("float32"),
  935. labels_chunk.unsqueeze(-1),
  936. group=ctx.mp_group,
  937. ignore_index=ctx.ignore_index,
  938. )
  939. loss_chunk.append(loss)
  940. loss_all.append(paddle.concat(loss_chunk, axis=0))
  941. labels_all.append(labels_recv)
  942. ctx.loss_concat_sections = [loss.shape[0] for loss in loss_all]
  943. loss_all = paddle.concat(loss_all, axis=0)
  944. labels_all = paddle.concat(labels_all, axis=0)
  945. tensor_inputs = [hidden_states, weight, bias, labels]
  946. ctx.save_for_backward(*tensor_inputs)
  947. return loss_all, labels_all
  948. @staticmethod
  949. def backward(ctx, loss_all_grad, labels_all_grad):
  950. """Backward pass for parallel cross-entropy computation.
  951. Args:
  952. ctx: Context object with saved tensors from forward
  953. loss_all_grad (paddle.Tensor): Gradient of loss
  954. labels_all_grad (paddle.Tensor): Gradient of labels (unused)
  955. Returns:
  956. Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[paddle.Tensor], None]:
  957. - hidden_states_grad: Gradient for input hidden states
  958. - weight_grad: Gradient for weight matrix (None if not trainable)
  959. - bias_grad: Gradient for bias vector (None if not trainable or not provided)
  960. - None: Placeholder for labels gradient
  961. """
  962. hidden_states, weight, bias, labels = ctx.saved_tensor()
  963. loss_all_grad_list = paddle.split(
  964. loss_all_grad, ctx.loss_concat_sections, axis=0
  965. )
  966. def detach_variable(inp):
  967. if inp is None:
  968. return None
  969. x = inp.detach()
  970. x.stop_gradient = inp.stop_gradient
  971. return x
  972. if weight.stop_gradient is False:
  973. weight_main_grad = paddle.zeros(weight.shape, dtype=paddle.float32)
  974. else:
  975. weight_main_grad = None
  976. if bias is not None and bias.stop_gradient is False:
  977. bias_main_grad = paddle.zeros(bias.shape, dtype=paddle.float32)
  978. else:
  979. bias_main_grad = None
  980. hidden_states = detach_variable(hidden_states)
  981. weight = detach_variable(weight)
  982. bias = detach_variable(bias)
  983. labels = detach_variable(labels)
  984. with paddle.base.dygraph.guard():
  985. tracer = paddle.base.framework._dygraph_tracer()
  986. tracer._has_grad = True
  987. for idx in range(ctx.world_size):
  988. if idx == ctx.rank:
  989. hidden_states_recv = hidden_states
  990. labels_recv = labels
  991. else:
  992. hidden_states_recv = paddle.empty(
  993. [ctx.num_tokens_per_rank[idx], hidden_states.shape[-1]],
  994. dtype=hidden_states.dtype,
  995. )
  996. labels_recv = paddle.empty(
  997. [ctx.num_tokens_per_rank[idx]], dtype=labels.dtype
  998. )
  999. dist.stream.broadcast(
  1000. hidden_states_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group
  1001. )
  1002. dist.stream.broadcast(
  1003. labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group
  1004. )
  1005. hidden_states_recv.stop_gradient = False
  1006. seq_len = hidden_states_recv.shape[0]
  1007. num_chunk = (seq_len + ctx.seq_chunk_size - 1) // ctx.seq_chunk_size
  1008. for chunk_idx in range(num_chunk):
  1009. start = chunk_idx * ctx.seq_chunk_size
  1010. end = min(start + ctx.seq_chunk_size, seq_len)
  1011. hidden_states_chunk = hidden_states_recv.slice(
  1012. axes=[0], starts=[start], ends=[end]
  1013. )
  1014. labels_chunk = labels_recv._slice(start, end)
  1015. loss_grad_chunk = loss_all_grad_list[idx]._slice(start, end)
  1016. logits = parallel_matmul(
  1017. hidden_states_chunk,
  1018. weight,
  1019. bias=bias,
  1020. transpose_y=ctx.transpose_y,
  1021. tensor_parallel_degree=ctx.tensor_parallel_degree,
  1022. tensor_parallel_output=True,
  1023. fuse_linear=ctx.fuse_linear,
  1024. training=ctx.training,
  1025. )
  1026. with paddle.amp.auto_cast(False):
  1027. loss_chunk = mp_ops._c_softmax_with_cross_entropy(
  1028. logits.cast("float32"),
  1029. labels_chunk.unsqueeze(-1),
  1030. group=ctx.mp_group,
  1031. ignore_index=ctx.ignore_index,
  1032. )
  1033. with paddle.amp.auto_cast(enable=False):
  1034. paddle.autograd.backward(loss_chunk, loss_grad_chunk)
  1035. if weight_main_grad is not None:
  1036. weight_main_grad.add_(weight.grad.cast(paddle.float32))
  1037. weight.clear_gradient(True)
  1038. if bias_main_grad is not None:
  1039. bias_main_grad.add_(bias.grad.cast(paddle.float32))
  1040. bias.clear_gradient(True)
  1041. if idx == ctx.rank:
  1042. hidden_states_grad = hidden_states_recv.grad
  1043. hidden_states_grad = hidden_states_grad.reshape(
  1044. ctx.hidden_states_shape
  1045. )
  1046. if weight_main_grad is not None:
  1047. weight_main_grad = weight_main_grad.astype(weight.dtype)
  1048. if bias_main_grad is not None:
  1049. bias_main_grad = bias_main_grad.astype(bias.dtype)
  1050. return (
  1051. hidden_states_grad,
  1052. weight_main_grad,
  1053. bias_main_grad,
  1054. None,
  1055. )
  1056. class ErniePretrainingCriterion(paddle.nn.Layer):
  1057. """Criterion for ERNIE pretraining task."""
  1058. def __init__(self, config, return_tuple=True):
  1059. """Initialize the pretraining criterion.
  1060. Args:
  1061. config (PaddleOCRVLConfig): Model configuration.
  1062. return_tuple (bool): Whether to return loss as tuple (loss, loss_sum). Defaults to True.
  1063. """
  1064. super(ErniePretrainingCriterion, self).__init__()
  1065. self.ignored_index = getattr(config, "ignored_index", -100)
  1066. self.config = config
  1067. self.return_tuple = return_tuple
  1068. self.enable_parallel_cross_entropy = (
  1069. config.tensor_parallel_degree > 1 and config.tensor_parallel_output
  1070. )
  1071. if (
  1072. self.enable_parallel_cross_entropy
  1073. ): # and False: # and lm_head is distributed
  1074. logging.info("using parallel cross entroy, take care")
  1075. self.loss_func = ParallelCrossEntropy()
  1076. else:
  1077. self.loss_func = paddle.nn.CrossEntropyLoss(
  1078. reduction="none",
  1079. )
  1080. self.token_balance_loss = config.token_balance_loss
  1081. def forward(self, prediction_scores, masked_lm_labels, loss_mask=None):
  1082. """Compute the pretraining loss.
  1083. Args:
  1084. prediction_scores (Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]):
  1085. Either:
  1086. - Direct logits tensor [batch_size, seq_len, vocab_size]
  1087. - Tuple of (hidden_states, weight, bias) for sparse head computation
  1088. masked_lm_labels (paddle.Tensor): Target labels tensor [batch_size, seq_len]
  1089. loss_mask (Optional[paddle.Tensor]): Optional mask for valid tokens. Defaults to None.
  1090. Returns:
  1091. Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
  1092. - If return_tuple=False: Single loss tensor
  1093. - If return_tuple=True: Tuple of (normalized_loss, sum_loss)
  1094. """
  1095. if self.config.use_sparse_head_and_loss_fn:
  1096. hidden_states, outlinear_weight, outlinear_bias = prediction_scores[:3]
  1097. if self.config.sequence_parallel:
  1098. masked_lm_labels, sparse_label_idx = (
  1099. sequence_parallel_sparse_mask_labels(
  1100. masked_lm_labels, self.ignored_index
  1101. )
  1102. )
  1103. sparse_label_idx = sparse_label_idx.reshape([-1, 1])
  1104. hidden_states = paddle.gather(hidden_states, sparse_label_idx, axis=0)
  1105. hidden_states = AllGatherVarlenOp.apply(hidden_states)
  1106. else:
  1107. masked_lm_labels = masked_lm_labels.flatten()
  1108. sparse_label_idx = paddle.nonzero(
  1109. masked_lm_labels != self.ignored_index
  1110. ).flatten()
  1111. masked_lm_labels = paddle.take_along_axis(
  1112. masked_lm_labels, sparse_label_idx, axis=0
  1113. )
  1114. hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]])
  1115. hidden_states = paddle.take_along_axis(
  1116. hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0
  1117. )
  1118. # `loss_mask` must be reset to None and re-calculate it in ErnieBotPretrainingCriterion
  1119. # when use use_sparse_head_and_loss_fn.
  1120. loss_mask = None
  1121. if self.config.use_recompute_loss_fn:
  1122. offload_kwargs = {}
  1123. if self.config.get("offload_lm_head", False):
  1124. offload_kwargs["offload_indices"] = [1]
  1125. res = recompute(
  1126. self.forward_impl_with_calc_logits,
  1127. masked_lm_labels,
  1128. loss_mask,
  1129. hidden_states,
  1130. outlinear_weight,
  1131. outlinear_bias,
  1132. **offload_kwargs,
  1133. )
  1134. else:
  1135. logits = calc_lm_head_logits(
  1136. self.config,
  1137. hidden_states,
  1138. outlinear_weight,
  1139. outlinear_bias,
  1140. training=self.training,
  1141. )
  1142. res = self.forward_impl(logits, masked_lm_labels, loss_mask)
  1143. elif self.config.use_recompute_loss_fn:
  1144. if self.config.use_fused_head_and_loss_fn:
  1145. res = self.forward_impl_with_fused_head_loss_fn(
  1146. masked_lm_labels, loss_mask, *prediction_scores
  1147. )
  1148. else:
  1149. assert isinstance(prediction_scores, tuple) and len(
  1150. prediction_scores
  1151. ) in [3, 4], prediction_scores
  1152. res = recompute(
  1153. self.forward_impl_with_calc_logits,
  1154. masked_lm_labels,
  1155. loss_mask,
  1156. *prediction_scores,
  1157. )
  1158. else:
  1159. res = self.forward_impl(prediction_scores, masked_lm_labels, loss_mask)
  1160. return res
  1161. def forward_impl_with_fused_head_loss_fn(
  1162. self,
  1163. masked_lm_labels,
  1164. loss_mask,
  1165. hidden_states,
  1166. outlinear_weight,
  1167. outlinear_bias,
  1168. ):
  1169. """Compute loss with fused head and parallel cross-entropy.
  1170. Args:
  1171. masked_lm_labels (paddle.Tensor): Target labels tensor [batch_size, seq_len]
  1172. loss_mask (Optional[paddle.Tensor]): Optional mask for valid tokens
  1173. hidden_states (paddle.Tensor): Hidden states from transformer [batch_size, seq_len, hidden_size]
  1174. outlinear_weight (paddle.Tensor): Weight matrix for output projection
  1175. outlinear_bias (Optional[paddle.Tensor]): Optional bias for output projection
  1176. Returns:
  1177. Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
  1178. Same return format as forward()
  1179. """
  1180. assert (
  1181. self.config.tensor_parallel_degree > 0
  1182. ), "use_fused_head_and_loss_fn require tensor_parallel_degree > 0"
  1183. masked_lm_loss, masked_lm_labels_all = FusedHeadParallelCrossEntropy.apply(
  1184. hidden_states,
  1185. outlinear_weight,
  1186. outlinear_bias,
  1187. masked_lm_labels,
  1188. self.config.tensor_parallel_degree,
  1189. ignore_index=self.ignored_index,
  1190. seq_chunk_size=self.config.get("loss_subbatch_seqlen", 32768),
  1191. transpose_y=self.config.tie_word_embeddings,
  1192. fuse_linear=self.config.fuse_linear,
  1193. training=self.training,
  1194. )
  1195. if loss_mask is None:
  1196. loss_mask = masked_lm_labels_all != self.ignored_index
  1197. if (~loss_mask).all(): # empty span
  1198. logging.warning(
  1199. f"encounter empty span when calculate loss, ignored_index={self.ignored_index}"
  1200. )
  1201. loss = paddle.mean(masked_lm_loss) * 0.0
  1202. loss_sum = masked_lm_loss.sum().detach()
  1203. else:
  1204. loss_mask = loss_mask.reshape([-1]).cast(paddle.float32)
  1205. # 逐位对齐, 全精度聚合
  1206. masked_lm_loss = paddle.sum(
  1207. masked_lm_loss.cast(paddle.float32).reshape([-1]) * loss_mask
  1208. )
  1209. loss = masked_lm_loss / loss_mask.sum()
  1210. if self.token_balance_loss:
  1211. _loss = masked_lm_loss / self.config.token_balance_seqlen
  1212. loss = _loss - _loss.detach() + loss.detach() # for 对线
  1213. loss_sum = masked_lm_loss.sum().detach()
  1214. if not self.return_tuple: # only used in pp
  1215. if self.training:
  1216. return loss
  1217. return loss_sum
  1218. return loss, loss_sum
  1219. def forward_impl_with_calc_logits(
  1220. self,
  1221. masked_lm_labels,
  1222. loss_mask,
  1223. hidden_states,
  1224. outlinear_weight,
  1225. outlinear_bias,
  1226. ):
  1227. """Compute logits then calculate loss.
  1228. Args:
  1229. Same as forward_impl_with_fused_head_loss_fn()
  1230. Returns:
  1231. Same return format as forward()
  1232. """
  1233. logits = calc_lm_head_logits(
  1234. self.config,
  1235. hidden_states,
  1236. outlinear_weight,
  1237. outlinear_bias,
  1238. training=self.training,
  1239. )
  1240. return self.forward_impl(logits, masked_lm_labels, loss_mask)
  1241. def loss_impl(self, prediction_scores, masked_lm_labels):
  1242. """Core loss computation without reduction.
  1243. Args:
  1244. prediction_scores (paddle.Tensor): Logits tensor [batch_size, seq_len, vocab_size]
  1245. masked_lm_labels (paddle.Tensor): Target labels tensor [batch_size, seq_len]
  1246. Returns:
  1247. paddle.Tensor: Unreduced loss tensor
  1248. """
  1249. prediction_scores = prediction_scores.cast("float32")
  1250. masked_lm_loss = self.loss_func(
  1251. prediction_scores, masked_lm_labels.unsqueeze(-1)
  1252. )
  1253. return masked_lm_loss
  1254. def forward_impl(self, prediction_scores, masked_lm_labels, loss_mask=None):
  1255. """Standard loss computation with reduction and masking.
  1256. Args:
  1257. prediction_scores (paddle.Tensor): Logits tensor [batch_size, seq_len, vocab_size]
  1258. masked_lm_labels (paddle.Tensor): Target labels tensor [batch_size, seq_len]
  1259. loss_mask (Optional[paddle.Tensor]): Optional mask for valid tokens
  1260. Returns:
  1261. Same return format as forward()
  1262. """
  1263. if self.enable_parallel_cross_entropy:
  1264. assert prediction_scores.shape[-1] != self.config.vocab_size, (
  1265. f"enable_parallel_cross_entropy, the vocab_size should be splited:"
  1266. f" {prediction_scores.shape[-1]}, {self.config.vocab_size}"
  1267. )
  1268. with paddle.amp.auto_cast(False):
  1269. prediction_scores_dims = len(prediction_scores.shape)
  1270. if prediction_scores_dims == 2 and prediction_scores.shape[
  1271. 0
  1272. ] > self.config.get("loss_subbatch_seqlen", 32768):
  1273. sb_loss_func = subbatch(
  1274. self.loss_impl,
  1275. [0, 1],
  1276. [0, 0],
  1277. self.config.get("loss_subbatch_seqlen", 32768),
  1278. 0,
  1279. )
  1280. masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
  1281. elif prediction_scores_dims == 3 and prediction_scores.shape[
  1282. 1
  1283. ] > self.config.get("loss_subbatch_seqlen", 32768):
  1284. sb_loss_func = subbatch(
  1285. self.loss_impl,
  1286. [0, 1],
  1287. [1, 1],
  1288. self.config.get("loss_subbatch_seqlen", 32768),
  1289. 1,
  1290. )
  1291. masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
  1292. else:
  1293. masked_lm_loss = self.loss_impl(prediction_scores, masked_lm_labels)
  1294. if loss_mask is None:
  1295. loss_mask = masked_lm_labels != self.ignored_index
  1296. lossmask = masked_lm_labels != self.ignored_index
  1297. if (~lossmask).all(): # empty span
  1298. logging.warning(
  1299. f"encounter empty span when calculate loss, ignored_index={self.ignored_index}"
  1300. )
  1301. loss = paddle.mean(masked_lm_loss) * 0.0
  1302. loss_sum = masked_lm_loss.sum().detach()
  1303. else:
  1304. loss_mask = loss_mask.reshape([-1]).cast(paddle.float32)
  1305. # 逐位对齐, 全精度聚合
  1306. masked_lm_loss = paddle.sum(
  1307. masked_lm_loss.cast(paddle.float32).reshape([-1]) * loss_mask
  1308. )
  1309. loss = masked_lm_loss / loss_mask.sum()
  1310. if self.token_balance_loss:
  1311. _loss = masked_lm_loss / self.config.token_balance_seqlen
  1312. loss = _loss - _loss.detach() + loss.detach() # for 对线
  1313. loss_sum = masked_lm_loss.sum().detach()
  1314. if not self.return_tuple: # only used in pp
  1315. if self.training:
  1316. return loss
  1317. return loss_sum
  1318. return loss, loss_sum
  1319. class Ernie4_5LMHead(nn.Layer):
  1320. """Language model head for ERNIE with support for tensor parallelism."""
  1321. def __init__(self, config):
  1322. """Initialize the language model head.
  1323. Args:
  1324. config (PaddleOCRVLConfig): Model configuration containing:
  1325. - vocab_size: Size of vocabulary
  1326. - hidden_size: Dimension of hidden states
  1327. - tensor_parallel_degree: Degree of tensor parallelism
  1328. - tie_word_embeddings: Whether to tie input/output embeddings
  1329. - weight_share_add_bias: Whether to add bias when weight sharing
  1330. - use_bias: Whether to use bias term
  1331. - use_recompute_loss_fn: Whether to defer logits computation to loss function
  1332. - use_sparse_head_and_loss_fn: Whether to use sparse head computation
  1333. """
  1334. super(Ernie4_5LMHead, self).__init__()
  1335. self.config = config
  1336. if config.tensor_parallel_degree > 1:
  1337. vocab_size = config.vocab_size // config.tensor_parallel_degree
  1338. else:
  1339. vocab_size = config.vocab_size
  1340. self.weight = self.create_parameter(
  1341. shape=(
  1342. [vocab_size, config.hidden_size]
  1343. if config.tie_word_embeddings
  1344. else [config.hidden_size, vocab_size]
  1345. ),
  1346. dtype=paddle.get_default_dtype(),
  1347. )
  1348. logging.info(
  1349. f"output-weight:{self.weight.shape} config.tie_word_embeddings={config.tie_word_embeddings}"
  1350. )
  1351. if config.weight_share_add_bias and config.use_bias:
  1352. self.bias = self.create_parameter(
  1353. shape=[vocab_size],
  1354. dtype=paddle.get_default_dtype(),
  1355. attr=paddle.ParamAttr(
  1356. initializer=paddle.nn.initializer.constant.Constant(0.0)
  1357. ),
  1358. )
  1359. else:
  1360. self.bias = None
  1361. # Must set distributed attr for Tensor Parallel !
  1362. self.weight.is_distributed = (
  1363. True if (vocab_size != config.vocab_size) else False
  1364. )
  1365. if config.weight_share_add_bias and config.use_bias:
  1366. self.bias.is_distributed = (
  1367. True if (vocab_size != config.vocab_size) else False
  1368. )
  1369. if self.weight.is_distributed:
  1370. self.weight.split_axis = 1
  1371. if (
  1372. config.weight_share_add_bias
  1373. and config.use_bias
  1374. and self.bias.is_distributed
  1375. ):
  1376. self.bias.split_axis = 0
  1377. if self.config.use_recompute_loss_fn:
  1378. logging.info(
  1379. "Using recompute_loss_fn, the calculation of logits will be moved into "
  1380. "loss_fn for memory optimization"
  1381. )
  1382. def forward(self, hidden_states, tensor_parallel_output=None):
  1383. """Project hidden states to vocabulary logits.
  1384. Args:
  1385. hidden_states (paddle.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
  1386. tensor_parallel_output (Optional[bool]): Whether to output parallel results. Defaults to None.
  1387. Returns:
  1388. Union[
  1389. Tuple[paddle.Tensor, paddle.Tensor, Optional[paddle.Tensor]]:
  1390. # When use_recompute_loss_fn or use_sparse_head_and_loss_fn
  1391. - hidden_states: Original input
  1392. - weight: Projection weights
  1393. - bias: Optional bias term
  1394. Tuple[paddle.Tensor, paddle.Tensor, Optional[paddle.Tensor], bool]: # With tensor_parallel_output
  1395. Same as above plus tensor_parallel_output flag
  1396. paddle.Tensor: # Normal case
  1397. Logits tensor of shape [batch_size, seq_len, vocab_size]
  1398. ]
  1399. """
  1400. # will enter this branch when:
  1401. # 1. use_recompute_loss_fn or use_sparse_head_and_loss_fn
  1402. # 2. dpo training
  1403. if self.config.use_recompute_loss_fn or self.config.use_sparse_head_and_loss_fn:
  1404. return (
  1405. hidden_states,
  1406. self.weight,
  1407. self.bias,
  1408. self.config.tie_word_embeddings,
  1409. )
  1410. return calc_lm_head_logits(
  1411. self.config,
  1412. hidden_states,
  1413. self.weight,
  1414. self.bias,
  1415. tensor_parallel_output,
  1416. training=self.training,
  1417. )
  1418. class Ernie4_5DecoderLayer(nn.Layer):
  1419. """A single transformer decoder layer in ERNIE model.
  1420. Contains self-attention and feed-forward components,
  1421. support, residual connections, and layer normalization.
  1422. """
  1423. def __init__(self, config, layer_idx):
  1424. """Initialize the decoder layer.
  1425. Args:
  1426. config (PaddleOCRVLConfig): Model configuration.
  1427. layer_idx (int): Index of this layer in the transformer stack
  1428. """
  1429. super().__init__()
  1430. self.hidden_size = config.hidden_size
  1431. self.layer_idx = layer_idx
  1432. self.config = config
  1433. self.self_attn = Ernie4_5Attention(config, layer_idx)
  1434. self.mlp = Ernie4_5MLP(config)
  1435. Norm = RMSNorm if config.use_rmsnorm else LayerNorm
  1436. self.input_layernorm = Norm(config)
  1437. self.post_attention_layernorm = Norm(config)
  1438. self.residual_add1 = FusedDropoutImpl(
  1439. config.hidden_dropout_prob, mode="upscale_in_train"
  1440. )
  1441. self.residual_add2 = FusedDropoutImpl(
  1442. config.hidden_dropout_prob, mode="upscale_in_train"
  1443. )
  1444. if config.sequence_parallel:
  1445. mark_as_sequence_parallel_parameter(self.post_attention_layernorm.weight)
  1446. if not hasattr(config, "disable_ffn_model_parallel"):
  1447. mark_as_sequence_parallel_parameter(self.input_layernorm.weight)
  1448. if config.use_bias:
  1449. mark_as_sequence_parallel_parameter(self.self_attn.o_proj.bias)
  1450. mark_as_sequence_parallel_parameter(self.mlp.down_proj.bias)
  1451. if not config.use_rmsnorm and config.use_bias:
  1452. mark_as_sequence_parallel_parameter(self.post_attention_layernorm.bias)
  1453. mark_as_sequence_parallel_parameter(self.input_layernorm.bias)
  1454. def forward(
  1455. self,
  1456. hidden_states: paddle.Tensor,
  1457. position_embeddings: paddle.Tensor,
  1458. attention_mask: Optional[paddle.Tensor] = None,
  1459. attn_mask_start_row_indices: Optional[paddle.Tensor] = None,
  1460. position_ids: Optional[paddle.Tensor] = None,
  1461. token_type_ids: Optional[paddle.Tensor] = None,
  1462. output_attentions: Optional[bool] = False,
  1463. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  1464. use_cache: Optional[bool] = False,
  1465. ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
  1466. """Forward pass through the decoder layer.
  1467. Args:
  1468. hidden_states (paddle.Tensor): Input tensor [batch_size, seq_len, hidden_size]
  1469. position_embeddings (paddle.Tensor): Position embeddings
  1470. attention_mask (Optional[paddle.Tensor]): Attention mask tensor
  1471. attn_mask_start_row_indices (Optional[paddle.Tensor]): Indices for variable length attention
  1472. position_ids (Optional[paddle.Tensor]): Position indices for rotary embeddings
  1473. output_attentions (Optional[bool]): Whether to return attention weights
  1474. past_key_value (Optional[Tuple[paddle.Tensor]]): Cached key/value states
  1475. use_cache (Optional[bool]): Whether to cache key/value states
  1476. Returns:
  1477. Union: Various output combinations depending on arguments:
  1478. - Base case: Hidden states tensor
  1479. - With attention: Tuple of (hidden_states, attention_weights)
  1480. - With cache: Tuple of (hidden_states, cached_key_value)
  1481. """
  1482. residual = hidden_states
  1483. hidden_states = self.input_layernorm(hidden_states)
  1484. # Self Attention
  1485. has_gradient = not hidden_states.stop_gradient
  1486. if (
  1487. self.config.recompute
  1488. and self.config.recompute_granularity == "full_attn"
  1489. and has_gradient
  1490. ):
  1491. hidden_states, self_attn_weights, present_key_value = recompute(
  1492. self.self_attn,
  1493. hidden_states,
  1494. position_embeddings,
  1495. past_key_value,
  1496. attention_mask,
  1497. attn_mask_start_row_indices,
  1498. position_ids,
  1499. output_attentions,
  1500. use_cache,
  1501. use_reentrant=self.config.recompute_use_reentrant,
  1502. )
  1503. else:
  1504. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  1505. hidden_states=hidden_states,
  1506. position_embeddings=position_embeddings,
  1507. past_key_value=past_key_value,
  1508. attention_mask=attention_mask,
  1509. attn_mask_start_row_indices=attn_mask_start_row_indices,
  1510. position_ids=position_ids,
  1511. output_attentions=output_attentions,
  1512. use_cache=use_cache,
  1513. token_type_ids=token_type_ids,
  1514. )
  1515. with self.model_parallel_dropout():
  1516. hidden_states = self.residual_add1(hidden_states, residual)
  1517. # Fully Connected
  1518. residual = hidden_states
  1519. hidden_states = self.post_attention_layernorm(hidden_states)
  1520. hidden_states = self.mlp(hidden_states)
  1521. with self.model_parallel_dropout():
  1522. hidden_states = self.residual_add2(hidden_states, residual)
  1523. outputs = (hidden_states,)
  1524. if output_attentions:
  1525. outputs += (self_attn_weights,)
  1526. if use_cache:
  1527. outputs += (present_key_value,)
  1528. # remove empty tuple for pipeline parallel
  1529. if type(outputs) is tuple and len(outputs) == 1:
  1530. outputs = outputs[0]
  1531. return outputs
  1532. def model_parallel_dropout(self):
  1533. """Get context manager for model-parallel dropout with proper seed control.
  1534. Returns:
  1535. Context manager for dropout operation
  1536. """
  1537. if (
  1538. self.config.tensor_parallel_degree > 1
  1539. and self.config.hidden_dropout_prob > 0.0
  1540. ):
  1541. current_seed = (
  1542. "local_seed" if self.config.sequence_parallel else "global_seed"
  1543. )
  1544. return get_rng_state_tracker().rng_state(current_seed)
  1545. return contextlib.nullcontext()
  1546. class Ernie4_5PretrainedModel(PretrainedModel):
  1547. """Base class for ERNIE pretrained models."""
  1548. config_class = PaddleOCRVLConfig
  1549. base_model_prefix = "ernie"
  1550. @classmethod
  1551. def _get_tensor_parallel_mappings(cls, config, is_split=True):
  1552. """Generate tensor parallel mappings for model conversion.
  1553. Args:
  1554. config (PaddleOCRVLConfig): Model configuration.
  1555. is_split (bool): Whether to generate split mappings (True)
  1556. or merge mappings (False). Defaults to True.
  1557. Returns:
  1558. Dict[str, Callable[[Any], Any]]: Dictionary mapping parameter names
  1559. to their corresponding split/merge functions for tensor parallelism.
  1560. """
  1561. from ..conversion_utils import split_or_merge_func
  1562. fn = split_or_merge_func(
  1563. is_split=is_split,
  1564. tensor_parallel_degree=config.tensor_parallel_degree,
  1565. tensor_parallel_rank=config.tensor_parallel_rank,
  1566. num_attention_heads=config.num_attention_heads,
  1567. )
  1568. def gqa_qkv_split_func(
  1569. weight,
  1570. tensor_parallel_degree,
  1571. tensor_parallel_rank,
  1572. num_attention_heads,
  1573. num_key_value_heads,
  1574. head_dim,
  1575. is_quant=False,
  1576. is_split=True,
  1577. ):
  1578. if is_quant:
  1579. weight = weight.T
  1580. def get_shape(tensor):
  1581. return (
  1582. tensor.get_shape() if hasattr(tensor, "get_shape") else tensor.shape
  1583. )
  1584. def slice_tensor(tensor, start, end):
  1585. shape = get_shape(tensor)
  1586. if len(shape) == 1:
  1587. return tensor[start:end]
  1588. else:
  1589. return tensor[..., start:end]
  1590. q_end = num_attention_heads * head_dim
  1591. k_end = q_end + num_key_value_heads * head_dim
  1592. v_end = k_end + num_key_value_heads * head_dim
  1593. q = slice_tensor(weight, 0, q_end)
  1594. k = slice_tensor(weight, q_end, k_end)
  1595. v = slice_tensor(weight, k_end, v_end)
  1596. def split_tensor(tensor, degree):
  1597. shape = get_shape(tensor)
  1598. size = shape[-1]
  1599. block_size = size // degree
  1600. if hasattr(tensor, "get_shape"):
  1601. return [
  1602. slice_tensor(tensor, i * block_size, (i + 1) * block_size)
  1603. for i in range(degree)
  1604. ]
  1605. else:
  1606. return np.split(tensor, degree, axis=-1)
  1607. q_list = split_tensor(q, tensor_parallel_degree)
  1608. k_list = split_tensor(k, tensor_parallel_degree)
  1609. v_list = split_tensor(v, tensor_parallel_degree)
  1610. if tensor_parallel_rank is None:
  1611. out = [
  1612. np.concatenate([q_i, k_i, v_i], axis=-1)
  1613. for q_i, k_i, v_i in zip(q_list, k_list, v_list)
  1614. ]
  1615. else:
  1616. out = np.concatenate(
  1617. [
  1618. q_list[tensor_parallel_rank],
  1619. k_list[tensor_parallel_rank],
  1620. v_list[tensor_parallel_rank],
  1621. ],
  1622. axis=-1,
  1623. )
  1624. if is_quant:
  1625. out = out.T
  1626. return out
  1627. def gqa_qkv_merge_func(
  1628. weight_list,
  1629. num_attention_heads,
  1630. num_key_value_heads,
  1631. head_dim,
  1632. is_quant=False,
  1633. is_split=False,
  1634. ):
  1635. tensor_parallel_degree = len(weight_list)
  1636. num_attention_heads = num_attention_heads // tensor_parallel_degree
  1637. num_key_value_heads = num_key_value_heads // tensor_parallel_degree
  1638. is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
  1639. def get_shape(tensor):
  1640. return (
  1641. tensor.get_shape() if hasattr(tensor, "get_shape") else tensor.shape
  1642. )
  1643. def slice_tensor(tensor, start, end):
  1644. if len(get_shape(tensor)) == 1:
  1645. return tensor[start:end]
  1646. else:
  1647. return tensor[..., start:end]
  1648. q_list, k_list, v_list = [], [], []
  1649. for weight in weight_list:
  1650. if is_quant:
  1651. weight = weight.T
  1652. q_end = num_attention_heads * head_dim
  1653. k_end = q_end + num_key_value_heads * head_dim
  1654. v_end = k_end + num_key_value_heads * head_dim
  1655. q = slice_tensor(weight, 0, q_end)
  1656. k = slice_tensor(weight, q_end, k_end)
  1657. v = slice_tensor(weight, k_end, v_end)
  1658. q_list.append(q)
  1659. k_list.append(k)
  1660. v_list.append(v)
  1661. merged = q_list + k_list + v_list
  1662. if is_paddle_tensor:
  1663. tensor = paddle.concat(merged, axis=-1)
  1664. if tensor.place.is_gpu_place():
  1665. tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
  1666. else:
  1667. tensor = np.concatenate(merged, axis=-1)
  1668. if is_quant:
  1669. tensor = tensor.T
  1670. return tensor
  1671. if (
  1672. config.num_key_value_heads is not None
  1673. and config.num_key_value_heads != config.num_attention_heads
  1674. ):
  1675. if is_split:
  1676. qkv_fn = partial(
  1677. gqa_qkv_split_func,
  1678. tensor_parallel_degree=config.tensor_parallel_degree,
  1679. tensor_parallel_rank=config.tensor_parallel_rank,
  1680. num_attention_heads=config.num_attention_heads,
  1681. num_key_value_heads=config.num_key_value_heads,
  1682. head_dim=(
  1683. config.hidden_size // config.num_attention_heads
  1684. if config.head_dim is None
  1685. else config.head_dim
  1686. ),
  1687. is_quant=False,
  1688. is_split=True,
  1689. )
  1690. else:
  1691. qkv_fn = partial(
  1692. gqa_qkv_merge_func,
  1693. num_attention_heads=config.num_attention_heads,
  1694. num_key_value_heads=config.num_key_value_heads,
  1695. head_dim=(
  1696. config.hidden_size // config.num_attention_heads
  1697. if config.head_dim is None
  1698. else config.head_dim
  1699. ),
  1700. is_quant=False,
  1701. is_split=False,
  1702. )
  1703. else:
  1704. qkv_fn = partial(fn, is_column=True)
  1705. def get_tensor_parallel_split_mappings(num_hidden_layers):
  1706. final_actions = {}
  1707. base_actions = {
  1708. # Column Linear
  1709. "layers.0.self_attn.qkv_proj.weight": qkv_fn,
  1710. "layers.0.mlp.up_gate_proj.weight": partial(
  1711. fn, is_column=True, is_naive_2fuse=True
  1712. ),
  1713. "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings),
  1714. # Row Linear
  1715. "embed_tokens.weight": partial(fn, is_column=False),
  1716. "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
  1717. "layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
  1718. }
  1719. if config.use_bias:
  1720. base_actions.update(
  1721. {
  1722. # Column Linear
  1723. "layers.0.self_attn.qkv_proj.bias": qkv_fn,
  1724. "layers.0.mlp.up_gate_proj.bias": partial(
  1725. fn, is_column=True, is_naive_2fuse=True
  1726. ),
  1727. "layers.0.mlp.down_proj.bias": lambda x: x[
  1728. :
  1729. ], # convert PySafeSlice to ndarray.
  1730. "lm_head.bias": partial(fn, is_column=True),
  1731. }
  1732. )
  1733. for key, action in base_actions.items():
  1734. if "layers.0." in key:
  1735. for i in range(num_hidden_layers):
  1736. final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
  1737. else:
  1738. final_actions[key] = action
  1739. return final_actions
  1740. mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
  1741. return mappings
  1742. class Ernie4_5Model(Ernie4_5PretrainedModel):
  1743. """The core ERNIE transformer model"""
  1744. def __init__(self, config: PaddleOCRVLConfig):
  1745. """Initialize the ERNIE model architecture.
  1746. Args:
  1747. config (PaddleOCRVLConfig): Model configuration.
  1748. """
  1749. super().__init__(config)
  1750. self.padding_idx = config.pad_token_id
  1751. self.vocab_size = config.vocab_size
  1752. self.hidden_size = config.hidden_size
  1753. self.config = config
  1754. if config.tensor_parallel_degree > 1:
  1755. self.embed_tokens = VocabParallelEmbedding(
  1756. self.vocab_size,
  1757. self.hidden_size,
  1758. )
  1759. else:
  1760. self.embed_tokens = nn.Embedding(
  1761. self.vocab_size,
  1762. self.hidden_size,
  1763. )
  1764. self.layers = nn.LayerList(
  1765. [Ernie4_5DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
  1766. )
  1767. Norm = RMSNorm if config.use_rmsnorm else LayerNorm
  1768. self.norm = Norm(config)
  1769. self.rotary_emb = KeyeRotaryEmbedding(config=config)
  1770. self.gradient_checkpointing = False
  1771. def get_input_embeddings(self):
  1772. """Get the input embedding layer.
  1773. Returns:
  1774. nn.Embedding: The embedding layer for input tokens
  1775. """
  1776. return self.embed_tokens
  1777. def set_input_embeddings(self, value):
  1778. """Set new input embeddings.
  1779. Args:
  1780. value (nn.Embedding): New embedding layer to use
  1781. """
  1782. self.embed_tokens = value
  1783. @paddle.jit.not_to_static
  1784. def recompute_training(
  1785. self,
  1786. layer_module,
  1787. hidden_states,
  1788. position_embeddings,
  1789. attention_mask,
  1790. attn_mask_start_row_indices,
  1791. position_ids,
  1792. token_type_ids,
  1793. output_attentions,
  1794. past_key_value,
  1795. use_cache,
  1796. ):
  1797. """Perform gradient checkpointing for memory-efficient training.
  1798. Args:
  1799. layer_module (nn.Layer): Transformer layer to recompute
  1800. hidden_states (paddle.Tensor): Input hidden states
  1801. position_embeddings (paddle.Tensor): Position embeddings
  1802. attention_mask (paddle.Tensor): Attention mask
  1803. attn_mask_start_row_indices (paddle.Tensor): Variable length indices
  1804. position_ids (paddle.Tensor): Position indices
  1805. output_attentions (bool): Whether to output attention weights
  1806. past_key_value (Optional[Tuple[paddle.Tensor]]): Cached key/value states
  1807. use_cache (bool): Whether to cache key/value states
  1808. Returns:
  1809. paddle.Tensor: Output hidden states after recomputation
  1810. """
  1811. def create_custom_forward(module):
  1812. def custom_forward(*inputs):
  1813. return module(*inputs, output_gate_logits=False)
  1814. return custom_forward
  1815. hidden_states = recompute(
  1816. create_custom_forward(layer_module),
  1817. hidden_states,
  1818. position_embeddings,
  1819. attention_mask,
  1820. attn_mask_start_row_indices,
  1821. position_ids,
  1822. token_type_ids,
  1823. output_attentions,
  1824. past_key_value,
  1825. use_cache,
  1826. )
  1827. return hidden_states
  1828. def forward(
  1829. self,
  1830. input_ids=None,
  1831. position_ids=None,
  1832. token_type_ids=None,
  1833. attention_mask=None,
  1834. attn_mask_start_row_indices=None,
  1835. inputs_embeds=None,
  1836. use_cache=None,
  1837. past_key_values=None,
  1838. output_attentions=False,
  1839. output_hidden_states=None,
  1840. return_dict=False,
  1841. ):
  1842. """Forward pass through the ERNIE model.
  1843. Args:
  1844. input_ids (Optional[paddle.Tensor]): Input token IDs
  1845. position_ids (Optional[paddle.Tensor]): Position indices
  1846. attention_mask (Optional[paddle.Tensor]): Attention mask
  1847. attn_mask_start_row_indices (Optional[paddle.Tensor]): Variable length attention indices
  1848. inputs_embeds (Optional[paddle.Tensor]): Precomputed embeddings
  1849. use_cache (Optional[bool]): Whether to cache key/value states
  1850. past_key_values (Optional[Tuple[Tuple[paddle.Tensor]]]): Cached key/value states
  1851. output_attentions (Optional[bool]): Whether to output attention weights
  1852. output_hidden_states (Optional[bool]): Whether to output all hidden states
  1853. return_dict (Optional[bool]): Whether to return dict or tuple
  1854. Returns:
  1855. Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
  1856. Various outputs depending on configuration, including:
  1857. - last_hidden_state: Final layer hidden states
  1858. - past_key_values: Cached key/value states if use_cache=True
  1859. - hidden_states: All hidden states if output_hidden_states=True
  1860. - attentions: Attention weights if output_attentions=True
  1861. """
  1862. output_attentions = (
  1863. output_attentions
  1864. if output_attentions is not None
  1865. else self.config.output_attentions
  1866. )
  1867. output_hidden_states = (
  1868. output_hidden_states
  1869. if output_hidden_states is not None
  1870. else self.config.output_hidden_states
  1871. )
  1872. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1873. return_dict = (
  1874. return_dict if return_dict is not None else self.config.use_return_dict
  1875. )
  1876. # retrieve input_ids and inputs_embeds
  1877. if input_ids is not None and inputs_embeds is not None:
  1878. raise ValueError(
  1879. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  1880. )
  1881. elif input_ids is not None:
  1882. batch_size, seq_length = input_ids.shape
  1883. elif inputs_embeds is not None:
  1884. batch_size, seq_length, _ = inputs_embeds.shape
  1885. else:
  1886. raise ValueError(
  1887. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  1888. )
  1889. if batch_size != 1:
  1890. raise NotImplementedError
  1891. layers = self.layers[: self.config.num_hidden_layers]
  1892. if past_key_values is None:
  1893. past_key_values = tuple([None] * len(layers))
  1894. kv_seq_len = 0
  1895. else:
  1896. kv_seq_len = past_key_values[0][0].shape[1]
  1897. if inputs_embeds is None:
  1898. inputs_embeds = self.embed_tokens(input_ids)
  1899. inputs_embeds = inputs_embeds.astype(self.embed_tokens.weight.dtype)
  1900. if self.config.sequence_parallel:
  1901. inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]])
  1902. inputs_embeds = ScatterOp.apply(inputs_embeds)
  1903. hidden_states = inputs_embeds
  1904. if position_ids is None or position_ids.dim() == 2:
  1905. raise NotImplementedError
  1906. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  1907. if attention_mask is None:
  1908. raise NotImplementedError
  1909. causal_mask = self._update_causal_mask(
  1910. attention_mask.astype("int64"),
  1911. inputs_embeds,
  1912. past_key_values,
  1913. output_attentions,
  1914. )
  1915. # decoder layers
  1916. all_hidden_states = () if output_hidden_states else None
  1917. all_self_attns = () if output_attentions else None
  1918. next_decoder_cache = () if use_cache else None
  1919. for idx, (decoder_layer) in enumerate(layers):
  1920. if output_hidden_states:
  1921. all_hidden_states += (hidden_states,)
  1922. past_key_value = (
  1923. past_key_values[idx] if past_key_values is not None else None
  1924. )
  1925. has_gradient = not hidden_states.stop_gradient
  1926. if (
  1927. self.config.recompute
  1928. and self.config.recompute_granularity == "full"
  1929. and has_gradient
  1930. ):
  1931. layer_outputs = self.recompute_training(
  1932. decoder_layer,
  1933. hidden_states,
  1934. position_embeddings,
  1935. causal_mask,
  1936. attn_mask_start_row_indices,
  1937. position_ids,
  1938. token_type_ids,
  1939. output_attentions,
  1940. past_key_value,
  1941. use_cache,
  1942. )
  1943. else:
  1944. layer_outputs = decoder_layer(
  1945. hidden_states,
  1946. position_embeddings,
  1947. causal_mask,
  1948. attn_mask_start_row_indices,
  1949. position_ids,
  1950. token_type_ids,
  1951. output_attentions,
  1952. past_key_value,
  1953. use_cache,
  1954. )
  1955. if isinstance(layer_outputs, (tuple, list)):
  1956. hidden_states = layer_outputs[0]
  1957. else:
  1958. hidden_states = layer_outputs
  1959. if use_cache:
  1960. next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
  1961. if output_attentions:
  1962. all_self_attns += (layer_outputs[1],)
  1963. hidden_states = self.norm(hidden_states)
  1964. # add hidden states from the last decoder layer
  1965. if output_hidden_states:
  1966. all_hidden_states += (hidden_states,)
  1967. next_cache = next_decoder_cache if use_cache else None
  1968. if not return_dict:
  1969. return tuple(
  1970. v
  1971. for v in [
  1972. hidden_states,
  1973. next_cache,
  1974. all_hidden_states,
  1975. all_self_attns,
  1976. ]
  1977. if v is not None
  1978. )
  1979. return BaseModelOutputWithPastAndCrossAttentions(
  1980. last_hidden_state=hidden_states,
  1981. past_key_values=next_cache,
  1982. hidden_states=all_hidden_states,
  1983. attentions=all_self_attns,
  1984. cross_attentions=None,
  1985. )
  1986. def _update_causal_mask(
  1987. self,
  1988. attention_mask: paddle.Tensor,
  1989. input_tensor: paddle.Tensor,
  1990. past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]],
  1991. output_attentions: bool = False,
  1992. ):
  1993. past_seen_tokens = (
  1994. past_key_values[0][0].shape[1]
  1995. if past_key_values is not None and past_key_values[0] is not None
  1996. else 0
  1997. )
  1998. dtype = input_tensor.dtype
  1999. min_dtype = paddle.finfo(dtype).min
  2000. sequence_length = input_tensor.shape[1]
  2001. target_length = (
  2002. attention_mask.shape[-1]
  2003. if isinstance(attention_mask, paddle.Tensor)
  2004. else past_seen_tokens + sequence_length + 1
  2005. )
  2006. cache_position = paddle.arange(
  2007. past_seen_tokens, past_seen_tokens + sequence_length
  2008. )
  2009. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  2010. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  2011. attention_mask,
  2012. sequence_length=sequence_length,
  2013. target_length=target_length,
  2014. dtype=dtype,
  2015. cache_position=cache_position,
  2016. batch_size=input_tensor.shape[0],
  2017. )
  2018. return causal_mask
  2019. @staticmethod
  2020. def _prepare_4d_causal_attention_mask_with_cache_position(
  2021. attention_mask,
  2022. sequence_length: int,
  2023. target_length: int,
  2024. dtype,
  2025. cache_position,
  2026. batch_size: int,
  2027. ):
  2028. if attention_mask is not None and attention_mask.dim() == 4:
  2029. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  2030. causal_mask = attention_mask
  2031. else:
  2032. min_dtype = paddle.finfo(dtype).min
  2033. causal_mask = paddle.full(
  2034. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype
  2035. )
  2036. diagonal_attend_mask = paddle.arange(
  2037. target_length
  2038. ) > cache_position.reshape((-1, 1))
  2039. diagonal_attend_mask = diagonal_attend_mask.astype(causal_mask.dtype)
  2040. causal_mask *= diagonal_attend_mask
  2041. causal_mask = causal_mask[None, None, :, :].expand((batch_size, 1, -1, -1))
  2042. if attention_mask is not None:
  2043. causal_mask = (
  2044. causal_mask.clone()
  2045. ) # copy to contiguous memory for in-place edit
  2046. if attention_mask.shape[-1] > target_length:
  2047. attention_mask = attention_mask[:, :target_length]
  2048. mask_length = attention_mask.shape[-1]
  2049. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
  2050. :, None, None, :
  2051. ].astype(causal_mask.dtype)
  2052. padding_mask = padding_mask == 0
  2053. causal_mask[:, :, :, :mask_length] = causal_mask[
  2054. :, :, :, :mask_length
  2055. ].masked_fill(padding_mask, min_dtype)
  2056. return causal_mask