qwen2_vl.py 99 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. from dataclasses import dataclass
  17. from typing import Any, Dict, List, Optional, Tuple, Union
  18. import paddle
  19. import paddle.distributed.fleet.meta_parallel as mpu
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle import Tensor
  23. from paddle.distributed import fleet
  24. from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
  25. from paddle.distributed.fleet.utils import recompute
  26. from .....utils import logging
  27. from ....utils.benchmark import (
  28. benchmark,
  29. get_inference_operations,
  30. set_inference_operations,
  31. )
  32. from ...common.vlm.activations import ACT2FN
  33. from ...common.vlm.bert_padding import index_first_axis, pad_input, unpad_input
  34. from ...common.vlm.flash_attn_utils import has_flash_attn_func
  35. from ...common.vlm.transformers import PretrainedConfig, PretrainedModel
  36. from ...common.vlm.transformers.model_outputs import (
  37. BaseModelOutputWithPast,
  38. ModelOutput,
  39. )
  40. flash_attn_func, flash_attn_varlen_func = has_flash_attn_func()
  41. _IS_NPU = "npu" in paddle.get_device()
  42. Linear = nn.Linear
  43. ColumnParallelLinear = mpu.ColumnParallelLinear
  44. RowParallelLinear = mpu.RowParallelLinear
  45. class Qwen2VLVisionConfig(PretrainedConfig):
  46. model_type = "qwen2_vl"
  47. def __init__(
  48. self,
  49. depth=32,
  50. embed_dim=1280,
  51. hidden_size=3584,
  52. hidden_act="quick_gelu",
  53. mlp_ratio=4,
  54. num_heads=16,
  55. in_channels=3,
  56. patch_size=14,
  57. spatial_merge_size=2,
  58. temporal_patch_size=2,
  59. attn_implementation="eager", # new added
  60. **kwargs,
  61. ):
  62. super().__init__(**kwargs)
  63. self.depth = depth
  64. self.embed_dim = embed_dim
  65. self.hidden_size = hidden_size
  66. self.hidden_act = hidden_act
  67. self.mlp_ratio = mlp_ratio
  68. self.num_heads = num_heads
  69. self.in_channels = in_channels
  70. self.patch_size = patch_size
  71. self.spatial_merge_size = spatial_merge_size
  72. self.temporal_patch_size = temporal_patch_size
  73. self.attn_implementation = attn_implementation
  74. @classmethod
  75. def from_pretrained(
  76. cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
  77. ) -> "PretrainedConfig":
  78. config_dict, kwargs = cls.get_config_dict(
  79. pretrained_model_name_or_path, **kwargs
  80. )
  81. if config_dict.get("model_type") == "qwen2_vl":
  82. config_dict = config_dict["vision_config"]
  83. if (
  84. "model_type" in config_dict
  85. and hasattr(cls, "model_type")
  86. and config_dict["model_type"] != cls.model_type
  87. ):
  88. logging.warning(
  89. f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
  90. f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
  91. )
  92. return cls.from_dict(config_dict, **kwargs)
  93. class Qwen2VLConfig(PretrainedConfig):
  94. r"""
  95. This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
  96. Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
  97. with the defaults will yield a similar configuration to that of
  98. Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
  99. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  100. documentation from [`PretrainedConfig`] for more information.
  101. Args:
  102. vocab_size (`int`, *optional*, defaults to 152064):
  103. Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
  104. `inputs_ids` passed when calling [`Qwen2VLModel`]
  105. hidden_size (`int`, *optional*, defaults to 8192):
  106. Dimension of the hidden representations.
  107. intermediate_size (`int`, *optional*, defaults to 29568):
  108. Dimension of the MLP representations.
  109. num_hidden_layers (`int`, *optional*, defaults to 80):
  110. Number of hidden layers in the Transformer encoder.
  111. num_attention_heads (`int`, *optional*, defaults to 64):
  112. Number of attention heads for each attention layer in the Transformer encoder.
  113. num_key_value_heads (`int`, *optional*, defaults to 8):
  114. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  115. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  116. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  117. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  118. by meanpooling all the original heads within that group. For more details checkout [this
  119. paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
  120. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  121. The non-linear activation function (function or string) in the decoder.
  122. max_position_embeddings (`int`, *optional*, defaults to 32768):
  123. The maximum sequence length that this model might ever be used with.
  124. initializer_range (`float`, *optional*, defaults to 0.02):
  125. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  126. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  127. The epsilon used by the rms normalization layers.
  128. use_cache (`bool`, *optional*, defaults to `True`):
  129. Whether or not the model should return the last key/values attentions (not used by all models). Only
  130. relevant if `config.is_decoder=True`.
  131. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  132. Whether the model's input and output word embeddings should be tied.
  133. rope_theta (`float`, *optional*, defaults to 1000000.0):
  134. The base period of the RoPE embeddings.
  135. use_sliding_window (`bool`, *optional*, defaults to `False`):
  136. Whether to use sliding window attention.
  137. sliding_window (`int`, *optional*, defaults to 4096):
  138. Sliding window attention (SWA) window size. If not specified, will default to `4096`.
  139. max_window_layers (`int`, *optional*, defaults to 80):
  140. The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
  141. attention_dropout (`float`, *optional*, defaults to 0.0):
  142. The dropout ratio for the attention probabilities.
  143. vision_config (`Dict`, *optional*):
  144. The config for the visual encoder initialization.
  145. rope_scaling (`Dict`, *optional*):
  146. Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
  147. strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
  148. `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
  149. `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
  150. these scaling strategies behave:
  151. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
  152. experimental feature, subject to breaking API changes in future versions.
  153. """
  154. model_type = "qwen2_vl"
  155. keys_to_ignore_at_inference = ["past_key_values"]
  156. def __init__(
  157. self,
  158. vocab_size=152064,
  159. hidden_size=8192,
  160. intermediate_size=29568,
  161. num_hidden_layers=80,
  162. num_attention_heads=64,
  163. num_key_value_heads=8,
  164. hidden_act="silu",
  165. max_position_embeddings=32768,
  166. initializer_range=0.02,
  167. rms_norm_eps=1e-05,
  168. use_cache=True,
  169. tie_word_embeddings=False,
  170. rope_theta=1000000.0,
  171. use_sliding_window=False,
  172. sliding_window=4096,
  173. max_window_layers=80,
  174. attention_dropout=0.0,
  175. vision_config=None,
  176. rope_scaling=None,
  177. **kwargs,
  178. ):
  179. if isinstance(vision_config, dict):
  180. self.vision_config = Qwen2VLVisionConfig(**vision_config)
  181. elif vision_config is None:
  182. self.vision_config = Qwen2VLVisionConfig()
  183. self.vocab_size = vocab_size
  184. self.max_position_embeddings = max_position_embeddings
  185. self.hidden_size = hidden_size
  186. self.intermediate_size = intermediate_size
  187. self.num_hidden_layers = num_hidden_layers
  188. self.num_attention_heads = num_attention_heads
  189. self.use_sliding_window = use_sliding_window
  190. self.sliding_window = sliding_window
  191. self.max_window_layers = max_window_layers
  192. if num_key_value_heads is None:
  193. num_key_value_heads = num_attention_heads
  194. self.num_key_value_heads = num_key_value_heads
  195. self.hidden_act = hidden_act
  196. self.initializer_range = initializer_range
  197. self.rms_norm_eps = rms_norm_eps
  198. self.use_cache = use_cache
  199. self.rope_theta = rope_theta
  200. self.attention_dropout = attention_dropout
  201. self.rope_scaling = rope_scaling
  202. super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
  203. def get_triangle_upper_mask(x, mask=None):
  204. if mask is not None:
  205. return mask
  206. shape = x.shape
  207. shape[1] = 1
  208. mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype)
  209. mask = paddle.triu(mask, diagonal=1)
  210. mask.stop_gradient = True
  211. return mask
  212. def parallel_matmul(
  213. x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True
  214. ):
  215. is_fleet_init = True
  216. tensor_parallel_degree = 1
  217. try:
  218. hcg = fleet.get_hybrid_communicate_group()
  219. model_parallel_group = hcg.get_model_parallel_group()
  220. tensor_parallel_degree = hcg.get_model_parallel_world_size()
  221. except:
  222. is_fleet_init = False
  223. if paddle.in_dynamic_mode():
  224. y_is_distributed = y.is_distributed
  225. else:
  226. y_is_distributed = tensor_parallel_degree > 1
  227. if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
  228. input_parallel = paddle.distributed.collective._c_identity(
  229. x, group=model_parallel_group
  230. )
  231. logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
  232. if tensor_parallel_output:
  233. return logits
  234. return paddle.distributed.collective._c_concat(
  235. logits, group=model_parallel_group
  236. )
  237. else:
  238. logits = paddle.matmul(x, y, transpose_y=transpose_y)
  239. return logits
  240. def _compute_default_rope_parameters(
  241. config: Optional[PretrainedConfig] = None,
  242. device: Optional["paddle.device"] = None,
  243. seq_len: Optional[int] = None,
  244. **rope_kwargs,
  245. ) -> Tuple["paddle.Tensor", float]:
  246. """
  247. Computes the inverse frequencies according to the original RoPE implementation
  248. Args:
  249. config ([`~transformers.PretrainedConfig`]):
  250. The model configuration.
  251. device (`paddle.device`):
  252. The device to use for initialization of the inverse frequencies.
  253. seq_len (`int`, *optional*):
  254. The current sequence length. Unused for this type of RoPE.
  255. rope_kwargs (`Dict`, *optional*):
  256. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
  257. Returns:
  258. Tuple of (`paddle.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  259. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  260. """
  261. if config is not None and len(rope_kwargs) > 0:
  262. raise ValueError(
  263. "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
  264. f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
  265. )
  266. if len(rope_kwargs) > 0:
  267. base = rope_kwargs["base"]
  268. dim = rope_kwargs["dim"]
  269. elif config is not None:
  270. base = config.rope_theta
  271. partial_rotary_factor = (
  272. config.partial_rotary_factor
  273. if hasattr(config, "partial_rotary_factor")
  274. else 1.0
  275. )
  276. head_dim = getattr(
  277. config, "head_dim", config.hidden_size // config.num_attention_heads
  278. )
  279. dim = int(head_dim * partial_rotary_factor)
  280. attention_factor = 1.0
  281. inv_freq = 1.0 / (
  282. base ** (paddle.arange(0, dim, 2, dtype="int64").astype("float32") / dim)
  283. )
  284. return inv_freq, attention_factor
  285. ROPE_INIT_FUNCTIONS = {
  286. "default": _compute_default_rope_parameters,
  287. }
  288. def _get_unpad_data(attention_mask):
  289. seqlens_in_batch = attention_mask.sum(axis=-1, dtype="int32")
  290. indices = paddle.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  291. max_seqlen_in_batch = seqlens_in_batch.max().item() # [2, 1, 1323]
  292. cu_seqlens = F.pad(
  293. paddle.cumsum(seqlens_in_batch, axis=0), (1, 0), data_format="NCL"
  294. )
  295. return (
  296. indices,
  297. cu_seqlens,
  298. max_seqlen_in_batch,
  299. )
  300. def is_casual_mask(attention_mask):
  301. """
  302. Upper triangular of attention_mask equals to attention_mask is casual
  303. """
  304. return (paddle.triu(attention_mask) == attention_mask).all().item()
  305. def _make_causal_mask(input_ids_shape, past_key_values_length):
  306. """
  307. Make causal mask used for self-attention
  308. """
  309. batch_size, target_length = input_ids_shape
  310. mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
  311. if past_key_values_length > 0:
  312. mask = paddle.concat(
  313. [paddle.ones([target_length, past_key_values_length], dtype="bool"), mask],
  314. axis=-1,
  315. )
  316. return mask[None, None, :, :].expand(
  317. [batch_size, 1, target_length, target_length + past_key_values_length]
  318. )
  319. def _expand_2d_mask(mask, dtype, tgt_length):
  320. """
  321. Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
  322. """
  323. batch_size, src_length = mask.shape[0], mask.shape[-1]
  324. tgt_length = tgt_length if tgt_length is not None else src_length
  325. mask = mask[:, None, None, :].astype("bool")
  326. mask.stop_gradient = True
  327. expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
  328. return expanded_mask
  329. @dataclass
  330. class Qwen2VLCausalLMOutputWithPast(ModelOutput):
  331. """
  332. Base class for Qwen2VL causal language model (or autoregressive) outputs.
  333. Args:
  334. loss (`paddle.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  335. Language modeling loss (for next-token prediction).
  336. logits (`paddle.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  337. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  338. past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  339. Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  340. `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
  341. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  342. `past_key_values` input) to speed up sequential decoding.
  343. hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  344. Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
  345. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  346. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  347. attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  348. Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  349. sequence_length)`.
  350. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  351. heads.
  352. rope_deltas (`paddle.Tensor` of shape `(batch_size, )`, *optional*):
  353. The rope index difference between sequence length and multimodal rope.
  354. """
  355. loss: Optional[paddle.Tensor] = None
  356. logits: paddle.Tensor = None
  357. past_key_values: Optional[List[paddle.Tensor]] = None
  358. hidden_states: Optional[Tuple[paddle.Tensor]] = None
  359. attentions: Optional[Tuple[paddle.Tensor]] = None
  360. rope_deltas: Optional[paddle.Tensor] = None
  361. class Qwen2VLRotaryEmbedding(nn.Layer):
  362. def __init__(
  363. self,
  364. dim=None,
  365. max_position_embeddings=2048,
  366. base=10000,
  367. device=None,
  368. scaling_factor=1.0,
  369. rope_type="default",
  370. config: Optional[Qwen2VLConfig] = None,
  371. ):
  372. super().__init__()
  373. self.rope_kwargs = {}
  374. if config is None:
  375. self.rope_kwargs = {
  376. "rope_type": rope_type,
  377. "factor": scaling_factor,
  378. "dim": dim,
  379. "base": base,
  380. "max_position_embeddings": max_position_embeddings,
  381. }
  382. self.rope_type = rope_type
  383. self.max_seq_len_cached = max_position_embeddings
  384. self.original_max_seq_len = max_position_embeddings
  385. else:
  386. # BC: "rope_type" was originally "type"
  387. if config.rope_scaling is not None:
  388. self.rope_type = config.rope_scaling.get(
  389. "rope_type", config.rope_scaling.get("type")
  390. )
  391. else:
  392. self.rope_type = "default"
  393. self.max_seq_len_cached = config.max_position_embeddings
  394. self.original_max_seq_len = config.max_position_embeddings
  395. self.config = config
  396. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  397. self.inv_freq, self.attention_scaling = self.rope_init_fn(
  398. self.config, device, **self.rope_kwargs
  399. )
  400. self.original_inv_freq = self.inv_freq
  401. self._set_cos_sin_cache(seq_len=max_position_embeddings)
  402. def _set_cos_sin_cache(self, seq_len):
  403. self.max_seq_len_cached = seq_len
  404. t = paddle.arange(seq_len, dtype="float32")
  405. freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
  406. emb = paddle.concat([freqs, freqs], axis=-1)
  407. self.cos_cached = emb.cos()
  408. self.sin_cached = emb.sin()
  409. def _dynamic_frequency_update(self, position_ids, device):
  410. """
  411. dynamic RoPE layers should recompute `inv_freq` in the following situations:
  412. 1 - growing beyond the cached sequence length (allow scaling)
  413. 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
  414. """
  415. seq_len = paddle.max(position_ids) + 1
  416. if seq_len > self.max_seq_len_cached: # growth
  417. inv_freq, self.attention_scaling = self.rope_init_fn(
  418. self.config, device, seq_len=seq_len, **self.rope_kwargs
  419. )
  420. self.inv_freq = inv_freq
  421. self.max_seq_len_cached = seq_len
  422. if (
  423. seq_len < self.original_max_seq_len
  424. and self.max_seq_len_cached > self.original_max_seq_len
  425. ): # reset
  426. self.inv_freq = self.original_inv_freq
  427. self.max_seq_len_cached = self.original_max_seq_len
  428. @paddle.no_grad()
  429. def forward(self, x, position_ids):
  430. if "dynamic" in self.rope_type:
  431. self._dynamic_frequency_update(position_ids, device=x.device)
  432. inv_freq_expanded = (
  433. self.inv_freq[None, None, :, None]
  434. .astype("float32")
  435. .expand([3, position_ids.shape[1], -1, 1])
  436. )
  437. position_ids_expanded = position_ids[:, :, None, :].astype("float32")
  438. device_type = paddle.get_device()
  439. device_type = (
  440. device_type
  441. if isinstance(device_type, str) and device_type != "mps"
  442. else "cpu"
  443. )
  444. with paddle.amp.auto_cast():
  445. freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded)
  446. freqs = freqs.transpose([0, 1, 3, 2])
  447. emb = paddle.concat((freqs, freqs), axis=-1)
  448. cos = emb.cos()
  449. sin = emb.sin()
  450. cos = cos * self.attention_scaling
  451. sin = sin * self.attention_scaling
  452. return cos.astype(x.dtype), sin.astype(x.dtype)
  453. # Copied from transformers.models.llama.modeling_llama.rotate_half
  454. def rotate_half(x):
  455. """Rotates half the hidden dims of the input."""
  456. x1 = x[..., : x.shape[-1] // 2]
  457. x2 = x[..., x.shape[-1] // 2 :]
  458. return paddle.concat([-x2, x1], axis=-1)
  459. def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
  460. """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
  461. Explanation:
  462. Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
  463. sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
  464. vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
  465. Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
  466. For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
  467. height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
  468. difference with modern LLMs.
  469. Args:
  470. q (`paddle.Tensor`): The query tensor.
  471. k (`paddle.Tensor`): The key tensor.
  472. cos (`paddle.Tensor`): The cosine part of the rotary embedding.
  473. sin (`paddle.Tensor`): The sine part of the rotary embedding.
  474. position_ids (`paddle.Tensor`):
  475. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  476. used to pass offsetted position ids when working with a KV-cache.
  477. mrope_section(`List(int)`):
  478. Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
  479. unsqueeze_dim (`int`, *optional*, defaults to 1):
  480. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  481. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  482. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  483. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  484. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  485. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  486. Returns:
  487. `tuple(paddle.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  488. """
  489. mrope_section = mrope_section * 2
  490. cos = paddle.concat(
  491. x=[m[i % 3] for i, m in enumerate(cos.split(mrope_section, axis=-1))], axis=-1
  492. ).unsqueeze(axis=unsqueeze_dim)
  493. sin = paddle.concat(
  494. x=[m[i % 3] for i, m in enumerate(sin.split(mrope_section, axis=-1))], axis=-1
  495. ).unsqueeze(axis=unsqueeze_dim)
  496. q_embed = (q * cos) + (rotate_half(q) * sin)
  497. k_embed = (k * cos) + (rotate_half(k) * sin)
  498. return q_embed, k_embed
  499. def apply_rotary_pos_emb_vision(
  500. tensor: paddle.Tensor, freqs: paddle.Tensor
  501. ) -> paddle.Tensor:
  502. orig_dtype = tensor.dtype
  503. with paddle.amp.auto_cast(False):
  504. tensor = tensor.astype(dtype="float32")
  505. cos = freqs.cos()
  506. sin = freqs.sin()
  507. cos = (
  508. cos.unsqueeze(1)
  509. .tile(repeat_times=[1, 1, 2])
  510. .unsqueeze(0)
  511. .astype(dtype="float32")
  512. )
  513. sin = (
  514. sin.unsqueeze(1)
  515. .tile(repeat_times=[1, 1, 2])
  516. .unsqueeze(0)
  517. .astype(dtype="float32")
  518. )
  519. output = tensor * cos + rotate_half(tensor) * sin
  520. output = paddle.cast(output, orig_dtype)
  521. return output
  522. class VisionRotaryEmbedding(nn.Layer):
  523. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  524. super().__init__()
  525. self.inv_freq = 1.0 / theta ** (
  526. paddle.arange(start=0, end=dim, step=2, dtype="float32") / dim
  527. )
  528. def forward(self, seqlen: int) -> paddle.Tensor:
  529. seq = paddle.arange(seqlen).cast(self.inv_freq.dtype)
  530. freqs = paddle.outer(x=seq, y=self.inv_freq)
  531. return freqs
  532. class PatchEmbed(nn.Layer):
  533. def __init__(
  534. self,
  535. patch_size: int = 14,
  536. temporal_patch_size: int = 2,
  537. in_channels: int = 3,
  538. embed_dim: int = 1152,
  539. ) -> None:
  540. super().__init__()
  541. self.patch_size = patch_size
  542. self.temporal_patch_size = temporal_patch_size
  543. self.in_channels = in_channels
  544. self.embed_dim = embed_dim
  545. kernel_size = [temporal_patch_size, patch_size, patch_size]
  546. self.proj = nn.Conv3D(
  547. in_channels,
  548. embed_dim,
  549. kernel_size=kernel_size,
  550. stride=kernel_size,
  551. bias_attr=False,
  552. )
  553. def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
  554. target_dtype = self.proj.weight.dtype
  555. hidden_states = hidden_states.reshape(
  556. [
  557. -1,
  558. self.in_channels,
  559. self.temporal_patch_size,
  560. self.patch_size,
  561. self.patch_size,
  562. ]
  563. )
  564. # NOTE(changwenbin): AttributeError: 'Variable' object has no attribute 'to'.
  565. # hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).reshape([-1, self.embed_dim])
  566. # hidden_states = paddle.cast(hidden_states, dtype=target_dtype)
  567. hidden_states = self.proj(
  568. paddle.cast(hidden_states, dtype=target_dtype)
  569. ).reshape([-1, self.embed_dim])
  570. return hidden_states
  571. class PatchMerger(nn.Layer):
  572. def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
  573. super().__init__()
  574. self.hidden_size = context_dim * (spatial_merge_size**2)
  575. self.ln_q = nn.LayerNorm(context_dim, epsilon=1e-6)
  576. self.mlp = nn.Sequential(
  577. nn.Linear(self.hidden_size, self.hidden_size),
  578. nn.GELU(),
  579. nn.Linear(self.hidden_size, dim),
  580. )
  581. def forward(self, x: paddle.Tensor) -> paddle.Tensor:
  582. x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size]))
  583. return x
  584. class VisionMlp(nn.Layer):
  585. def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
  586. super().__init__()
  587. self.fc1 = nn.Linear(dim, hidden_dim)
  588. self.act = ACT2FN[hidden_act]
  589. self.fc2 = nn.Linear(hidden_dim, dim)
  590. def forward(self, x) -> paddle.Tensor:
  591. return self.fc2(self.act(self.fc1(x)))
  592. class VisionAttention(nn.Layer):
  593. def __init__(self, dim: int, num_heads: int = 16) -> None:
  594. super().__init__()
  595. self.num_heads = num_heads
  596. self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
  597. self.proj = nn.Linear(dim, dim)
  598. self.head_dim = dim // num_heads # must added
  599. def forward(
  600. self,
  601. hidden_states: paddle.Tensor,
  602. cu_seqlens: paddle.Tensor,
  603. rotary_pos_emb: paddle.Tensor = None,
  604. ) -> paddle.Tensor:
  605. seq_length = hidden_states.shape[0]
  606. q, k, v = (
  607. self.qkv(hidden_states)
  608. .reshape([seq_length, 3, self.num_heads, -1])
  609. .transpose([1, 0, 2, 3])
  610. .unbind(0)
  611. )
  612. q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
  613. k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
  614. attention_mask = paddle.zeros([1, seq_length, seq_length], dtype="bool")
  615. for i in range(1, len(cu_seqlens)):
  616. attention_mask[
  617. ...,
  618. cu_seqlens[i - 1] : cu_seqlens[i],
  619. cu_seqlens[i - 1] : cu_seqlens[i],
  620. ] = True
  621. zero = paddle.zeros(attention_mask.shape, dtype=hidden_states.dtype)
  622. neg_inf = paddle.full_like(
  623. attention_mask,
  624. paddle.finfo(hidden_states.dtype).min,
  625. dtype=hidden_states.dtype,
  626. )
  627. attention_mask = paddle.where(attention_mask, zero, neg_inf)
  628. q = q.transpose([1, 0, 2])
  629. k = k.transpose([1, 0, 2])
  630. v = v.transpose([1, 0, 2])
  631. attn_weights = paddle.matmul(q, k.transpose([0, 2, 1])) / math.sqrt(
  632. self.head_dim
  633. )
  634. attn_weights = attn_weights + attention_mask
  635. attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype="float32")
  636. attn_output = paddle.matmul(attn_weights, v)
  637. attn_output = attn_output.transpose([1, 0, 2])
  638. attn_output = attn_output.reshape([seq_length, -1])
  639. attn_output = self.proj(attn_output)
  640. return attn_output
  641. class VisionFlashAttention2(nn.Layer):
  642. def __init__(self, dim: int, num_heads: int = 16) -> None:
  643. super().__init__()
  644. self.num_heads = num_heads
  645. self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
  646. self.proj = nn.Linear(dim, dim)
  647. self.head_dim = dim // num_heads # must added
  648. def forward(
  649. self,
  650. hidden_states: paddle.Tensor,
  651. cu_seqlens: paddle.Tensor,
  652. rotary_pos_emb: paddle.Tensor = None,
  653. ) -> paddle.Tensor:
  654. seq_length = tuple(hidden_states.shape)[0]
  655. qkv = (
  656. self.qkv(hidden_states)
  657. .reshape([seq_length, 3, self.num_heads, -1])
  658. .transpose(perm=[1, 0, 2, 3])
  659. )
  660. q, k, v = qkv.unbind(axis=0)
  661. q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(
  662. axis=0
  663. )
  664. k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(
  665. axis=0
  666. )
  667. if _IS_NPU:
  668. attn_output = paddle.nn.functional.flash_attention_npu(
  669. q.astype("bfloat16"),
  670. k.astype("bfloat16"),
  671. v.astype("bfloat16"),
  672. is_varlen=True,
  673. batch_size=1,
  674. seq_length=seq_length,
  675. ).reshape([seq_length, -1])
  676. else:
  677. max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
  678. softmax_scale = self.head_dim**-0.5
  679. attn_output = (
  680. flash_attn_varlen_func(
  681. q.astype("bfloat16"),
  682. k.astype("bfloat16"),
  683. v.astype("bfloat16"),
  684. cu_seqlens,
  685. cu_seqlens,
  686. max_seqlen,
  687. max_seqlen,
  688. scale=softmax_scale,
  689. )[0]
  690. .squeeze(0)
  691. .reshape([seq_length, -1])
  692. )
  693. if self.proj.weight.dtype == paddle.bfloat16:
  694. attn_output = attn_output.astype(paddle.bfloat16)
  695. elif self.proj.weight.dtype == paddle.float16:
  696. attn_output = attn_output.astype(paddle.float16)
  697. elif self.proj.weight.dtype == paddle.float32:
  698. attn_output = attn_output.astype(paddle.float32)
  699. attn_output = self.proj(attn_output)
  700. return attn_output
  701. def create_attention_module(config, module_type, layer_idx=None):
  702. if flash_attn_func is not None:
  703. if module_type == "qwen2vl":
  704. return Qwen2VLFlashAttention2(config, layer_idx)
  705. elif module_type == "vision":
  706. return VisionFlashAttention2(config.embed_dim, num_heads=config.num_heads)
  707. else:
  708. logging.warning_once(
  709. f"Warning: Flash Attention2 is not available for {module_type}, fallback to normal attention."
  710. )
  711. if module_type == "qwen2vl":
  712. return Qwen2VLAttention(config, layer_idx)
  713. elif module_type == "vision":
  714. return VisionAttention(config.embed_dim, num_heads=config.num_heads)
  715. class Qwen2VLVisionBlock(nn.Layer):
  716. def __init__(self, config, attn_implementation: str = "flash_attention_2") -> None:
  717. super().__init__()
  718. self.norm1 = nn.LayerNorm(config.embed_dim, epsilon=1e-6)
  719. self.norm2 = nn.LayerNorm(config.embed_dim, epsilon=1e-6)
  720. mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
  721. self.attn = create_attention_module(config, "vision")
  722. self.mlp = VisionMlp(
  723. dim=config.embed_dim,
  724. hidden_dim=mlp_hidden_dim,
  725. hidden_act=config.hidden_act,
  726. )
  727. def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor:
  728. hidden_states = hidden_states + self.attn(
  729. self.norm1(hidden_states),
  730. cu_seqlens=cu_seqlens,
  731. rotary_pos_emb=rotary_pos_emb,
  732. )
  733. hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
  734. return hidden_states
  735. def _prepare_4d_causal_attention_mask_with_cache_position(
  736. attention_mask: paddle.Tensor,
  737. sequence_length: int,
  738. target_length: int,
  739. dtype: paddle.dtype,
  740. min_dtype: float,
  741. cache_position: paddle.Tensor,
  742. batch_size: int,
  743. ):
  744. """
  745. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  746. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  747. Args:
  748. attention_mask (`paddle.Tensor`):
  749. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  750. sequence_length (`int`):
  751. The sequence length being processed.
  752. target_length (`int`):
  753. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  754. dtype (`paddle.dtype`):
  755. The dtype to use for the 4D attention mask.
  756. min_dtype (`float`):
  757. The minimum value representable with the dtype `dtype`.
  758. cache_position (`paddle.Tensor`):
  759. Indices depicting the position of the input sequence tokens in the sequence.
  760. batch_size (`paddle.Tensor`):
  761. Batch size.
  762. """
  763. if attention_mask is not None and attention_mask.dim() == 4:
  764. causal_mask = attention_mask
  765. else:
  766. causal_mask = paddle.full(
  767. [sequence_length, target_length], fill_value=min_dtype, dtype=dtype
  768. )
  769. if sequence_length != 1:
  770. causal_mask = paddle.triu(x=causal_mask, diagonal=1)
  771. causal_mask *= paddle.arange(target_length) > cache_position.reshape([-1, 1])
  772. causal_mask = causal_mask[None, None, :, :].expand(
  773. shape=[batch_size, 1, -1, -1]
  774. )
  775. if attention_mask is not None:
  776. causal_mask = causal_mask.clone()
  777. mask_length = tuple(attention_mask.shape)[-1]
  778. padding_mask = (
  779. causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  780. )
  781. padding_mask = padding_mask == 0
  782. causal_mask[:, :, :, :mask_length] = causal_mask[
  783. :, :, :, :mask_length
  784. ].masked_fill(mask=padding_mask, value=min_dtype)
  785. return causal_mask
  786. class Qwen2RMSNorm(nn.Layer):
  787. def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6):
  788. """
  789. Qwen2RMSNorm is equivalent to T5LayerNorm
  790. """
  791. super().__init__()
  792. self.weight = paddle.create_parameter(
  793. shape=[hidden_size],
  794. dtype=paddle.get_default_dtype(),
  795. default_initializer=nn.initializer.Constant(1.0),
  796. )
  797. self.variance_epsilon = eps
  798. def forward(self, hidden_states):
  799. if paddle.in_dynamic_mode():
  800. with paddle.amp.auto_cast(False):
  801. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  802. hidden_states = (
  803. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  804. )
  805. else:
  806. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  807. hidden_states = (
  808. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  809. )
  810. if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
  811. hidden_states = paddle.cast(hidden_states, self.weight.dtype)
  812. return hidden_states * self.weight
  813. class Qwen2MLP(nn.Layer):
  814. def __init__(self, config):
  815. super().__init__()
  816. self.hidden_size = config.hidden_size
  817. self.intermediate_size = config.intermediate_size
  818. self.fuse_attention_ffn = config.fuse_attention_ffn
  819. self.tensor_parallel_degree = config.tensor_parallel_degree
  820. if config.tensor_parallel_degree > 1:
  821. self.gate_proj = ColumnParallelLinear(
  822. self.hidden_size,
  823. self.intermediate_size,
  824. gather_output=False,
  825. has_bias=False,
  826. )
  827. self.up_proj = ColumnParallelLinear(
  828. self.hidden_size,
  829. self.intermediate_size,
  830. gather_output=False,
  831. has_bias=False,
  832. )
  833. self.down_proj = RowParallelLinear(
  834. self.intermediate_size,
  835. self.hidden_size,
  836. input_is_parallel=True,
  837. has_bias=False,
  838. )
  839. else:
  840. self.gate_proj = Linear(
  841. self.hidden_size, self.intermediate_size, bias_attr=False
  842. ) # w1
  843. self.up_proj = Linear(
  844. self.hidden_size, self.intermediate_size, bias_attr=False
  845. ) # w3
  846. self.down_proj = Linear(
  847. self.intermediate_size, self.hidden_size, bias_attr=False
  848. ) # w2
  849. self.act_fn = ACT2FN[config.hidden_act]
  850. self.fuse_swiglu = False
  851. def forward(self, x):
  852. x, y = self.gate_proj(x), self.up_proj(x)
  853. if self.fuse_swiglu:
  854. x = self.act_fn(x, y)
  855. else:
  856. x = self.act_fn(x) * y
  857. return self.down_proj(x)
  858. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  859. def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
  860. """
  861. This is the equivalent of paddle.repeat_interleave(x, axis=1, repeats=n_rep). The hidden states go from (batch,
  862. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  863. """
  864. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  865. if n_rep == 1:
  866. return hidden_states
  867. hidden_states = hidden_states[:, :, None, :, :].expand(
  868. [batch, num_key_value_heads, n_rep, slen, head_dim]
  869. )
  870. return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim])
  871. class Qwen2VLAttention(nn.Layer):
  872. """
  873. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  874. and "Generating Long Sequences with Sparse Transformers".
  875. """
  876. def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
  877. super().__init__()
  878. self.config = config
  879. self.layer_idx = layer_idx
  880. if layer_idx is None:
  881. logging.warning_once(
  882. f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
  883. "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  884. "when creating this class."
  885. )
  886. self.hidden_size = config.hidden_size
  887. self.num_heads = config.num_attention_heads
  888. self.head_dim = self.hidden_size // self.num_heads
  889. self.num_key_value_heads = config.num_key_value_heads
  890. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  891. self.max_position_embeddings = config.max_position_embeddings
  892. self.rope_theta = config.rope_theta
  893. self.is_causal = True
  894. self.attention_dropout = config.attention_dropout
  895. self.rope_scaling = config.rope_scaling
  896. # self.sequence_parallel = config.sequence_parallel
  897. if config.tensor_parallel_degree > 1:
  898. assert (
  899. self.num_heads % config.tensor_parallel_degree == 0
  900. ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  901. self.num_heads = self.num_heads // config.tensor_parallel_degree
  902. assert (
  903. self.num_key_value_heads % config.tensor_parallel_degree == 0
  904. ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  905. self.num_key_value_heads = (
  906. self.num_key_value_heads // config.tensor_parallel_degree
  907. )
  908. if config.tensor_parallel_degree > 1:
  909. self.q_proj = ColumnParallelLinear(
  910. self.hidden_size, self.hidden_size, has_bias=True, gather_output=False
  911. )
  912. self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  913. self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  914. self.o_proj = RowParallelLinear(
  915. self.hidden_size,
  916. self.hidden_size,
  917. has_bias=False,
  918. input_is_parallel=True,
  919. )
  920. else:
  921. self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True)
  922. self.k_proj = Linear(
  923. self.hidden_size,
  924. self.config.num_key_value_heads * self.head_dim,
  925. bias_attr=True,
  926. )
  927. self.v_proj = Linear(
  928. self.hidden_size,
  929. self.config.num_key_value_heads * self.head_dim,
  930. bias_attr=True,
  931. )
  932. self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False)
  933. self.rotary_emb = Qwen2VLRotaryEmbedding(
  934. self.head_dim,
  935. max_position_embeddings=self.max_position_embeddings,
  936. base=self.rope_theta,
  937. )
  938. def forward(
  939. self,
  940. hidden_states: paddle.Tensor,
  941. attention_mask: Optional[paddle.Tensor] = None,
  942. position_ids: Optional[paddle.Tensor] = None,
  943. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  944. output_attentions: bool = False,
  945. use_cache: bool = False, # default true
  946. cache_position: Optional[paddle.Tensor] = None,
  947. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
  948. bsz, q_len, _ = hidden_states.shape
  949. try:
  950. query_states = self.q_proj(hidden_states)
  951. key_states = self.k_proj(hidden_states)
  952. value_states = self.v_proj(hidden_states)
  953. except:
  954. hidden_states = hidden_states.astype(self.config.dtype)
  955. query_states = self.q_proj(hidden_states)
  956. key_states = self.k_proj(hidden_states)
  957. value_states = self.v_proj(hidden_states)
  958. target_query_shape = [0, 0, self.num_heads, self.head_dim]
  959. target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
  960. query_states = query_states.reshape(shape=target_query_shape)
  961. key_states = key_states.reshape(shape=target_key_value_shape)
  962. value_states = value_states.reshape(shape=target_key_value_shape)
  963. new_perm = [0, 2, 1, 3]
  964. query_states = query_states.transpose(new_perm)
  965. key_states = key_states.transpose(new_perm)
  966. value_states = value_states.transpose(new_perm)
  967. kv_seq_len = key_states.shape[-2]
  968. if past_key_value is not None:
  969. kv_seq_len += cache_position[0] + 1
  970. cos, sin = self.rotary_emb(value_states, position_ids)
  971. query_states, key_states = apply_multimodal_rotary_pos_emb(
  972. query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
  973. )
  974. if past_key_value is not None:
  975. key_states = paddle.concat([past_key_value[0], key_states], axis=2)
  976. value_states = paddle.concat([past_key_value[1], value_states], axis=2)
  977. past_key_value = (key_states, value_states) if use_cache else None
  978. # repeat k/v heads if n_kv_heads < n_heads
  979. key_states = repeat_kv(key_states, self.num_key_value_groups)
  980. value_states = repeat_kv(value_states, self.num_key_value_groups)
  981. query_states = query_states.astype("float32")
  982. key_states = key_states.astype("float32")
  983. value_states = value_states.astype("float32")
  984. attn_weights = paddle.matmul(
  985. query_states, key_states.transpose([0, 1, 3, 2])
  986. ) / math.sqrt(self.head_dim)
  987. if attention_mask is not None:
  988. attn_weights = attn_weights + attention_mask
  989. attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype="float32")
  990. attn_weights = nn.functional.dropout(
  991. x=attn_weights, p=self.attention_dropout, training=self.training
  992. )
  993. attn_output = paddle.matmul(
  994. attn_weights.cast(self.config.dtype), value_states.cast(self.config.dtype)
  995. )
  996. if attn_output.shape != [bsz, self.num_heads, q_len, self.head_dim]:
  997. raise ValueError(
  998. f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
  999. f" {attn_output.shape}"
  1000. )
  1001. attn_output = attn_output.transpose([0, 2, 1, 3])
  1002. attn_output = attn_output.reshape([bsz, q_len, -1])
  1003. if self.o_proj.weight.dtype == paddle.bfloat16:
  1004. attn_output = attn_output.astype(paddle.bfloat16)
  1005. elif self.o_proj.weight.dtype == paddle.float16:
  1006. attn_output = attn_output.astype(paddle.float16)
  1007. elif self.o_proj.weight.dtype == paddle.float32:
  1008. attn_output = attn_output.astype(paddle.float32)
  1009. attn_output = self.o_proj(attn_output)
  1010. if not output_attentions:
  1011. attn_weights = None
  1012. return attn_output, attn_weights, past_key_value
  1013. class Qwen2VLFlashAttention2(Qwen2VLAttention):
  1014. """
  1015. Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
  1016. as the weights of the module stays untouched. The only required change would be on the forward pass
  1017. where it needs to correctly call the public API of flash attention and deal with padding tokens
  1018. in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
  1019. config.max_window_layers layers.
  1020. """
  1021. def __init__(self, *args, **kwargs):
  1022. super().__init__(*args, **kwargs)
  1023. def forward(
  1024. self,
  1025. hidden_states: paddle.Tensor,
  1026. attention_mask: Optional[paddle.Tensor] = None,
  1027. position_ids: Optional[paddle.Tensor] = None,
  1028. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  1029. output_attentions: bool = False,
  1030. use_cache: bool = False, # default true
  1031. cache_position: Optional[paddle.Tensor] = None,
  1032. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
  1033. bsz, q_len, _ = tuple(hidden_states.shape)
  1034. try:
  1035. query_states = self.q_proj(hidden_states)
  1036. key_states = self.k_proj(hidden_states)
  1037. value_states = self.v_proj(hidden_states)
  1038. except:
  1039. hidden_states = hidden_states.astype("bfloat16")
  1040. query_states = self.q_proj(hidden_states)
  1041. key_states = self.k_proj(hidden_states)
  1042. value_states = self.v_proj(hidden_states)
  1043. target_query_shape = [0, 0, self.num_heads, self.head_dim]
  1044. target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
  1045. query_states = query_states.reshape(shape=target_query_shape)
  1046. key_states = key_states.reshape(shape=target_key_value_shape)
  1047. value_states = value_states.reshape(shape=target_key_value_shape)
  1048. new_perm = [0, 2, 1, 3]
  1049. query_states = query_states.transpose(new_perm)
  1050. key_states = key_states.transpose(new_perm)
  1051. value_states = value_states.transpose(new_perm)
  1052. kv_seq_len = key_states.shape[-2]
  1053. if past_key_value is not None:
  1054. kv_seq_len += cache_position[0] + 1
  1055. # Because the input can be padded, the absolute sequence length depends on the max position id.
  1056. cos, sin = self.rotary_emb(value_states, position_ids)
  1057. query_states, key_states = apply_multimodal_rotary_pos_emb(
  1058. query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
  1059. )
  1060. if past_key_value is not None:
  1061. key_states = paddle.concat([past_key_value[0], key_states], axis=2)
  1062. value_states = paddle.concat([past_key_value[1], value_states], axis=2)
  1063. past_key_value = (key_states, value_states) if use_cache else None
  1064. # repeat k/v heads if n_kv_heads < n_heads
  1065. key_states = repeat_kv(key_states, self.num_key_value_groups)
  1066. value_states = repeat_kv(value_states, self.num_key_value_groups)
  1067. # Reashape to the expected shape for Flash Attention
  1068. # [1, 3599, 12, 128]
  1069. query_states = query_states.transpose(perm=[0, 2, 1, 3])
  1070. key_states = key_states.transpose(perm=[0, 2, 1, 3])
  1071. value_states = value_states.transpose(perm=[0, 2, 1, 3])
  1072. attn_output = self._flash_attention_forward(
  1073. query_states, key_states, value_states, attention_mask, q_len
  1074. )
  1075. attn_output = attn_output.reshape([bsz, q_len, -1])
  1076. attn_output = self.o_proj(attn_output)
  1077. if not output_attentions:
  1078. attn_weights = None
  1079. return attn_output, attn_weights, past_key_value
  1080. def _flash_attention_forward(
  1081. self,
  1082. query_states,
  1083. key_states,
  1084. value_states,
  1085. attention_mask,
  1086. query_length,
  1087. dropout=0.0,
  1088. softmax_scale=None,
  1089. ):
  1090. """
  1091. Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
  1092. first unpad the input, then computes the attention scores and pad the final attention scores.
  1093. Args:
  1094. query_states (`paddle.Tensor`):
  1095. Input query states to be passed to Flash Attention API
  1096. key_states (`paddle.Tensor`):
  1097. Input key states to be passed to Flash Attention API
  1098. value_states (`paddle.Tensor`):
  1099. Input value states to be passed to Flash Attention API
  1100. attention_mask (`paddle.Tensor`):
  1101. The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
  1102. position of padding tokens and 1 for the position of non-padding tokens.
  1103. dropout (`int`, *optional*):
  1104. Attention dropout
  1105. softmax_scale (`float`, *optional*):
  1106. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
  1107. """
  1108. # Contains at least one padding token in the sequence
  1109. causal = self.is_causal and query_length != 1
  1110. if _IS_NPU:
  1111. if attention_mask is not None:
  1112. attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded
  1113. query_states,
  1114. key_states,
  1115. value_states,
  1116. attn_mask=attention_mask,
  1117. dropout=dropout,
  1118. causal=causal,
  1119. is_varlen=True,
  1120. )
  1121. else:
  1122. dtype = query_states.dtype
  1123. attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded
  1124. query_states.astype("bfloat16"),
  1125. key_states.astype("bfloat16"),
  1126. value_states.astype("bfloat16"),
  1127. attn_mask=attention_mask,
  1128. dropout=dropout,
  1129. causal=causal,
  1130. )
  1131. attn_output = attn_output.astype(dtype)
  1132. else:
  1133. head_dim = query_states.shape[-1]
  1134. softmax_scale = head_dim**-0.5 # TODO: 需要手动加上
  1135. if attention_mask is not None:
  1136. batch_size = query_states.shape[0]
  1137. (
  1138. query_states,
  1139. key_states,
  1140. value_states,
  1141. indices_q,
  1142. cu_seq_lens,
  1143. max_seq_lens,
  1144. ) = self._unpad_input(
  1145. query_states, key_states, value_states, attention_mask, query_length
  1146. )
  1147. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  1148. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  1149. attn_output_unpad = flash_attn_varlen_func(
  1150. query_states,
  1151. key_states,
  1152. value_states,
  1153. cu_seqlens_q=cu_seqlens_q,
  1154. cu_seqlens_k=cu_seqlens_k,
  1155. max_seqlen_q=max_seqlen_in_batch_q,
  1156. max_seqlen_k=max_seqlen_in_batch_k,
  1157. scale=softmax_scale, # not softmax_scale=
  1158. dropout=dropout,
  1159. causal=causal,
  1160. )[0]
  1161. attn_output = pad_input(
  1162. attn_output_unpad, indices_q, batch_size, query_length
  1163. )
  1164. else:
  1165. attn_output = flash_attn_func(
  1166. query_states,
  1167. key_states,
  1168. value_states,
  1169. dropout,
  1170. causal=causal,
  1171. )[0]
  1172. return attn_output
  1173. def _unpad_input(
  1174. self, query_layer, key_layer, value_layer, attention_mask, query_length
  1175. ):
  1176. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
  1177. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  1178. # TODO:cuda error
  1179. key_layer = index_first_axis(
  1180. key_layer.reshape([batch_size * kv_seq_len, num_key_value_heads, head_dim]),
  1181. indices_k,
  1182. )
  1183. value_layer = index_first_axis(
  1184. value_layer.reshape(
  1185. [batch_size * kv_seq_len, num_key_value_heads, head_dim]
  1186. ),
  1187. indices_k,
  1188. )
  1189. if query_length == kv_seq_len:
  1190. query_layer = index_first_axis(
  1191. query_layer.reshape(
  1192. [batch_size * kv_seq_len, self.num_heads, head_dim]
  1193. ),
  1194. indices_k,
  1195. )
  1196. cu_seqlens_q = cu_seqlens_k
  1197. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  1198. indices_q = indices_k
  1199. elif query_length == 1:
  1200. max_seqlen_in_batch_q = 1
  1201. cu_seqlens_q = paddle.arange(
  1202. batch_size + 1, dtype=paddle.int32
  1203. ) # There is a memcpy here, that is very bad.
  1204. indices_q = cu_seqlens_q[:-1]
  1205. query_layer = query_layer.squeeze(1)
  1206. else:
  1207. # The -q_len: slice assumes left padding.
  1208. attention_mask = attention_mask[:, -query_length:]
  1209. query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
  1210. query_layer, attention_mask
  1211. )
  1212. return (
  1213. query_layer,
  1214. key_layer,
  1215. value_layer,
  1216. indices_q.to(paddle.int64),
  1217. (cu_seqlens_q, cu_seqlens_k),
  1218. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  1219. )
  1220. class Qwen2VLDecoderLayer(nn.Layer):
  1221. def __init__(self, config: Qwen2VLConfig, layer_idx: int):
  1222. super().__init__()
  1223. self.hidden_size = config.hidden_size
  1224. # use_sliding_window false
  1225. if (
  1226. config.use_sliding_window
  1227. and config.attn_implementation != "flash_attention_2"
  1228. ):
  1229. logging.warning_once(
  1230. f"Sliding Window Attention is enabled but not implemented for `{config.attn_implementation}`; "
  1231. "unexpected results may be encountered."
  1232. )
  1233. self.self_attn = create_attention_module(config, "qwen2vl", layer_idx=layer_idx)
  1234. # self.self_attn = Qwen2VLAttention(config, layer_idx)
  1235. self.mlp = Qwen2MLP(config)
  1236. self.input_layernorm = Qwen2RMSNorm(
  1237. config, config.hidden_size, eps=config.rms_norm_eps
  1238. )
  1239. self.post_attention_layernorm = Qwen2RMSNorm(
  1240. config, config.hidden_size, eps=config.rms_norm_eps
  1241. )
  1242. def forward(
  1243. self,
  1244. hidden_states: paddle.Tensor,
  1245. attention_mask: Optional[paddle.Tensor] = None,
  1246. position_ids: Optional[paddle.Tensor] = None,
  1247. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  1248. output_attentions: Optional[bool] = False,
  1249. use_cache: Optional[bool] = False,
  1250. cache_position: Optional[paddle.Tensor] = None,
  1251. **kwargs,
  1252. ):
  1253. """
  1254. Args:
  1255. hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  1256. attention_mask (`paddle.Tensor`, *optional*): attention mask of size
  1257. `(batch, sequence_length)` where padding elements are indicated by 0.
  1258. output_attentions (`bool`, *optional*):
  1259. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1260. returned tensors for more detail.
  1261. use_cache (`bool`, *optional*):
  1262. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  1263. (see `past_key_values`).
  1264. past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
  1265. cache_position (`paddle.Tensor` of shape `(sequence_length)`, *optional*):
  1266. Indices depicting the position of the input sequence tokens in the sequence.
  1267. kwargs (`dict`, *optional*):
  1268. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  1269. into the model
  1270. """
  1271. residual = hidden_states
  1272. hidden_states = self.input_layernorm(hidden_states)
  1273. # Self Attention
  1274. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  1275. hidden_states=hidden_states,
  1276. attention_mask=attention_mask,
  1277. position_ids=position_ids,
  1278. past_key_value=past_key_value,
  1279. output_attentions=output_attentions,
  1280. use_cache=use_cache,
  1281. cache_position=cache_position,
  1282. )
  1283. hidden_states = residual + hidden_states
  1284. # Fully Connected
  1285. residual = hidden_states
  1286. hidden_states = self.post_attention_layernorm(hidden_states)
  1287. hidden_states = self.mlp(hidden_states)
  1288. hidden_states = residual + hidden_states
  1289. outputs = (hidden_states,)
  1290. if output_attentions:
  1291. outputs += (self_attn_weights,)
  1292. if use_cache:
  1293. outputs += (present_key_value,)
  1294. return outputs
  1295. class Qwen2VLPreTrainedModel(PretrainedModel):
  1296. config_class = Qwen2VLConfig
  1297. base_model_prefix = "model"
  1298. _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
  1299. _skip_keys_device_placement = "past_key_values"
  1300. def _init_weights(self, layer):
  1301. std = 0.2
  1302. if isinstance(layer, (nn.Linear, nn.Conv3D)):
  1303. nn.initializer.Normal(mean=0.0, std=std)(layer.weight)
  1304. if layer.bias is not None:
  1305. nn.initializer.Constant(0.0)(layer.bias)
  1306. elif isinstance(layer, nn.Embedding):
  1307. nn.initializer.Normal(mean=0.0, std=std)(layer.weight)
  1308. if layer._padding_idx is not None:
  1309. with paddle.no_grad():
  1310. layer.weight[layer._padding_idx] = 0.0
  1311. class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
  1312. config_class = Qwen2VLVisionConfig
  1313. _no_split_modules = ["Qwen2VLVisionBlock"]
  1314. def __init__(self, config) -> None:
  1315. super().__init__(config)
  1316. self.spatial_merge_size = config.spatial_merge_size
  1317. self.patch_embed = PatchEmbed(
  1318. patch_size=config.patch_size,
  1319. temporal_patch_size=config.temporal_patch_size,
  1320. in_channels=config.in_channels,
  1321. embed_dim=config.embed_dim,
  1322. )
  1323. head_dim = config.embed_dim // config.num_heads
  1324. self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
  1325. self.blocks = nn.LayerList(
  1326. [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
  1327. )
  1328. self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
  1329. self.enable_recompute = False
  1330. def get_dtype(self) -> paddle.dtype:
  1331. return self.blocks[0].mlp.fc2.weight.dtype
  1332. def rot_pos_emb(self, grid_thw):
  1333. pos_ids = []
  1334. for t, h, w in grid_thw:
  1335. hpos_ids = paddle.arange(h).unsqueeze(1).expand([-1, w])
  1336. hpos_ids = hpos_ids.reshape(
  1337. [
  1338. h // self.spatial_merge_size,
  1339. self.spatial_merge_size,
  1340. w // self.spatial_merge_size,
  1341. self.spatial_merge_size,
  1342. ]
  1343. )
  1344. hpos_ids = hpos_ids.transpose(perm=[0, 2, 1, 3])
  1345. hpos_ids = hpos_ids.flatten()
  1346. wpos_ids = paddle.arange(w).unsqueeze(0).expand([h, -1])
  1347. wpos_ids = wpos_ids.reshape(
  1348. [
  1349. h // self.spatial_merge_size,
  1350. self.spatial_merge_size,
  1351. w // self.spatial_merge_size,
  1352. self.spatial_merge_size,
  1353. ]
  1354. )
  1355. wpos_ids = wpos_ids.transpose([0, 2, 1, 3])
  1356. wpos_ids = wpos_ids.flatten()
  1357. pos_ids.append(
  1358. paddle.stack(x=[hpos_ids, wpos_ids], axis=-1).tile(repeat_times=[t, 1])
  1359. )
  1360. pos_ids = paddle.concat(x=pos_ids, axis=0)
  1361. max_grid_size = grid_thw[:, 1:].max()
  1362. rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
  1363. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_axis=1)
  1364. return rotary_pos_emb
  1365. @paddle.jit.not_to_static
  1366. def recompute_training_full(
  1367. self,
  1368. layer_module: nn.Layer,
  1369. hidden_states: paddle.Tensor,
  1370. cu_seqlens_now: paddle.Tensor,
  1371. rotary_pos_emb: paddle.Tensor,
  1372. ):
  1373. def create_custom_forward(module):
  1374. def custom_forward(*inputs):
  1375. return module(*inputs)
  1376. return custom_forward
  1377. hidden_states = recompute(
  1378. create_custom_forward(layer_module),
  1379. hidden_states,
  1380. cu_seqlens_now,
  1381. rotary_pos_emb,
  1382. # use_reentrant=self.config.recompute_use_reentrant,
  1383. )
  1384. return hidden_states
  1385. def forward(
  1386. self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor
  1387. ) -> paddle.Tensor:
  1388. # breakpoint()
  1389. hidden_states = self.patch_embed(hidden_states)
  1390. rotary_pos_emb = self.rot_pos_emb(grid_thw)
  1391. cu_seqlens = paddle.repeat_interleave(
  1392. grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
  1393. ).cumsum(axis=0, dtype="int32")
  1394. cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
  1395. for idx, blk in enumerate(self.blocks):
  1396. if self.enable_recompute and self.training:
  1397. hidden_states = self.recompute_training_full(
  1398. blk, hidden_states, cu_seqlens, rotary_pos_emb
  1399. )
  1400. else:
  1401. hidden_states = blk(
  1402. hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
  1403. )
  1404. return self.merger(hidden_states)
  1405. class Qwen2VLModel(Qwen2VLPreTrainedModel):
  1406. def __init__(self, config: Qwen2VLConfig):
  1407. super().__init__(config)
  1408. self.padding_idx = config.pad_token_id
  1409. self.vocab_size = config.vocab_size
  1410. self.hidden_size = config.hidden_size
  1411. # Recompute defaults to False and is controlled by Trainer
  1412. if (
  1413. config.tensor_parallel_degree > 1
  1414. and config.vocab_size % config.tensor_parallel_degree == 0
  1415. ):
  1416. self.embed_tokens = mpu.VocabParallelEmbedding(
  1417. self.vocab_size,
  1418. self.hidden_size,
  1419. weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
  1420. )
  1421. else:
  1422. self.embed_tokens = nn.Embedding(
  1423. self.vocab_size,
  1424. self.hidden_size,
  1425. )
  1426. # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  1427. self.layers = nn.LayerList(
  1428. [
  1429. Qwen2VLDecoderLayer(config, layer_idx)
  1430. for layer_idx in range(config.num_hidden_layers)
  1431. ]
  1432. )
  1433. self.norm = Qwen2RMSNorm(config, config.hidden_size, eps=config.rms_norm_eps)
  1434. self.enamble_recompute = False
  1435. def get_input_embeddings(self):
  1436. return self.embed_tokens
  1437. def set_input_embeddings(self, value):
  1438. self.embed_tokens = value
  1439. @staticmethod
  1440. def _prepare_decoder_attention_mask(
  1441. attention_mask, input_shape, past_key_values_length, dtype
  1442. ):
  1443. if attention_mask is not None:
  1444. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1445. if len(attention_mask.shape) == 2:
  1446. expanded_attn_mask = _expand_2d_mask(
  1447. attention_mask, dtype, tgt_length=input_shape[-1]
  1448. )
  1449. # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
  1450. if input_shape[-1] > 1:
  1451. combined_attention_mask = _make_causal_mask(
  1452. input_shape,
  1453. past_key_values_length=past_key_values_length,
  1454. )
  1455. expanded_attn_mask = expanded_attn_mask & combined_attention_mask
  1456. # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
  1457. elif len(attention_mask.shape) == 3:
  1458. expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
  1459. # if attention_mask is already 4-D, do nothing
  1460. else:
  1461. expanded_attn_mask = attention_mask
  1462. else:
  1463. expanded_attn_mask = _make_causal_mask(
  1464. input_shape,
  1465. past_key_values_length=past_key_values_length,
  1466. )
  1467. # Convert bool attention_mask to float attention mask, which will be added to attention_scores later
  1468. expanded_attn_mask = paddle.where(
  1469. expanded_attn_mask, 0.0, paddle.finfo(dtype).min
  1470. ).astype(dtype)
  1471. return expanded_attn_mask
  1472. @paddle.jit.not_to_static
  1473. def recompute_training_full(
  1474. self,
  1475. layer_module: nn.Layer,
  1476. hidden_states: paddle.Tensor,
  1477. attention_mask: paddle.Tensor,
  1478. position_ids: Optional[paddle.Tensor],
  1479. past_key_value: paddle.Tensor,
  1480. output_attentions: bool,
  1481. use_cache: bool,
  1482. cache_position: Optional[paddle.Tensor] = None,
  1483. ):
  1484. def create_custom_forward(module):
  1485. def custom_forward(*inputs):
  1486. return module(*inputs)
  1487. return custom_forward
  1488. hidden_states = recompute(
  1489. create_custom_forward(layer_module),
  1490. hidden_states,
  1491. attention_mask,
  1492. position_ids,
  1493. past_key_value,
  1494. output_attentions,
  1495. use_cache,
  1496. cache_position,
  1497. use_reentrant=self.config.recompute_use_reentrant,
  1498. )
  1499. return hidden_states
  1500. def forward(
  1501. self,
  1502. input_ids: paddle.Tensor = None,
  1503. attention_mask: Optional[paddle.Tensor] = None,
  1504. position_ids: Optional[paddle.Tensor] = None,
  1505. past_key_values: Optional[List[paddle.Tensor]] = None,
  1506. inputs_embeds: Optional[paddle.Tensor] = None,
  1507. use_cache: Optional[bool] = None,
  1508. output_attentions: Optional[bool] = None,
  1509. output_hidden_states: Optional[bool] = None,
  1510. return_dict: Optional[bool] = None,
  1511. cache_position: Optional[paddle.Tensor] = None,
  1512. ) -> Union[Tuple, BaseModelOutputWithPast]:
  1513. output_attentions = (
  1514. output_attentions
  1515. if output_attentions is not None
  1516. else self.config.output_attentions
  1517. )
  1518. output_hidden_states = (
  1519. output_hidden_states
  1520. if output_hidden_states is not None
  1521. else self.config.output_hidden_states
  1522. )
  1523. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1524. return_dict = (
  1525. return_dict if return_dict is not None else self.config.use_return_dict
  1526. )
  1527. if (input_ids is None) ^ (inputs_embeds is not None):
  1528. raise ValueError(
  1529. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  1530. )
  1531. elif input_ids is not None:
  1532. batch_size, seq_length = input_ids.shape
  1533. elif inputs_embeds is not None:
  1534. batch_size, seq_length, _ = inputs_embeds.shape
  1535. else:
  1536. raise ValueError(
  1537. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  1538. )
  1539. if past_key_values is None:
  1540. past_key_values = tuple([None] * len(self.layers))
  1541. # NOTE: to make cache can be clear in-time
  1542. past_key_values = list(past_key_values)
  1543. seq_length_with_past = seq_length
  1544. cache_length = 0
  1545. if past_key_values[0] is not None:
  1546. cache_length = past_key_values[0][0].shape[2] # shape[1] in qwen2
  1547. seq_length_with_past += cache_length
  1548. if inputs_embeds is None:
  1549. inputs_embeds = self.embed_tokens(input_ids)
  1550. # embed positions
  1551. if attention_mask is None:
  1552. # [bs, seq_len]
  1553. attention_mask = paddle.ones(
  1554. (batch_size, seq_length_with_past), dtype=paddle.bool
  1555. )
  1556. if flash_attn_varlen_func:
  1557. causal_mask = attention_mask
  1558. else:
  1559. causal_mask = self._prepare_decoder_attention_mask(
  1560. attention_mask,
  1561. (batch_size, seq_length),
  1562. cache_length,
  1563. inputs_embeds.dtype,
  1564. ) # [bs, 1, seq_len, seq_len]
  1565. if cache_position is None:
  1566. past_seen_tokens = (
  1567. past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
  1568. )
  1569. cache_position = paddle.arange(
  1570. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]
  1571. )
  1572. if position_ids is None:
  1573. # the hard coded `3` is for temporal, height and width.
  1574. position_ids = cache_position.reshape([1, 1, -1]).expand(
  1575. [3, inputs_embeds.shape[0], -1]
  1576. )
  1577. hidden_states = inputs_embeds
  1578. # decoder layers
  1579. all_hidden_states = () if output_hidden_states else None
  1580. all_self_attns = () if output_attentions else None
  1581. next_decoder_cache = ()
  1582. for idx, (decoder_layer) in enumerate(self.layers):
  1583. if output_hidden_states:
  1584. all_hidden_states += (hidden_states,)
  1585. past_key_value = (
  1586. past_key_values[idx] if past_key_values is not None else None
  1587. )
  1588. if self.enamble_recompute and self.training:
  1589. layer_outputs = self.recompute_training_full(
  1590. decoder_layer,
  1591. hidden_states,
  1592. causal_mask,
  1593. position_ids,
  1594. past_key_value,
  1595. output_attentions,
  1596. use_cache,
  1597. cache_position,
  1598. )
  1599. else:
  1600. layer_outputs = decoder_layer(
  1601. hidden_states,
  1602. attention_mask=causal_mask,
  1603. position_ids=position_ids,
  1604. past_key_value=past_key_value,
  1605. output_attentions=output_attentions, # False
  1606. use_cache=use_cache, # True
  1607. cache_position=cache_position,
  1608. )
  1609. # NOTE: clear outdate cache after it has been used for memory saving
  1610. past_key_value = past_key_values[idx] = None
  1611. hidden_states = layer_outputs[0]
  1612. next_decoder_cache = (
  1613. next_decoder_cache + (layer_outputs[-1],) if use_cache else None
  1614. )
  1615. if output_attentions:
  1616. all_self_attns += (layer_outputs[1],)
  1617. hidden_states = self.norm(hidden_states)
  1618. # add hidden states from the last decoder layer
  1619. if output_hidden_states:
  1620. all_hidden_states += (hidden_states,)
  1621. next_cache = next_decoder_cache if use_cache else None
  1622. if not return_dict:
  1623. return tuple(
  1624. v
  1625. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
  1626. if v is not None
  1627. )
  1628. return BaseModelOutputWithPast(
  1629. last_hidden_state=hidden_states,
  1630. past_key_values=next_cache,
  1631. hidden_states=all_hidden_states,
  1632. attentions=all_self_attns,
  1633. )
  1634. class Qwen2LMHead(nn.Layer):
  1635. def __init__(self, config, embedding_weights=None, transpose_y=False):
  1636. super(Qwen2LMHead, self).__init__()
  1637. self.config = config
  1638. if (
  1639. config.tensor_parallel_degree > 1
  1640. and config.vocab_size % config.tensor_parallel_degree == 0
  1641. ):
  1642. vocab_size = config.vocab_size // config.tensor_parallel_degree
  1643. else:
  1644. vocab_size = config.vocab_size
  1645. self.transpose_y = transpose_y
  1646. if transpose_y:
  1647. # only for weight from embedding_weights
  1648. if embedding_weights is not None:
  1649. self.weight = embedding_weights
  1650. else:
  1651. self.weight = self.create_parameter(
  1652. shape=[vocab_size, config.hidden_size],
  1653. dtype=paddle.get_default_dtype(),
  1654. )
  1655. else:
  1656. if vocab_size != config.vocab_size:
  1657. with get_rng_state_tracker().rng_state():
  1658. self.weight = self.create_parameter(
  1659. shape=[config.hidden_size, vocab_size],
  1660. dtype=paddle.get_default_dtype(),
  1661. )
  1662. else:
  1663. self.weight = self.create_parameter(
  1664. shape=[config.hidden_size, vocab_size],
  1665. dtype=paddle.get_default_dtype(),
  1666. )
  1667. # Must set distributed attr for Tensor Parallel !
  1668. self.weight.is_distributed = (
  1669. True if (vocab_size != config.vocab_size) else False
  1670. )
  1671. if self.weight.is_distributed:
  1672. # for tie_word_embeddings
  1673. self.weight.split_axis = 0 if self.transpose_y else 1
  1674. def forward(self, hidden_states, tensor_parallel_output=None):
  1675. if tensor_parallel_output is None:
  1676. tensor_parallel_output = self.config.tensor_parallel_output
  1677. # 确保数据类型一致
  1678. if self.weight.dtype != hidden_states.dtype:
  1679. hidden_states = paddle.cast(hidden_states, self.weight.dtype)
  1680. logits = parallel_matmul(
  1681. hidden_states,
  1682. self.weight,
  1683. transpose_y=self.transpose_y,
  1684. tensor_parallel_output=tensor_parallel_output,
  1685. )
  1686. return logits
  1687. class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel):
  1688. _tied_weights_keys = ["lm_head.weight"]
  1689. def __init__(self, config, attn_implementation="flash_attention_2"):
  1690. super().__init__(config)
  1691. config._attn_implementation = attn_implementation
  1692. config.vision_config._attn_implementation = attn_implementation
  1693. self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
  1694. config.vision_config
  1695. )
  1696. self.model = Qwen2VLModel(config)
  1697. self.vocab_size = config.vocab_size
  1698. if config.tie_word_embeddings:
  1699. self.lm_head = Qwen2LMHead(
  1700. config,
  1701. embedding_weights=self.model.embed_tokens.weight,
  1702. transpose_y=True,
  1703. )
  1704. self.tie_weights()
  1705. else:
  1706. self.lm_head = Qwen2LMHead(config)
  1707. self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
  1708. def get_input_embeddings(self):
  1709. return self.model.embed_tokens
  1710. def set_input_embeddings(self, value):
  1711. self.model.embed_tokens = value
  1712. def get_output_embeddings(self):
  1713. return self.lm_head
  1714. def set_output_embeddings(self, new_embeddings):
  1715. self.lm_head = new_embeddings
  1716. def set_decoder(self, decoder):
  1717. self.model = decoder
  1718. def get_decoder(self):
  1719. return self.model
  1720. @staticmethod
  1721. def get_rope_index(
  1722. spatial_merge_size,
  1723. image_token_id,
  1724. video_token_id,
  1725. vision_start_token_id,
  1726. input_ids: paddle.Tensor,
  1727. image_grid_thw: Optional[paddle.Tensor] = None,
  1728. video_grid_thw: Optional[paddle.Tensor] = None,
  1729. attention_mask: Optional[paddle.Tensor] = None,
  1730. ) -> Tuple[paddle.Tensor, paddle.Tensor]:
  1731. """
  1732. Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
  1733. Explanation:
  1734. Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
  1735. For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
  1736. Examples:
  1737. input_ids: [T T T T T], here T is for text.
  1738. temporal position_ids: [0, 1, 2, 3, 4]
  1739. height position_ids: [0, 1, 2, 3, 4]
  1740. width position_ids: [0, 1, 2, 3, 4]
  1741. For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
  1742. and 1D rotary position embedding for text part.
  1743. Examples:
  1744. Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
  1745. input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
  1746. vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
  1747. vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
  1748. vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
  1749. text temporal position_ids: [3, 4, 5, 6, 7]
  1750. text height position_ids: [3, 4, 5, 6, 7]
  1751. text width position_ids: [3, 4, 5, 6, 7]
  1752. Here we calculate the text start position_ids as the max vision position_ids plus 1.
  1753. Args:
  1754. input_ids (`paddle.Tensor` of shape `(batch_size, sequence_length)`):
  1755. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  1756. it.
  1757. image_grid_thw (`paddle.Tensor` of shape `(num_images, 3)`, *optional*):
  1758. The temporal, height and width of feature shape of each image in LLM.
  1759. video_grid_thw (`paddle.Tensor` of shape `(num_videos, 3)`, *optional*):
  1760. The temporal, height and width of feature shape of each video in LLM.
  1761. attention_mask (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1762. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1763. - 1 for tokens that are **not masked**,
  1764. - 0 for tokens that are **masked**.
  1765. Returns:
  1766. position_ids (`paddle.Tensor` of shape `(3, batch_size, sequence_length)`)
  1767. mrope_position_deltas (`paddle.Tensor` of shape `(batch_size)`)
  1768. """
  1769. mrope_position_deltas = []
  1770. if image_grid_thw is not None or video_grid_thw is not None:
  1771. total_input_ids = input_ids
  1772. position_ids = paddle.ones(
  1773. [3, input_ids.shape[0], input_ids.shape[1]], dtype=input_ids.dtype
  1774. )
  1775. image_index, video_index = 0, 0
  1776. for i, input_ids in enumerate(total_input_ids):
  1777. # TODO: CUDA error in some paddle version
  1778. if attention_mask is not None:
  1779. input_ids = paddle.to_tensor(
  1780. input_ids.cpu()[attention_mask[i].cpu() == 1]
  1781. ) # NOTE 原始写法
  1782. image_nums, video_nums = 0, 0
  1783. vision_start_indices = paddle.nonzero(
  1784. input_ids == vision_start_token_id
  1785. ).squeeze(
  1786. 1
  1787. ) # NOTE 原始写法
  1788. vision_tokens = input_ids[vision_start_indices + 1]
  1789. image_nums = (
  1790. (vision_tokens == image_token_id).sum()
  1791. if vision_tokens.numel() > 0
  1792. else 0
  1793. )
  1794. video_nums = (
  1795. (vision_tokens == video_token_id).sum()
  1796. if vision_tokens.numel() > 0
  1797. else 0
  1798. )
  1799. input_tokens = input_ids.tolist()
  1800. llm_pos_ids_list: list = []
  1801. st = 0
  1802. remain_images, remain_videos = image_nums, video_nums
  1803. for _ in range(image_nums + video_nums):
  1804. if image_token_id in input_tokens and remain_images > 0:
  1805. ed_image = input_tokens.index(image_token_id, st)
  1806. else:
  1807. ed_image = len(input_tokens) + 1
  1808. if video_token_id in input_tokens and remain_videos > 0:
  1809. ed_video = input_tokens.index(video_token_id, st)
  1810. else:
  1811. ed_video = len(input_tokens) + 1
  1812. if ed_image < ed_video:
  1813. t, h, w = (
  1814. image_grid_thw[image_index][0],
  1815. image_grid_thw[image_index][1],
  1816. image_grid_thw[image_index][2],
  1817. )
  1818. image_index += 1
  1819. remain_images -= 1
  1820. ed = ed_image
  1821. else:
  1822. t, h, w = (
  1823. video_grid_thw[video_index][0],
  1824. video_grid_thw[video_index][1],
  1825. video_grid_thw[video_index][2],
  1826. )
  1827. video_index += 1
  1828. remain_videos -= 1
  1829. ed = ed_video
  1830. llm_grid_t, llm_grid_h, llm_grid_w = (
  1831. t.item(),
  1832. h.item() // spatial_merge_size,
  1833. w.item() // spatial_merge_size,
  1834. )
  1835. text_len = ed - st
  1836. st_idx = (
  1837. llm_pos_ids_list[-1].max() + 1
  1838. if len(llm_pos_ids_list) > 0
  1839. else 0
  1840. )
  1841. llm_pos_ids_list.append(
  1842. paddle.arange(text_len).reshape([1, -1]).expand([3, -1])
  1843. + st_idx
  1844. )
  1845. t_index = (
  1846. paddle.arange(llm_grid_t)
  1847. .reshape([-1, 1])
  1848. .expand([-1, llm_grid_h * llm_grid_w])
  1849. .flatten()
  1850. )
  1851. h_index = (
  1852. paddle.arange(llm_grid_h)
  1853. .reshape([1, -1, 1])
  1854. .expand([llm_grid_t, -1, llm_grid_w])
  1855. .flatten()
  1856. )
  1857. w_index = (
  1858. paddle.arange(llm_grid_w)
  1859. .reshape([1, 1, -1])
  1860. .expand([llm_grid_t, llm_grid_h, -1])
  1861. .flatten()
  1862. )
  1863. llm_pos_ids_list.append(
  1864. paddle.stack([t_index, h_index, w_index]) + text_len + st_idx
  1865. )
  1866. st = ed + llm_grid_t * llm_grid_h * llm_grid_w
  1867. if st < len(input_tokens):
  1868. st_idx = (
  1869. llm_pos_ids_list[-1].max() + 1
  1870. if len(llm_pos_ids_list) > 0
  1871. else 0
  1872. )
  1873. text_len = len(input_tokens) - st
  1874. llm_pos_ids_list.append(
  1875. paddle.arange(text_len).reshape([1, -1]).expand([3, -1])
  1876. + st_idx
  1877. )
  1878. llm_positions = paddle.concat(llm_pos_ids_list, axis=1).reshape([3, -1])
  1879. if _IS_NPU:
  1880. bool_indices = (
  1881. (attention_mask[i] == 1)
  1882. .unsqueeze(0)
  1883. .tile([position_ids.shape[0], 1])
  1884. )
  1885. position_ids[:, i] = paddle.index_put(
  1886. position_ids[:, i], [bool_indices], llm_positions.reshape([-1])
  1887. )
  1888. else:
  1889. position_ids[..., i, attention_mask[i] == 1] = llm_positions
  1890. mrope_position_deltas.append(
  1891. llm_positions.max() + 1 - len(total_input_ids[i])
  1892. )
  1893. mrope_position_deltas = paddle.to_tensor(mrope_position_deltas).unsqueeze(1)
  1894. else:
  1895. if attention_mask is not None:
  1896. position_ids = paddle.cast(attention_mask, dtype="int64").cumsum(-1) - 1
  1897. position_ids.masked_fill_(mask=attention_mask == 0, value=1)
  1898. position_ids = position_ids.unsqueeze(0).expand([3, -1, -1])
  1899. max_position_ids = position_ids.max(0, keepdim=False)[0].max(
  1900. -1, keepdim=True
  1901. )[0]
  1902. mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
  1903. else:
  1904. position_ids = (
  1905. paddle.arange(input_ids.shape[1])
  1906. .reshape([1, 1, -1])
  1907. .expand(shape=[3, input_ids.shape[0], -1])
  1908. )
  1909. mrope_position_deltas = paddle.zeros(
  1910. [input_ids.shape[0], 1], dtype=input_ids.dtype
  1911. )
  1912. return position_ids, mrope_position_deltas
  1913. def update_model_kwargs_for_generation(
  1914. self,
  1915. outputs: ModelOutput,
  1916. model_kwargs: Dict[str, Any],
  1917. is_encoder_decoder: bool = False,
  1918. # num_new_tokens: int = 1,
  1919. ) -> Dict[str, Any]:
  1920. model_kwargs = super().update_model_kwargs_for_generation(
  1921. outputs=outputs,
  1922. model_kwargs=model_kwargs,
  1923. is_encoder_decoder=is_encoder_decoder,
  1924. )
  1925. if getattr(outputs, "rope_deltas", None) is not None:
  1926. model_kwargs["rope_deltas"] = outputs.rope_deltas
  1927. return model_kwargs
  1928. def forward(
  1929. self,
  1930. input_ids: paddle.Tensor = None,
  1931. attention_mask: Optional[paddle.Tensor] = None,
  1932. position_ids: Optional[paddle.Tensor] = None,
  1933. past_key_values: Optional[List[paddle.Tensor]] = None,
  1934. inputs_embeds: Optional[paddle.Tensor] = None,
  1935. labels: Optional[paddle.Tensor] = None,
  1936. use_cache: Optional[bool] = None,
  1937. output_attentions: Optional[bool] = None,
  1938. output_hidden_states: Optional[bool] = None,
  1939. return_dict: Optional[bool] = None,
  1940. pixel_values: Optional[paddle.Tensor] = None,
  1941. pixel_values_videos: Optional[paddle.Tensor] = None,
  1942. image_grid_thw: Optional[paddle.Tensor] = None,
  1943. video_grid_thw: Optional[paddle.Tensor] = None,
  1944. rope_deltas: Optional[paddle.Tensor] = None,
  1945. ):
  1946. """
  1947. Args:
  1948. labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1949. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1950. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1951. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1952. """
  1953. output_attentions = (
  1954. output_attentions
  1955. if output_attentions is not None
  1956. else self.config.output_attentions
  1957. )
  1958. output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip
  1959. return_dict = True # return_dict if return_dict is not None else self.config.use_return_dict
  1960. if inputs_embeds is None:
  1961. inputs_embeds = self.model.embed_tokens(input_ids)
  1962. if pixel_values is not None:
  1963. pixel_values = paddle.cast(pixel_values, inputs_embeds.dtype)
  1964. image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  1965. image_embeds = paddle.cast(image_embeds, inputs_embeds.dtype)
  1966. image_mask = input_ids == self.config.image_token_id
  1967. if self.training:
  1968. inputs_embeds = inputs_embeds.clone()
  1969. inputs_embeds[image_mask] = image_embeds
  1970. if pixel_values_videos is not None:
  1971. pixel_values_videos = paddle.cast(
  1972. pixel_values_videos, inputs_embeds.dtype
  1973. )
  1974. video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
  1975. video_embeds = paddle.cast(video_embeds, inputs_embeds.dtype)
  1976. video_mask = input_ids == self.config.video_token_id
  1977. inputs_embeds[video_mask] = video_embeds
  1978. if attention_mask is not None:
  1979. attention_mask = attention_mask
  1980. outputs = self.model(
  1981. input_ids=None,
  1982. position_ids=position_ids,
  1983. attention_mask=attention_mask,
  1984. past_key_values=past_key_values,
  1985. inputs_embeds=inputs_embeds,
  1986. use_cache=use_cache,
  1987. output_attentions=output_attentions,
  1988. output_hidden_states=output_hidden_states,
  1989. return_dict=return_dict,
  1990. )
  1991. hidden_states = outputs[0]
  1992. tensor_parallel_output = (
  1993. self.config.tensor_parallel_output
  1994. and self.config.tensor_parallel_degree > 1
  1995. )
  1996. logits = self.lm_head(
  1997. hidden_states, tensor_parallel_output=tensor_parallel_output
  1998. )
  1999. logits = paddle.cast(logits, "float32")
  2000. loss = None
  2001. if labels is not None:
  2002. # Shift so that tokens < n predict n
  2003. shift_logits = logits[..., :-1, :]
  2004. shift_labels = labels[..., 1:]
  2005. # Flatten the tokens
  2006. shift_logits = shift_logits.reshape([-1, self.config.vocab_size])
  2007. shift_labels = shift_labels.reshape([-1])
  2008. if _IS_NPU:
  2009. tmp = F.log_softmax(shift_logits, axis=1)
  2010. loss = F.nll_loss(tmp, shift_labels, reduction="sum")
  2011. else:
  2012. loss_fct = nn.CrossEntropyLoss(reduction="sum")
  2013. loss = loss_fct(shift_logits, shift_labels)
  2014. label_sum = paddle.sum(shift_labels != -100).cast("float32")
  2015. loss = loss / label_sum
  2016. if not return_dict:
  2017. output = (logits,) + tuple(outputs[1:])
  2018. return (loss,) + output if loss is not None else output
  2019. return Qwen2VLCausalLMOutputWithPast(
  2020. loss=loss,
  2021. logits=logits,
  2022. past_key_values=outputs.past_key_values,
  2023. hidden_states=outputs.hidden_states,
  2024. attentions=outputs.attentions,
  2025. rope_deltas=rope_deltas,
  2026. )
  2027. def prepare_inputs_for_generation(
  2028. self,
  2029. input_ids,
  2030. past_key_values=None,
  2031. attention_mask=None,
  2032. inputs_embeds=None,
  2033. cache_position=None,
  2034. position_ids=None,
  2035. use_cache=True,
  2036. pixel_values=None,
  2037. pixel_values_videos=None,
  2038. image_grid_thw=None,
  2039. video_grid_thw=None,
  2040. **kwargs,
  2041. ):
  2042. batch_size, seq_length = input_ids.shape
  2043. if past_key_values is None:
  2044. cache_position = paddle.arange(input_ids.shape[1])
  2045. else:
  2046. cache_position = paddle.to_tensor([seq_length - 1])
  2047. if past_key_values is not None:
  2048. input_ids = input_ids[:, -1].unsqueeze(-1)
  2049. rope_deltas = kwargs.get("rope_deltas", None)
  2050. if attention_mask is not None and position_ids is None:
  2051. if cache_position is None or (
  2052. cache_position is not None and cache_position[0] == 0
  2053. ):
  2054. position_ids, rope_deltas = self.get_rope_index(
  2055. self.config.vision_config.spatial_merge_size,
  2056. self.config.image_token_id,
  2057. self.config.video_token_id,
  2058. self.config.vision_start_token_id,
  2059. input_ids,
  2060. image_grid_thw,
  2061. video_grid_thw,
  2062. attention_mask,
  2063. )
  2064. else:
  2065. batch_size, seq_length = input_ids.shape
  2066. delta = (
  2067. cache_position[0] + rope_deltas
  2068. if cache_position is not None and rope_deltas is not None
  2069. else 0
  2070. )
  2071. position_ids = paddle.arange(seq_length)
  2072. position_ids = position_ids.reshape([1, -1]).expand([batch_size, -1])
  2073. position_ids = position_ids + delta
  2074. position_ids = position_ids.unsqueeze(axis=0).expand([3, -1, -1])
  2075. if cache_position[0] != 0:
  2076. pixel_values = None
  2077. pixel_values_videos = None
  2078. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  2079. if inputs_embeds is not None and cache_position[0] == 0:
  2080. model_inputs = {"inputs_embeds": inputs_embeds}
  2081. else:
  2082. model_inputs = {"input_ids": input_ids}
  2083. model_inputs.update(
  2084. {
  2085. "position_ids": position_ids, # [3, 1, 3602]
  2086. "past_key_values": past_key_values, # DynamicCache()
  2087. "use_cache": use_cache, # 1
  2088. "attention_mask": attention_mask, # [1, 3602]
  2089. "pixel_values": pixel_values, # [14308, 1176]
  2090. "pixel_values_videos": pixel_values_videos,
  2091. "image_grid_thw": image_grid_thw, # [[ 1, 98, 146]]
  2092. "video_grid_thw": video_grid_thw,
  2093. "rope_deltas": rope_deltas, # [[-3504]]
  2094. }
  2095. )
  2096. return model_inputs
  2097. def gme_qwen2_vl_forward(
  2098. self,
  2099. input_ids: paddle.Tensor = None,
  2100. attention_mask: Optional[paddle.Tensor] = None,
  2101. position_ids: Optional[paddle.Tensor] = None,
  2102. past_key_values: Optional[List[paddle.Tensor]] = None,
  2103. inputs_embeds: Optional[paddle.Tensor] = None,
  2104. labels: Optional[paddle.Tensor] = None,
  2105. use_cache: Optional[bool] = None,
  2106. output_attentions: Optional[bool] = None,
  2107. output_hidden_states: Optional[bool] = None,
  2108. return_dict: Optional[bool] = None,
  2109. pixel_values: Optional[paddle.Tensor] = None,
  2110. pixel_values_videos: Optional[paddle.Tensor] = None,
  2111. image_grid_thw: Optional[paddle.Tensor] = None,
  2112. video_grid_thw: Optional[paddle.Tensor] = None,
  2113. rope_deltas: Optional[paddle.Tensor] = None,
  2114. ):
  2115. output_attentions = (
  2116. output_attentions
  2117. if output_attentions is not None
  2118. else self.config.output_attentions
  2119. )
  2120. output_hidden_states = (
  2121. output_hidden_states
  2122. if output_hidden_states is not None
  2123. else self.config.output_hidden_states
  2124. )
  2125. return_dict = True # return_dict if return_dict is not None else self.config.use_return_dict
  2126. if inputs_embeds is None:
  2127. inputs_embeds = self.model.embed_tokens(input_ids)
  2128. if pixel_values is not None:
  2129. # 确保 pixel_values 和 inputs_embeds 使用相同的数据类型
  2130. pixel_values = paddle.cast(pixel_values, inputs_embeds.dtype)
  2131. image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  2132. # 确保 image_embeds 和 inputs_embeds 使用相同的数据类型
  2133. image_embeds = paddle.cast(image_embeds, inputs_embeds.dtype)
  2134. image_mask = input_ids == self.config.image_token_id
  2135. if self.training:
  2136. inputs_embeds = inputs_embeds.clone()
  2137. inputs_embeds[image_mask] = image_embeds
  2138. if pixel_values_videos is not None:
  2139. # 确保 pixel_values_videos 和 inputs_embeds 使用相同的数据类型
  2140. pixel_values_videos = paddle.cast(
  2141. pixel_values_videos, inputs_embeds.dtype
  2142. )
  2143. video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
  2144. # 确保 video_embeds 和 inputs_embeds 使用相同的数据类型
  2145. video_embeds = paddle.cast(video_embeds, inputs_embeds.dtype)
  2146. video_mask = input_ids == self.config.video_token_id
  2147. inputs_embeds[video_mask] = video_embeds
  2148. if attention_mask is not None:
  2149. attention_mask = attention_mask
  2150. outputs = self.model(
  2151. input_ids=None,
  2152. position_ids=position_ids,
  2153. attention_mask=attention_mask,
  2154. past_key_values=past_key_values,
  2155. inputs_embeds=inputs_embeds,
  2156. use_cache=use_cache,
  2157. output_attentions=output_attentions,
  2158. output_hidden_states=output_hidden_states,
  2159. return_dict=return_dict,
  2160. )
  2161. hidden_states = outputs[0]
  2162. # get last hidden state
  2163. last_hidden_state = hidden_states[:, -1, :]
  2164. return last_hidden_state
  2165. class PPDocBeeInference(Qwen2VLForConditionalGeneration):
  2166. set_inference_operations(get_inference_operations() + ["docbee_generate"])
  2167. @benchmark.timeit_with_options(name="docbee_generate")
  2168. def generate(self, inputs, **kwargs):
  2169. max_new_tokens = kwargs.get("max_new_tokens", 2048)
  2170. temperature = kwargs.get("temperature", 0.1)
  2171. top_p = kwargs.get("top_p", 0.001)
  2172. top_k = kwargs.get("top_k", 1)
  2173. with paddle.no_grad():
  2174. generated_ids = super().generate(
  2175. **inputs,
  2176. max_new_tokens=max_new_tokens,
  2177. temperature=temperature,
  2178. top_p=top_p,
  2179. top_k=top_k,
  2180. )
  2181. return generated_ids