qwen2_vl.py 103 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. from dataclasses import dataclass
  17. from functools import partial
  18. from typing import Any, Dict, List, Optional, Tuple, Union
  19. import paddle
  20. import paddle.distributed.fleet.meta_parallel as mpu
  21. import paddle.nn as nn
  22. import paddle.nn.functional as F
  23. from paddle import Tensor
  24. from paddle.distributed import fleet
  25. from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
  26. from paddle.distributed.fleet.utils import recompute
  27. from .....utils import logging
  28. from ....utils.benchmark import (
  29. benchmark,
  30. get_inference_operations,
  31. set_inference_operations,
  32. )
  33. from ...common.vlm.activations import ACT2FN
  34. from ...common.vlm.bert_padding import index_first_axis, pad_input, unpad_input
  35. from ...common.vlm.flash_attn_utils import has_flash_attn_func
  36. from ...common.vlm.transformers import PretrainedConfig, PretrainedModel
  37. from ...common.vlm.transformers.model_outputs import (
  38. BaseModelOutputWithPast,
  39. ModelOutput,
  40. )
  41. flash_attn_func, flash_attn_varlen_func = has_flash_attn_func()
  42. _IS_NPU = "npu" in paddle.get_device()
  43. Linear = nn.Linear
  44. ColumnParallelLinear = mpu.ColumnParallelLinear
  45. RowParallelLinear = mpu.RowParallelLinear
  46. class Qwen2VLVisionConfig(PretrainedConfig):
  47. model_type = "qwen2_vl"
  48. def __init__(
  49. self,
  50. depth=32,
  51. embed_dim=1280,
  52. hidden_size=3584,
  53. hidden_act="quick_gelu",
  54. mlp_ratio=4,
  55. num_heads=16,
  56. in_channels=3,
  57. patch_size=14,
  58. spatial_merge_size=2,
  59. temporal_patch_size=2,
  60. attn_implementation="eager", # new added
  61. **kwargs,
  62. ):
  63. super().__init__(**kwargs)
  64. self.depth = depth
  65. self.embed_dim = embed_dim
  66. self.hidden_size = hidden_size
  67. self.hidden_act = hidden_act
  68. self.mlp_ratio = mlp_ratio
  69. self.num_heads = num_heads
  70. self.in_channels = in_channels
  71. self.patch_size = patch_size
  72. self.spatial_merge_size = spatial_merge_size
  73. self.temporal_patch_size = temporal_patch_size
  74. self.attn_implementation = attn_implementation
  75. @classmethod
  76. def from_pretrained(
  77. cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
  78. ) -> "PretrainedConfig":
  79. config_dict, kwargs = cls.get_config_dict(
  80. pretrained_model_name_or_path, **kwargs
  81. )
  82. if config_dict.get("model_type") == "qwen2_vl":
  83. config_dict = config_dict["vision_config"]
  84. if (
  85. "model_type" in config_dict
  86. and hasattr(cls, "model_type")
  87. and config_dict["model_type"] != cls.model_type
  88. ):
  89. logging.warning(
  90. f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
  91. f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
  92. )
  93. return cls.from_dict(config_dict, **kwargs)
  94. class Qwen2VLConfig(PretrainedConfig):
  95. r"""
  96. This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
  97. Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
  98. with the defaults will yield a similar configuration to that of
  99. Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
  100. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  101. documentation from [`PretrainedConfig`] for more information.
  102. Args:
  103. vocab_size (`int`, *optional*, defaults to 152064):
  104. Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
  105. `inputs_ids` passed when calling [`Qwen2VLModel`]
  106. hidden_size (`int`, *optional*, defaults to 8192):
  107. Dimension of the hidden representations.
  108. intermediate_size (`int`, *optional*, defaults to 29568):
  109. Dimension of the MLP representations.
  110. num_hidden_layers (`int`, *optional*, defaults to 80):
  111. Number of hidden layers in the Transformer encoder.
  112. num_attention_heads (`int`, *optional*, defaults to 64):
  113. Number of attention heads for each attention layer in the Transformer encoder.
  114. num_key_value_heads (`int`, *optional*, defaults to 8):
  115. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  116. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  117. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  118. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  119. by meanpooling all the original heads within that group. For more details checkout [this
  120. paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
  121. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  122. The non-linear activation function (function or string) in the decoder.
  123. max_position_embeddings (`int`, *optional*, defaults to 32768):
  124. The maximum sequence length that this model might ever be used with.
  125. initializer_range (`float`, *optional*, defaults to 0.02):
  126. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  127. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  128. The epsilon used by the rms normalization layers.
  129. use_cache (`bool`, *optional*, defaults to `True`):
  130. Whether or not the model should return the last key/values attentions (not used by all models). Only
  131. relevant if `config.is_decoder=True`.
  132. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  133. Whether the model's input and output word embeddings should be tied.
  134. rope_theta (`float`, *optional*, defaults to 1000000.0):
  135. The base period of the RoPE embeddings.
  136. use_sliding_window (`bool`, *optional*, defaults to `False`):
  137. Whether to use sliding window attention.
  138. sliding_window (`int`, *optional*, defaults to 4096):
  139. Sliding window attention (SWA) window size. If not specified, will default to `4096`.
  140. max_window_layers (`int`, *optional*, defaults to 80):
  141. The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
  142. attention_dropout (`float`, *optional*, defaults to 0.0):
  143. The dropout ratio for the attention probabilities.
  144. vision_config (`Dict`, *optional*):
  145. The config for the visual encoder initialization.
  146. rope_scaling (`Dict`, *optional*):
  147. Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
  148. strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
  149. `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
  150. `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
  151. these scaling strategies behave:
  152. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
  153. experimental feature, subject to breaking API changes in future versions.
  154. """
  155. model_type = "qwen2_vl"
  156. keys_to_ignore_at_inference = ["past_key_values"]
  157. def __init__(
  158. self,
  159. vocab_size=152064,
  160. hidden_size=8192,
  161. intermediate_size=29568,
  162. num_hidden_layers=80,
  163. num_attention_heads=64,
  164. num_key_value_heads=8,
  165. hidden_act="silu",
  166. max_position_embeddings=32768,
  167. initializer_range=0.02,
  168. rms_norm_eps=1e-05,
  169. use_cache=True,
  170. tie_word_embeddings=False,
  171. rope_theta=1000000.0,
  172. use_sliding_window=False,
  173. sliding_window=4096,
  174. max_window_layers=80,
  175. attention_dropout=0.0,
  176. vision_config=None,
  177. rope_scaling=None,
  178. **kwargs,
  179. ):
  180. if isinstance(vision_config, dict):
  181. self.vision_config = Qwen2VLVisionConfig(**vision_config)
  182. elif vision_config is None:
  183. self.vision_config = Qwen2VLVisionConfig()
  184. self.vocab_size = vocab_size
  185. self.max_position_embeddings = max_position_embeddings
  186. self.hidden_size = hidden_size
  187. self.intermediate_size = intermediate_size
  188. self.num_hidden_layers = num_hidden_layers
  189. self.num_attention_heads = num_attention_heads
  190. self.use_sliding_window = use_sliding_window
  191. self.sliding_window = sliding_window
  192. self.max_window_layers = max_window_layers
  193. if num_key_value_heads is None:
  194. num_key_value_heads = num_attention_heads
  195. self.num_key_value_heads = num_key_value_heads
  196. self.hidden_act = hidden_act
  197. self.initializer_range = initializer_range
  198. self.rms_norm_eps = rms_norm_eps
  199. self.use_cache = use_cache
  200. self.rope_theta = rope_theta
  201. self.attention_dropout = attention_dropout
  202. self.rope_scaling = rope_scaling
  203. super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
  204. def get_triangle_upper_mask(x, mask=None):
  205. if mask is not None:
  206. return mask
  207. shape = x.shape
  208. shape[1] = 1
  209. mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype)
  210. mask = paddle.triu(mask, diagonal=1)
  211. mask.stop_gradient = True
  212. return mask
  213. def parallel_matmul(
  214. x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True
  215. ):
  216. is_fleet_init = True
  217. tensor_parallel_degree = 1
  218. try:
  219. hcg = fleet.get_hybrid_communicate_group()
  220. model_parallel_group = hcg.get_model_parallel_group()
  221. tensor_parallel_degree = hcg.get_model_parallel_world_size()
  222. except:
  223. is_fleet_init = False
  224. if paddle.in_dynamic_mode():
  225. y_is_distributed = y.is_distributed
  226. else:
  227. y_is_distributed = tensor_parallel_degree > 1
  228. if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
  229. input_parallel = paddle.distributed.collective._c_identity(
  230. x, group=model_parallel_group
  231. )
  232. logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
  233. if tensor_parallel_output:
  234. return logits
  235. return paddle.distributed.collective._c_concat(
  236. logits, group=model_parallel_group
  237. )
  238. else:
  239. logits = paddle.matmul(x, y, transpose_y=transpose_y)
  240. return logits
  241. def _compute_default_rope_parameters(
  242. config: Optional[PretrainedConfig] = None,
  243. device: Optional["paddle.device"] = None,
  244. seq_len: Optional[int] = None,
  245. **rope_kwargs,
  246. ) -> Tuple["paddle.Tensor", float]:
  247. """
  248. Computes the inverse frequencies according to the original RoPE implementation
  249. Args:
  250. config ([`~transformers.PretrainedConfig`]):
  251. The model configuration.
  252. device (`paddle.device`):
  253. The device to use for initialization of the inverse frequencies.
  254. seq_len (`int`, *optional*):
  255. The current sequence length. Unused for this type of RoPE.
  256. rope_kwargs (`Dict`, *optional*):
  257. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
  258. Returns:
  259. Tuple of (`paddle.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  260. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  261. """
  262. if config is not None and len(rope_kwargs) > 0:
  263. raise ValueError(
  264. "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
  265. f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
  266. )
  267. if len(rope_kwargs) > 0:
  268. base = rope_kwargs["base"]
  269. dim = rope_kwargs["dim"]
  270. elif config is not None:
  271. base = config.rope_theta
  272. partial_rotary_factor = (
  273. config.partial_rotary_factor
  274. if hasattr(config, "partial_rotary_factor")
  275. else 1.0
  276. )
  277. head_dim = getattr(
  278. config, "head_dim", config.hidden_size // config.num_attention_heads
  279. )
  280. dim = int(head_dim * partial_rotary_factor)
  281. attention_factor = 1.0
  282. inv_freq = 1.0 / (
  283. base ** (paddle.arange(0, dim, 2, dtype="int64").astype("float32") / dim)
  284. )
  285. return inv_freq, attention_factor
  286. ROPE_INIT_FUNCTIONS = {
  287. "default": _compute_default_rope_parameters,
  288. }
  289. def _get_unpad_data(attention_mask):
  290. seqlens_in_batch = attention_mask.sum(axis=-1, dtype="int32")
  291. indices = paddle.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  292. max_seqlen_in_batch = seqlens_in_batch.max().item() # [2, 1, 1323]
  293. cu_seqlens = F.pad(
  294. paddle.cumsum(seqlens_in_batch, axis=0), (1, 0), data_format="NCL"
  295. )
  296. return (
  297. indices,
  298. cu_seqlens,
  299. max_seqlen_in_batch,
  300. )
  301. def is_casual_mask(attention_mask):
  302. """
  303. Upper triangular of attention_mask equals to attention_mask is casual
  304. """
  305. return (paddle.triu(attention_mask) == attention_mask).all().item()
  306. def _make_causal_mask(input_ids_shape, past_key_values_length):
  307. """
  308. Make causal mask used for self-attention
  309. """
  310. batch_size, target_length = input_ids_shape
  311. mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
  312. if past_key_values_length > 0:
  313. mask = paddle.concat(
  314. [paddle.ones([target_length, past_key_values_length], dtype="bool"), mask],
  315. axis=-1,
  316. )
  317. return mask[None, None, :, :].expand(
  318. [batch_size, 1, target_length, target_length + past_key_values_length]
  319. )
  320. def _expand_2d_mask(mask, dtype, tgt_length):
  321. """
  322. Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
  323. """
  324. batch_size, src_length = mask.shape[0], mask.shape[-1]
  325. tgt_length = tgt_length if tgt_length is not None else src_length
  326. mask = mask[:, None, None, :].astype("bool")
  327. mask.stop_gradient = True
  328. expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
  329. return expanded_mask
  330. @dataclass
  331. class Qwen2VLCausalLMOutputWithPast(ModelOutput):
  332. """
  333. Base class for Qwen2VL causal language model (or autoregressive) outputs.
  334. Args:
  335. loss (`paddle.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  336. Language modeling loss (for next-token prediction).
  337. logits (`paddle.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  338. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  339. past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  340. Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  341. `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
  342. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  343. `past_key_values` input) to speed up sequential decoding.
  344. hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  345. Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
  346. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  347. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  348. attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  349. Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  350. sequence_length)`.
  351. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  352. heads.
  353. rope_deltas (`paddle.Tensor` of shape `(batch_size, )`, *optional*):
  354. The rope index difference between sequence length and multimodal rope.
  355. """
  356. loss: Optional[paddle.Tensor] = None
  357. logits: paddle.Tensor = None
  358. past_key_values: Optional[List[paddle.Tensor]] = None
  359. hidden_states: Optional[Tuple[paddle.Tensor]] = None
  360. attentions: Optional[Tuple[paddle.Tensor]] = None
  361. rope_deltas: Optional[paddle.Tensor] = None
  362. class Qwen2VLRotaryEmbedding(nn.Layer):
  363. def __init__(
  364. self,
  365. dim=None,
  366. max_position_embeddings=2048,
  367. base=10000,
  368. device=None,
  369. scaling_factor=1.0,
  370. rope_type="default",
  371. config: Optional[Qwen2VLConfig] = None,
  372. ):
  373. super().__init__()
  374. self.rope_kwargs = {}
  375. if config is None:
  376. self.rope_kwargs = {
  377. "rope_type": rope_type,
  378. "factor": scaling_factor,
  379. "dim": dim,
  380. "base": base,
  381. "max_position_embeddings": max_position_embeddings,
  382. }
  383. self.rope_type = rope_type
  384. self.max_seq_len_cached = max_position_embeddings
  385. self.original_max_seq_len = max_position_embeddings
  386. else:
  387. # BC: "rope_type" was originally "type"
  388. if config.rope_scaling is not None:
  389. self.rope_type = config.rope_scaling.get(
  390. "rope_type", config.rope_scaling.get("type")
  391. )
  392. else:
  393. self.rope_type = "default"
  394. self.max_seq_len_cached = config.max_position_embeddings
  395. self.original_max_seq_len = config.max_position_embeddings
  396. self.config = config
  397. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  398. self.inv_freq, self.attention_scaling = self.rope_init_fn(
  399. self.config, device, **self.rope_kwargs
  400. )
  401. self.original_inv_freq = self.inv_freq
  402. self._set_cos_sin_cache(seq_len=max_position_embeddings)
  403. def _set_cos_sin_cache(self, seq_len):
  404. self.max_seq_len_cached = seq_len
  405. t = paddle.arange(seq_len, dtype="float32")
  406. freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
  407. emb = paddle.concat([freqs, freqs], axis=-1)
  408. self.cos_cached = emb.cos()
  409. self.sin_cached = emb.sin()
  410. def _dynamic_frequency_update(self, position_ids, device):
  411. """
  412. dynamic RoPE layers should recompute `inv_freq` in the following situations:
  413. 1 - growing beyond the cached sequence length (allow scaling)
  414. 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
  415. """
  416. seq_len = paddle.max(position_ids) + 1
  417. if seq_len > self.max_seq_len_cached: # growth
  418. inv_freq, self.attention_scaling = self.rope_init_fn(
  419. self.config, device, seq_len=seq_len, **self.rope_kwargs
  420. )
  421. self.inv_freq = inv_freq
  422. self.max_seq_len_cached = seq_len
  423. if (
  424. seq_len < self.original_max_seq_len
  425. and self.max_seq_len_cached > self.original_max_seq_len
  426. ): # reset
  427. self.inv_freq = self.original_inv_freq
  428. self.max_seq_len_cached = self.original_max_seq_len
  429. @paddle.no_grad()
  430. def forward(self, x, position_ids):
  431. if "dynamic" in self.rope_type:
  432. self._dynamic_frequency_update(position_ids, device=x.device)
  433. inv_freq_expanded = (
  434. self.inv_freq[None, None, :, None]
  435. .astype("float32")
  436. .expand([3, position_ids.shape[1], -1, 1])
  437. )
  438. position_ids_expanded = position_ids[:, :, None, :].astype("float32")
  439. device_type = paddle.get_device()
  440. device_type = (
  441. device_type
  442. if isinstance(device_type, str) and device_type != "mps"
  443. else "cpu"
  444. )
  445. with paddle.amp.auto_cast():
  446. freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded)
  447. freqs = freqs.transpose([0, 1, 3, 2])
  448. emb = paddle.concat((freqs, freqs), axis=-1)
  449. cos = emb.cos()
  450. sin = emb.sin()
  451. cos = cos * self.attention_scaling
  452. sin = sin * self.attention_scaling
  453. return cos.astype(x.dtype), sin.astype(x.dtype)
  454. # Copied from transformers.models.llama.modeling_llama.rotate_half
  455. def rotate_half(x):
  456. """Rotates half the hidden dims of the input."""
  457. x1 = x[..., : x.shape[-1] // 2]
  458. x2 = x[..., x.shape[-1] // 2 :]
  459. return paddle.concat([-x2, x1], axis=-1)
  460. def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
  461. """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
  462. Explanation:
  463. Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
  464. sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
  465. vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
  466. Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
  467. For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
  468. height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
  469. difference with modern LLMs.
  470. Args:
  471. q (`paddle.Tensor`): The query tensor.
  472. k (`paddle.Tensor`): The key tensor.
  473. cos (`paddle.Tensor`): The cosine part of the rotary embedding.
  474. sin (`paddle.Tensor`): The sine part of the rotary embedding.
  475. position_ids (`paddle.Tensor`):
  476. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  477. used to pass offsetted position ids when working with a KV-cache.
  478. mrope_section(`List(int)`):
  479. Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
  480. unsqueeze_dim (`int`, *optional*, defaults to 1):
  481. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  482. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  483. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  484. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  485. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  486. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  487. Returns:
  488. `tuple(paddle.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  489. """
  490. mrope_section = mrope_section * 2
  491. cos = paddle.concat(
  492. x=[m[i % 3] for i, m in enumerate(cos.split(mrope_section, axis=-1))], axis=-1
  493. ).unsqueeze(axis=unsqueeze_dim)
  494. sin = paddle.concat(
  495. x=[m[i % 3] for i, m in enumerate(sin.split(mrope_section, axis=-1))], axis=-1
  496. ).unsqueeze(axis=unsqueeze_dim)
  497. q_embed = (q * cos) + (rotate_half(q) * sin)
  498. k_embed = (k * cos) + (rotate_half(k) * sin)
  499. return q_embed, k_embed
  500. def apply_rotary_pos_emb_vision(
  501. tensor: paddle.Tensor, freqs: paddle.Tensor
  502. ) -> paddle.Tensor:
  503. orig_dtype = tensor.dtype
  504. with paddle.amp.auto_cast(False):
  505. tensor = tensor.astype(dtype="float32")
  506. cos = freqs.cos()
  507. sin = freqs.sin()
  508. cos = (
  509. cos.unsqueeze(1)
  510. .tile(repeat_times=[1, 1, 2])
  511. .unsqueeze(0)
  512. .astype(dtype="float32")
  513. )
  514. sin = (
  515. sin.unsqueeze(1)
  516. .tile(repeat_times=[1, 1, 2])
  517. .unsqueeze(0)
  518. .astype(dtype="float32")
  519. )
  520. output = tensor * cos + rotate_half(tensor) * sin
  521. output = paddle.cast(output, orig_dtype)
  522. return output
  523. class VisionRotaryEmbedding(nn.Layer):
  524. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  525. super().__init__()
  526. self.inv_freq = 1.0 / theta ** (
  527. paddle.arange(start=0, end=dim, step=2, dtype="float32") / dim
  528. )
  529. def forward(self, seqlen: int) -> paddle.Tensor:
  530. seq = paddle.arange(seqlen).cast(self.inv_freq.dtype)
  531. freqs = paddle.outer(x=seq, y=self.inv_freq)
  532. return freqs
  533. class PatchEmbed(nn.Layer):
  534. def __init__(
  535. self,
  536. patch_size: int = 14,
  537. temporal_patch_size: int = 2,
  538. in_channels: int = 3,
  539. embed_dim: int = 1152,
  540. ) -> None:
  541. super().__init__()
  542. self.patch_size = patch_size
  543. self.temporal_patch_size = temporal_patch_size
  544. self.in_channels = in_channels
  545. self.embed_dim = embed_dim
  546. kernel_size = [temporal_patch_size, patch_size, patch_size]
  547. self.proj = nn.Conv3D(
  548. in_channels,
  549. embed_dim,
  550. kernel_size=kernel_size,
  551. stride=kernel_size,
  552. bias_attr=False,
  553. )
  554. def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
  555. target_dtype = self.proj.weight.dtype
  556. hidden_states = hidden_states.reshape(
  557. [
  558. -1,
  559. self.in_channels,
  560. self.temporal_patch_size,
  561. self.patch_size,
  562. self.patch_size,
  563. ]
  564. )
  565. # NOTE(changwenbin): AttributeError: 'Variable' object has no attribute 'to'.
  566. # hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).reshape([-1, self.embed_dim])
  567. # hidden_states = paddle.cast(hidden_states, dtype=target_dtype)
  568. hidden_states = self.proj(
  569. paddle.cast(hidden_states, dtype=target_dtype)
  570. ).reshape([-1, self.embed_dim])
  571. return hidden_states
  572. class PatchMerger(nn.Layer):
  573. def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
  574. super().__init__()
  575. self.hidden_size = context_dim * (spatial_merge_size**2)
  576. self.ln_q = nn.LayerNorm(context_dim, epsilon=1e-6)
  577. self.mlp = nn.Sequential(
  578. nn.Linear(self.hidden_size, self.hidden_size),
  579. nn.GELU(),
  580. nn.Linear(self.hidden_size, dim),
  581. )
  582. def forward(self, x: paddle.Tensor) -> paddle.Tensor:
  583. x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size]))
  584. return x
  585. class VisionMlp(nn.Layer):
  586. def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
  587. super().__init__()
  588. self.fc1 = nn.Linear(dim, hidden_dim)
  589. self.act = ACT2FN[hidden_act]
  590. self.fc2 = nn.Linear(hidden_dim, dim)
  591. def forward(self, x) -> paddle.Tensor:
  592. return self.fc2(self.act(self.fc1(x)))
  593. class VisionAttention(nn.Layer):
  594. def __init__(self, dim: int, num_heads: int = 16) -> None:
  595. super().__init__()
  596. self.num_heads = num_heads
  597. self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
  598. self.proj = nn.Linear(dim, dim)
  599. self.head_dim = dim // num_heads # must added
  600. def forward(
  601. self,
  602. hidden_states: paddle.Tensor,
  603. cu_seqlens: paddle.Tensor,
  604. rotary_pos_emb: paddle.Tensor = None,
  605. ) -> paddle.Tensor:
  606. seq_length = hidden_states.shape[0]
  607. q, k, v = (
  608. self.qkv(hidden_states)
  609. .reshape([seq_length, 3, self.num_heads, -1])
  610. .transpose([1, 0, 2, 3])
  611. .unbind(0)
  612. )
  613. q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
  614. k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
  615. attention_mask = paddle.zeros([1, seq_length, seq_length], dtype="bool")
  616. for i in range(1, len(cu_seqlens)):
  617. attention_mask[
  618. ...,
  619. cu_seqlens[i - 1] : cu_seqlens[i],
  620. cu_seqlens[i - 1] : cu_seqlens[i],
  621. ] = True
  622. zero = paddle.zeros(attention_mask.shape, dtype=hidden_states.dtype)
  623. neg_inf = paddle.full_like(
  624. attention_mask,
  625. paddle.finfo(hidden_states.dtype).min,
  626. dtype=hidden_states.dtype,
  627. )
  628. attention_mask = paddle.where(attention_mask, zero, neg_inf)
  629. q = q.transpose([1, 0, 2])
  630. k = k.transpose([1, 0, 2])
  631. v = v.transpose([1, 0, 2])
  632. attn_weights = paddle.matmul(q, k.transpose([0, 2, 1])) / math.sqrt(
  633. self.head_dim
  634. )
  635. attn_weights = attn_weights + attention_mask
  636. attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype="float32")
  637. attn_output = paddle.matmul(attn_weights, v)
  638. attn_output = attn_output.transpose([1, 0, 2])
  639. attn_output = attn_output.reshape([seq_length, -1])
  640. attn_output = self.proj(attn_output)
  641. return attn_output
  642. class VisionFlashAttention2(nn.Layer):
  643. def __init__(self, dim: int, num_heads: int = 16) -> None:
  644. super().__init__()
  645. self.num_heads = num_heads
  646. self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
  647. self.proj = nn.Linear(dim, dim)
  648. self.head_dim = dim // num_heads # must added
  649. def forward(
  650. self,
  651. hidden_states: paddle.Tensor,
  652. cu_seqlens: paddle.Tensor,
  653. rotary_pos_emb: paddle.Tensor = None,
  654. ) -> paddle.Tensor:
  655. seq_length = tuple(hidden_states.shape)[0]
  656. qkv = (
  657. self.qkv(hidden_states)
  658. .reshape([seq_length, 3, self.num_heads, -1])
  659. .transpose(perm=[1, 0, 2, 3])
  660. )
  661. q, k, v = qkv.unbind(axis=0)
  662. q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(
  663. axis=0
  664. )
  665. k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(
  666. axis=0
  667. )
  668. if _IS_NPU:
  669. attn_output = paddle.nn.functional.flash_attention_npu(
  670. q.astype("bfloat16"),
  671. k.astype("bfloat16"),
  672. v.astype("bfloat16"),
  673. is_varlen=True,
  674. batch_size=1,
  675. seq_length=seq_length,
  676. ).reshape([seq_length, -1])
  677. else:
  678. max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
  679. softmax_scale = self.head_dim**-0.5
  680. attn_output = (
  681. flash_attn_varlen_func(
  682. q.astype("bfloat16"),
  683. k.astype("bfloat16"),
  684. v.astype("bfloat16"),
  685. cu_seqlens,
  686. cu_seqlens,
  687. max_seqlen,
  688. max_seqlen,
  689. scale=softmax_scale,
  690. )[0]
  691. .squeeze(0)
  692. .reshape([seq_length, -1])
  693. )
  694. if self.proj.weight.dtype == paddle.bfloat16:
  695. attn_output = attn_output.astype(paddle.bfloat16)
  696. elif self.proj.weight.dtype == paddle.float16:
  697. attn_output = attn_output.astype(paddle.float16)
  698. elif self.proj.weight.dtype == paddle.float32:
  699. attn_output = attn_output.astype(paddle.float32)
  700. attn_output = self.proj(attn_output)
  701. return attn_output
  702. def create_attention_module(config, module_type, layer_idx=None):
  703. if flash_attn_func is not None:
  704. if module_type == "qwen2vl":
  705. return Qwen2VLFlashAttention2(config, layer_idx)
  706. elif module_type == "vision":
  707. return VisionFlashAttention2(config.embed_dim, num_heads=config.num_heads)
  708. else:
  709. logging.warning_once(
  710. f"Warning: Flash Attention2 is not available for {module_type}, fallback to normal attention."
  711. )
  712. if module_type == "qwen2vl":
  713. return Qwen2VLAttention(config, layer_idx)
  714. elif module_type == "vision":
  715. return VisionAttention(config.embed_dim, num_heads=config.num_heads)
  716. class Qwen2VLVisionBlock(nn.Layer):
  717. def __init__(self, config, attn_implementation: str = "flash_attention_2") -> None:
  718. super().__init__()
  719. self.norm1 = nn.LayerNorm(config.embed_dim, epsilon=1e-6)
  720. self.norm2 = nn.LayerNorm(config.embed_dim, epsilon=1e-6)
  721. mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
  722. self.attn = create_attention_module(config, "vision")
  723. self.mlp = VisionMlp(
  724. dim=config.embed_dim,
  725. hidden_dim=mlp_hidden_dim,
  726. hidden_act=config.hidden_act,
  727. )
  728. def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor:
  729. hidden_states = hidden_states + self.attn(
  730. self.norm1(hidden_states),
  731. cu_seqlens=cu_seqlens,
  732. rotary_pos_emb=rotary_pos_emb,
  733. )
  734. hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
  735. return hidden_states
  736. def _prepare_4d_causal_attention_mask_with_cache_position(
  737. attention_mask: paddle.Tensor,
  738. sequence_length: int,
  739. target_length: int,
  740. dtype: paddle.dtype,
  741. min_dtype: float,
  742. cache_position: paddle.Tensor,
  743. batch_size: int,
  744. ):
  745. """
  746. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  747. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  748. Args:
  749. attention_mask (`paddle.Tensor`):
  750. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  751. sequence_length (`int`):
  752. The sequence length being processed.
  753. target_length (`int`):
  754. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  755. dtype (`paddle.dtype`):
  756. The dtype to use for the 4D attention mask.
  757. min_dtype (`float`):
  758. The minimum value representable with the dtype `dtype`.
  759. cache_position (`paddle.Tensor`):
  760. Indices depicting the position of the input sequence tokens in the sequence.
  761. batch_size (`paddle.Tensor`):
  762. Batch size.
  763. """
  764. if attention_mask is not None and attention_mask.dim() == 4:
  765. causal_mask = attention_mask
  766. else:
  767. causal_mask = paddle.full(
  768. [sequence_length, target_length], fill_value=min_dtype, dtype=dtype
  769. )
  770. if sequence_length != 1:
  771. causal_mask = paddle.triu(x=causal_mask, diagonal=1)
  772. causal_mask *= paddle.arange(target_length) > cache_position.reshape([-1, 1])
  773. causal_mask = causal_mask[None, None, :, :].expand(
  774. shape=[batch_size, 1, -1, -1]
  775. )
  776. if attention_mask is not None:
  777. causal_mask = causal_mask.clone()
  778. mask_length = tuple(attention_mask.shape)[-1]
  779. padding_mask = (
  780. causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  781. )
  782. padding_mask = padding_mask == 0
  783. causal_mask[:, :, :, :mask_length] = causal_mask[
  784. :, :, :, :mask_length
  785. ].masked_fill(mask=padding_mask, value=min_dtype)
  786. return causal_mask
  787. class Qwen2RMSNorm(nn.Layer):
  788. def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6):
  789. """
  790. Qwen2RMSNorm is equivalent to T5LayerNorm
  791. """
  792. super().__init__()
  793. self.weight = paddle.create_parameter(
  794. shape=[hidden_size],
  795. dtype=paddle.get_default_dtype(),
  796. default_initializer=nn.initializer.Constant(1.0),
  797. )
  798. self.variance_epsilon = eps
  799. def forward(self, hidden_states):
  800. if paddle.in_dynamic_mode():
  801. with paddle.amp.auto_cast(False):
  802. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  803. hidden_states = (
  804. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  805. )
  806. else:
  807. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  808. hidden_states = (
  809. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  810. )
  811. if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
  812. hidden_states = paddle.cast(hidden_states, self.weight.dtype)
  813. return hidden_states * self.weight
  814. class Qwen2MLP(nn.Layer):
  815. def __init__(self, config):
  816. super().__init__()
  817. self.hidden_size = config.hidden_size
  818. self.intermediate_size = config.intermediate_size
  819. self.fuse_attention_ffn = config.fuse_attention_ffn
  820. self.tensor_parallel_degree = config.tensor_parallel_degree
  821. if config.tensor_parallel_degree > 1:
  822. self.gate_proj = ColumnParallelLinear(
  823. self.hidden_size,
  824. self.intermediate_size,
  825. gather_output=False,
  826. has_bias=False,
  827. )
  828. self.up_proj = ColumnParallelLinear(
  829. self.hidden_size,
  830. self.intermediate_size,
  831. gather_output=False,
  832. has_bias=False,
  833. )
  834. self.down_proj = RowParallelLinear(
  835. self.intermediate_size,
  836. self.hidden_size,
  837. input_is_parallel=True,
  838. has_bias=False,
  839. )
  840. else:
  841. self.gate_proj = Linear(
  842. self.hidden_size, self.intermediate_size, bias_attr=False
  843. ) # w1
  844. self.up_proj = Linear(
  845. self.hidden_size, self.intermediate_size, bias_attr=False
  846. ) # w3
  847. self.down_proj = Linear(
  848. self.intermediate_size, self.hidden_size, bias_attr=False
  849. ) # w2
  850. self.act_fn = ACT2FN[config.hidden_act]
  851. self.fuse_swiglu = False
  852. def forward(self, x):
  853. x, y = self.gate_proj(x), self.up_proj(x)
  854. if self.fuse_swiglu:
  855. x = self.act_fn(x, y)
  856. else:
  857. x = self.act_fn(x) * y
  858. return self.down_proj(x)
  859. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  860. def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
  861. """
  862. This is the equivalent of paddle.repeat_interleave(x, axis=1, repeats=n_rep). The hidden states go from (batch,
  863. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  864. """
  865. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  866. if n_rep == 1:
  867. return hidden_states
  868. hidden_states = hidden_states[:, :, None, :, :].expand(
  869. [batch, num_key_value_heads, n_rep, slen, head_dim]
  870. )
  871. return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim])
  872. class Qwen2VLAttention(nn.Layer):
  873. """
  874. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  875. and "Generating Long Sequences with Sparse Transformers".
  876. """
  877. def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
  878. super().__init__()
  879. self.config = config
  880. self.layer_idx = layer_idx
  881. if layer_idx is None:
  882. logging.warning_once(
  883. f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
  884. "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  885. "when creating this class."
  886. )
  887. self.hidden_size = config.hidden_size
  888. self.num_heads = config.num_attention_heads
  889. self.head_dim = self.hidden_size // self.num_heads
  890. self.num_key_value_heads = config.num_key_value_heads
  891. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  892. self.max_position_embeddings = config.max_position_embeddings
  893. self.rope_theta = config.rope_theta
  894. self.is_causal = True
  895. self.attention_dropout = config.attention_dropout
  896. self.rope_scaling = config.rope_scaling
  897. # self.sequence_parallel = config.sequence_parallel
  898. if config.tensor_parallel_degree > 1:
  899. assert (
  900. self.num_heads % config.tensor_parallel_degree == 0
  901. ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  902. self.num_heads = self.num_heads // config.tensor_parallel_degree
  903. assert (
  904. self.num_key_value_heads % config.tensor_parallel_degree == 0
  905. ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  906. self.num_key_value_heads = (
  907. self.num_key_value_heads // config.tensor_parallel_degree
  908. )
  909. if config.tensor_parallel_degree > 1:
  910. self.q_proj = ColumnParallelLinear(
  911. self.hidden_size, self.hidden_size, has_bias=True, gather_output=False
  912. )
  913. self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  914. self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  915. self.o_proj = RowParallelLinear(
  916. self.hidden_size,
  917. self.hidden_size,
  918. has_bias=False,
  919. input_is_parallel=True,
  920. )
  921. else:
  922. self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True)
  923. self.k_proj = Linear(
  924. self.hidden_size,
  925. self.config.num_key_value_heads * self.head_dim,
  926. bias_attr=True,
  927. )
  928. self.v_proj = Linear(
  929. self.hidden_size,
  930. self.config.num_key_value_heads * self.head_dim,
  931. bias_attr=True,
  932. )
  933. self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False)
  934. self.rotary_emb = Qwen2VLRotaryEmbedding(
  935. self.head_dim,
  936. max_position_embeddings=self.max_position_embeddings,
  937. base=self.rope_theta,
  938. )
  939. def forward(
  940. self,
  941. hidden_states: paddle.Tensor,
  942. attention_mask: Optional[paddle.Tensor] = None,
  943. position_ids: Optional[paddle.Tensor] = None,
  944. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  945. output_attentions: bool = False,
  946. use_cache: bool = False, # default true
  947. cache_position: Optional[paddle.Tensor] = None,
  948. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
  949. bsz, q_len, _ = hidden_states.shape
  950. try:
  951. query_states = self.q_proj(hidden_states)
  952. key_states = self.k_proj(hidden_states)
  953. value_states = self.v_proj(hidden_states)
  954. except:
  955. hidden_states = hidden_states.astype(self.config.dtype)
  956. query_states = self.q_proj(hidden_states)
  957. key_states = self.k_proj(hidden_states)
  958. value_states = self.v_proj(hidden_states)
  959. target_query_shape = [0, 0, self.num_heads, self.head_dim]
  960. target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
  961. query_states = query_states.reshape(shape=target_query_shape)
  962. key_states = key_states.reshape(shape=target_key_value_shape)
  963. value_states = value_states.reshape(shape=target_key_value_shape)
  964. new_perm = [0, 2, 1, 3]
  965. query_states = query_states.transpose(new_perm)
  966. key_states = key_states.transpose(new_perm)
  967. value_states = value_states.transpose(new_perm)
  968. kv_seq_len = key_states.shape[-2]
  969. if past_key_value is not None:
  970. kv_seq_len += cache_position[0] + 1
  971. cos, sin = self.rotary_emb(value_states, position_ids)
  972. query_states, key_states = apply_multimodal_rotary_pos_emb(
  973. query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
  974. )
  975. if past_key_value is not None:
  976. key_states = paddle.concat([past_key_value[0], key_states], axis=2)
  977. value_states = paddle.concat([past_key_value[1], value_states], axis=2)
  978. past_key_value = (key_states, value_states) if use_cache else None
  979. # repeat k/v heads if n_kv_heads < n_heads
  980. key_states = repeat_kv(key_states, self.num_key_value_groups)
  981. value_states = repeat_kv(value_states, self.num_key_value_groups)
  982. query_states = query_states.astype("float32")
  983. key_states = key_states.astype("float32")
  984. value_states = value_states.astype("float32")
  985. attn_weights = paddle.matmul(
  986. query_states, key_states.transpose([0, 1, 3, 2])
  987. ) / math.sqrt(self.head_dim)
  988. if attention_mask is not None:
  989. attn_weights = attn_weights + attention_mask
  990. attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype="float32")
  991. attn_weights = nn.functional.dropout(
  992. x=attn_weights, p=self.attention_dropout, training=self.training
  993. )
  994. attn_output = paddle.matmul(
  995. attn_weights.cast(self.config.dtype), value_states.cast(self.config.dtype)
  996. )
  997. if attn_output.shape != [bsz, self.num_heads, q_len, self.head_dim]:
  998. raise ValueError(
  999. f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
  1000. f" {attn_output.shape}"
  1001. )
  1002. attn_output = attn_output.transpose([0, 2, 1, 3])
  1003. attn_output = attn_output.reshape([bsz, q_len, -1])
  1004. if self.o_proj.weight.dtype == paddle.bfloat16:
  1005. attn_output = attn_output.astype(paddle.bfloat16)
  1006. elif self.o_proj.weight.dtype == paddle.float16:
  1007. attn_output = attn_output.astype(paddle.float16)
  1008. elif self.o_proj.weight.dtype == paddle.float32:
  1009. attn_output = attn_output.astype(paddle.float32)
  1010. attn_output = self.o_proj(attn_output)
  1011. if not output_attentions:
  1012. attn_weights = None
  1013. return attn_output, attn_weights, past_key_value
  1014. class Qwen2VLFlashAttention2(Qwen2VLAttention):
  1015. """
  1016. Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
  1017. as the weights of the module stays untouched. The only required change would be on the forward pass
  1018. where it needs to correctly call the public API of flash attention and deal with padding tokens
  1019. in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
  1020. config.max_window_layers layers.
  1021. """
  1022. def __init__(self, *args, **kwargs):
  1023. super().__init__(*args, **kwargs)
  1024. def forward(
  1025. self,
  1026. hidden_states: paddle.Tensor,
  1027. attention_mask: Optional[paddle.Tensor] = None,
  1028. position_ids: Optional[paddle.Tensor] = None,
  1029. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  1030. output_attentions: bool = False,
  1031. use_cache: bool = False, # default true
  1032. cache_position: Optional[paddle.Tensor] = None,
  1033. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
  1034. bsz, q_len, _ = tuple(hidden_states.shape)
  1035. try:
  1036. query_states = self.q_proj(hidden_states)
  1037. key_states = self.k_proj(hidden_states)
  1038. value_states = self.v_proj(hidden_states)
  1039. except:
  1040. hidden_states = hidden_states.astype("bfloat16")
  1041. query_states = self.q_proj(hidden_states)
  1042. key_states = self.k_proj(hidden_states)
  1043. value_states = self.v_proj(hidden_states)
  1044. target_query_shape = [0, 0, self.num_heads, self.head_dim]
  1045. target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
  1046. query_states = query_states.reshape(shape=target_query_shape)
  1047. key_states = key_states.reshape(shape=target_key_value_shape)
  1048. value_states = value_states.reshape(shape=target_key_value_shape)
  1049. new_perm = [0, 2, 1, 3]
  1050. query_states = query_states.transpose(new_perm)
  1051. key_states = key_states.transpose(new_perm)
  1052. value_states = value_states.transpose(new_perm)
  1053. kv_seq_len = key_states.shape[-2]
  1054. if past_key_value is not None:
  1055. kv_seq_len += cache_position[0] + 1
  1056. # Because the input can be padded, the absolute sequence length depends on the max position id.
  1057. cos, sin = self.rotary_emb(value_states, position_ids)
  1058. query_states, key_states = apply_multimodal_rotary_pos_emb(
  1059. query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
  1060. )
  1061. if past_key_value is not None:
  1062. key_states = paddle.concat([past_key_value[0], key_states], axis=2)
  1063. value_states = paddle.concat([past_key_value[1], value_states], axis=2)
  1064. past_key_value = (key_states, value_states) if use_cache else None
  1065. # repeat k/v heads if n_kv_heads < n_heads
  1066. key_states = repeat_kv(key_states, self.num_key_value_groups)
  1067. value_states = repeat_kv(value_states, self.num_key_value_groups)
  1068. # Reashape to the expected shape for Flash Attention
  1069. # [1, 3599, 12, 128]
  1070. query_states = query_states.transpose(perm=[0, 2, 1, 3])
  1071. key_states = key_states.transpose(perm=[0, 2, 1, 3])
  1072. value_states = value_states.transpose(perm=[0, 2, 1, 3])
  1073. attn_output = self._flash_attention_forward(
  1074. query_states, key_states, value_states, attention_mask, q_len
  1075. )
  1076. attn_output = attn_output.reshape([bsz, q_len, -1])
  1077. attn_output = self.o_proj(attn_output)
  1078. if not output_attentions:
  1079. attn_weights = None
  1080. return attn_output, attn_weights, past_key_value
  1081. def _flash_attention_forward(
  1082. self,
  1083. query_states,
  1084. key_states,
  1085. value_states,
  1086. attention_mask,
  1087. query_length,
  1088. dropout=0.0,
  1089. softmax_scale=None,
  1090. ):
  1091. """
  1092. Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
  1093. first unpad the input, then computes the attention scores and pad the final attention scores.
  1094. Args:
  1095. query_states (`paddle.Tensor`):
  1096. Input query states to be passed to Flash Attention API
  1097. key_states (`paddle.Tensor`):
  1098. Input key states to be passed to Flash Attention API
  1099. value_states (`paddle.Tensor`):
  1100. Input value states to be passed to Flash Attention API
  1101. attention_mask (`paddle.Tensor`):
  1102. The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
  1103. position of padding tokens and 1 for the position of non-padding tokens.
  1104. dropout (`int`, *optional*):
  1105. Attention dropout
  1106. softmax_scale (`float`, *optional*):
  1107. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
  1108. """
  1109. # Contains at least one padding token in the sequence
  1110. causal = self.is_causal and query_length != 1
  1111. if _IS_NPU:
  1112. if attention_mask is not None:
  1113. attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded
  1114. query_states,
  1115. key_states,
  1116. value_states,
  1117. attn_mask=attention_mask,
  1118. dropout=dropout,
  1119. causal=causal,
  1120. is_varlen=True,
  1121. )
  1122. else:
  1123. dtype = query_states.dtype
  1124. attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded
  1125. query_states.astype("bfloat16"),
  1126. key_states.astype("bfloat16"),
  1127. value_states.astype("bfloat16"),
  1128. attn_mask=attention_mask,
  1129. dropout=dropout,
  1130. causal=causal,
  1131. )
  1132. attn_output = attn_output.astype(dtype)
  1133. else:
  1134. head_dim = query_states.shape[-1]
  1135. softmax_scale = head_dim**-0.5 # TODO: 需要手动加上
  1136. if attention_mask is not None:
  1137. batch_size = query_states.shape[0]
  1138. (
  1139. query_states,
  1140. key_states,
  1141. value_states,
  1142. indices_q,
  1143. cu_seq_lens,
  1144. max_seq_lens,
  1145. ) = self._unpad_input(
  1146. query_states, key_states, value_states, attention_mask, query_length
  1147. )
  1148. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  1149. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  1150. attn_output_unpad = flash_attn_varlen_func(
  1151. query_states,
  1152. key_states,
  1153. value_states,
  1154. cu_seqlens_q=cu_seqlens_q,
  1155. cu_seqlens_k=cu_seqlens_k,
  1156. max_seqlen_q=max_seqlen_in_batch_q,
  1157. max_seqlen_k=max_seqlen_in_batch_k,
  1158. scale=softmax_scale, # not softmax_scale=
  1159. dropout=dropout,
  1160. causal=causal,
  1161. )[0]
  1162. attn_output = pad_input(
  1163. attn_output_unpad, indices_q, batch_size, query_length
  1164. )
  1165. else:
  1166. attn_output = flash_attn_func(
  1167. query_states,
  1168. key_states,
  1169. value_states,
  1170. dropout,
  1171. causal=causal,
  1172. )[0]
  1173. return attn_output
  1174. def _unpad_input(
  1175. self, query_layer, key_layer, value_layer, attention_mask, query_length
  1176. ):
  1177. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
  1178. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  1179. # TODO:cuda error
  1180. key_layer = index_first_axis(
  1181. key_layer.reshape([batch_size * kv_seq_len, num_key_value_heads, head_dim]),
  1182. indices_k,
  1183. )
  1184. value_layer = index_first_axis(
  1185. value_layer.reshape(
  1186. [batch_size * kv_seq_len, num_key_value_heads, head_dim]
  1187. ),
  1188. indices_k,
  1189. )
  1190. if query_length == kv_seq_len:
  1191. query_layer = index_first_axis(
  1192. query_layer.reshape(
  1193. [batch_size * kv_seq_len, self.num_heads, head_dim]
  1194. ),
  1195. indices_k,
  1196. )
  1197. cu_seqlens_q = cu_seqlens_k
  1198. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  1199. indices_q = indices_k
  1200. elif query_length == 1:
  1201. max_seqlen_in_batch_q = 1
  1202. cu_seqlens_q = paddle.arange(
  1203. batch_size + 1, dtype=paddle.int32
  1204. ) # There is a memcpy here, that is very bad.
  1205. indices_q = cu_seqlens_q[:-1]
  1206. query_layer = query_layer.squeeze(1)
  1207. else:
  1208. # The -q_len: slice assumes left padding.
  1209. attention_mask = attention_mask[:, -query_length:]
  1210. query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
  1211. query_layer, attention_mask
  1212. )
  1213. return (
  1214. query_layer,
  1215. key_layer,
  1216. value_layer,
  1217. indices_q.to(paddle.int64),
  1218. (cu_seqlens_q, cu_seqlens_k),
  1219. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  1220. )
  1221. class Qwen2VLDecoderLayer(nn.Layer):
  1222. def __init__(self, config: Qwen2VLConfig, layer_idx: int):
  1223. super().__init__()
  1224. self.hidden_size = config.hidden_size
  1225. # use_sliding_window false
  1226. if (
  1227. config.use_sliding_window
  1228. and config.attn_implementation != "flash_attention_2"
  1229. ):
  1230. logging.warning_once(
  1231. f"Sliding Window Attention is enabled but not implemented for `{config.attn_implementation}`; "
  1232. "unexpected results may be encountered."
  1233. )
  1234. self.self_attn = create_attention_module(config, "qwen2vl", layer_idx=layer_idx)
  1235. # self.self_attn = Qwen2VLAttention(config, layer_idx)
  1236. self.mlp = Qwen2MLP(config)
  1237. self.input_layernorm = Qwen2RMSNorm(
  1238. config, config.hidden_size, eps=config.rms_norm_eps
  1239. )
  1240. self.post_attention_layernorm = Qwen2RMSNorm(
  1241. config, config.hidden_size, eps=config.rms_norm_eps
  1242. )
  1243. def forward(
  1244. self,
  1245. hidden_states: paddle.Tensor,
  1246. attention_mask: Optional[paddle.Tensor] = None,
  1247. position_ids: Optional[paddle.Tensor] = None,
  1248. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  1249. output_attentions: Optional[bool] = False,
  1250. use_cache: Optional[bool] = False,
  1251. cache_position: Optional[paddle.Tensor] = None,
  1252. **kwargs,
  1253. ):
  1254. """
  1255. Args:
  1256. hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  1257. attention_mask (`paddle.Tensor`, *optional*): attention mask of size
  1258. `(batch, sequence_length)` where padding elements are indicated by 0.
  1259. output_attentions (`bool`, *optional*):
  1260. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1261. returned tensors for more detail.
  1262. use_cache (`bool`, *optional*):
  1263. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  1264. (see `past_key_values`).
  1265. past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
  1266. cache_position (`paddle.Tensor` of shape `(sequence_length)`, *optional*):
  1267. Indices depicting the position of the input sequence tokens in the sequence.
  1268. kwargs (`dict`, *optional*):
  1269. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  1270. into the model
  1271. """
  1272. residual = hidden_states
  1273. hidden_states = self.input_layernorm(hidden_states)
  1274. # Self Attention
  1275. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  1276. hidden_states=hidden_states,
  1277. attention_mask=attention_mask,
  1278. position_ids=position_ids,
  1279. past_key_value=past_key_value,
  1280. output_attentions=output_attentions,
  1281. use_cache=use_cache,
  1282. cache_position=cache_position,
  1283. )
  1284. hidden_states = residual + hidden_states
  1285. # Fully Connected
  1286. residual = hidden_states
  1287. hidden_states = self.post_attention_layernorm(hidden_states)
  1288. hidden_states = self.mlp(hidden_states)
  1289. hidden_states = residual + hidden_states
  1290. outputs = (hidden_states,)
  1291. if output_attentions:
  1292. outputs += (self_attn_weights,)
  1293. if use_cache:
  1294. outputs += (present_key_value,)
  1295. return outputs
  1296. class Qwen2VLPreTrainedModel(PretrainedModel):
  1297. config_class = Qwen2VLConfig
  1298. base_model_prefix = "model"
  1299. _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
  1300. _skip_keys_device_placement = "past_key_values"
  1301. def _init_weights(self, layer):
  1302. std = 0.2
  1303. if isinstance(layer, (nn.Linear, nn.Conv3D)):
  1304. nn.initializer.Normal(mean=0.0, std=std)(layer.weight)
  1305. if layer.bias is not None:
  1306. nn.initializer.Constant(0.0)(layer.bias)
  1307. elif isinstance(layer, nn.Embedding):
  1308. nn.initializer.Normal(mean=0.0, std=std)(layer.weight)
  1309. if layer._padding_idx is not None:
  1310. with paddle.no_grad():
  1311. layer.weight[layer._padding_idx] = 0.0
  1312. class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
  1313. config_class = Qwen2VLVisionConfig
  1314. _no_split_modules = ["Qwen2VLVisionBlock"]
  1315. def __init__(self, config) -> None:
  1316. super().__init__(config)
  1317. self.spatial_merge_size = config.spatial_merge_size
  1318. self.patch_embed = PatchEmbed(
  1319. patch_size=config.patch_size,
  1320. temporal_patch_size=config.temporal_patch_size,
  1321. in_channels=config.in_channels,
  1322. embed_dim=config.embed_dim,
  1323. )
  1324. head_dim = config.embed_dim // config.num_heads
  1325. self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
  1326. self.blocks = nn.LayerList(
  1327. [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
  1328. )
  1329. self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
  1330. self.enable_recompute = False
  1331. def get_dtype(self) -> paddle.dtype:
  1332. return self.blocks[0].mlp.fc2.weight.dtype
  1333. def rot_pos_emb(self, grid_thw):
  1334. pos_ids = []
  1335. for t, h, w in grid_thw:
  1336. hpos_ids = paddle.arange(h).unsqueeze(1).expand([-1, w])
  1337. hpos_ids = hpos_ids.reshape(
  1338. [
  1339. h // self.spatial_merge_size,
  1340. self.spatial_merge_size,
  1341. w // self.spatial_merge_size,
  1342. self.spatial_merge_size,
  1343. ]
  1344. )
  1345. hpos_ids = hpos_ids.transpose(perm=[0, 2, 1, 3])
  1346. hpos_ids = hpos_ids.flatten()
  1347. wpos_ids = paddle.arange(w).unsqueeze(0).expand([h, -1])
  1348. wpos_ids = wpos_ids.reshape(
  1349. [
  1350. h // self.spatial_merge_size,
  1351. self.spatial_merge_size,
  1352. w // self.spatial_merge_size,
  1353. self.spatial_merge_size,
  1354. ]
  1355. )
  1356. wpos_ids = wpos_ids.transpose([0, 2, 1, 3])
  1357. wpos_ids = wpos_ids.flatten()
  1358. pos_ids.append(
  1359. paddle.stack(x=[hpos_ids, wpos_ids], axis=-1).tile(repeat_times=[t, 1])
  1360. )
  1361. pos_ids = paddle.concat(x=pos_ids, axis=0)
  1362. max_grid_size = grid_thw[:, 1:].max()
  1363. rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
  1364. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_axis=1)
  1365. return rotary_pos_emb
  1366. @paddle.jit.not_to_static
  1367. def recompute_training_full(
  1368. self,
  1369. layer_module: nn.Layer,
  1370. hidden_states: paddle.Tensor,
  1371. cu_seqlens_now: paddle.Tensor,
  1372. rotary_pos_emb: paddle.Tensor,
  1373. ):
  1374. def create_custom_forward(module):
  1375. def custom_forward(*inputs):
  1376. return module(*inputs)
  1377. return custom_forward
  1378. hidden_states = recompute(
  1379. create_custom_forward(layer_module),
  1380. hidden_states,
  1381. cu_seqlens_now,
  1382. rotary_pos_emb,
  1383. # use_reentrant=self.config.recompute_use_reentrant,
  1384. )
  1385. return hidden_states
  1386. def forward(
  1387. self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor
  1388. ) -> paddle.Tensor:
  1389. # breakpoint()
  1390. hidden_states = self.patch_embed(hidden_states)
  1391. rotary_pos_emb = self.rot_pos_emb(grid_thw)
  1392. cu_seqlens = paddle.repeat_interleave(
  1393. grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
  1394. ).cumsum(axis=0, dtype="int32")
  1395. cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
  1396. for idx, blk in enumerate(self.blocks):
  1397. if self.enable_recompute and self.training:
  1398. hidden_states = self.recompute_training_full(
  1399. blk, hidden_states, cu_seqlens, rotary_pos_emb
  1400. )
  1401. else:
  1402. hidden_states = blk(
  1403. hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
  1404. )
  1405. return self.merger(hidden_states)
  1406. class Qwen2VLModel(Qwen2VLPreTrainedModel):
  1407. def __init__(self, config: Qwen2VLConfig):
  1408. super().__init__(config)
  1409. self.padding_idx = config.pad_token_id
  1410. self.vocab_size = config.vocab_size
  1411. self.hidden_size = config.hidden_size
  1412. # Recompute defaults to False and is controlled by Trainer
  1413. if (
  1414. config.tensor_parallel_degree > 1
  1415. and config.vocab_size % config.tensor_parallel_degree == 0
  1416. ):
  1417. self.embed_tokens = mpu.VocabParallelEmbedding(
  1418. self.vocab_size,
  1419. self.hidden_size,
  1420. weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
  1421. )
  1422. else:
  1423. self.embed_tokens = nn.Embedding(
  1424. self.vocab_size,
  1425. self.hidden_size,
  1426. )
  1427. # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  1428. self.layers = nn.LayerList(
  1429. [
  1430. Qwen2VLDecoderLayer(config, layer_idx)
  1431. for layer_idx in range(config.num_hidden_layers)
  1432. ]
  1433. )
  1434. self.norm = Qwen2RMSNorm(config, config.hidden_size, eps=config.rms_norm_eps)
  1435. self.enamble_recompute = False
  1436. def get_input_embeddings(self):
  1437. return self.embed_tokens
  1438. def set_input_embeddings(self, value):
  1439. self.embed_tokens = value
  1440. @staticmethod
  1441. def _prepare_decoder_attention_mask(
  1442. attention_mask, input_shape, past_key_values_length, dtype
  1443. ):
  1444. if attention_mask is not None:
  1445. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1446. if len(attention_mask.shape) == 2:
  1447. expanded_attn_mask = _expand_2d_mask(
  1448. attention_mask, dtype, tgt_length=input_shape[-1]
  1449. )
  1450. # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
  1451. if input_shape[-1] > 1:
  1452. combined_attention_mask = _make_causal_mask(
  1453. input_shape,
  1454. past_key_values_length=past_key_values_length,
  1455. )
  1456. expanded_attn_mask = expanded_attn_mask & combined_attention_mask
  1457. # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
  1458. elif len(attention_mask.shape) == 3:
  1459. expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
  1460. # if attention_mask is already 4-D, do nothing
  1461. else:
  1462. expanded_attn_mask = attention_mask
  1463. else:
  1464. expanded_attn_mask = _make_causal_mask(
  1465. input_shape,
  1466. past_key_values_length=past_key_values_length,
  1467. )
  1468. # Convert bool attention_mask to float attention mask, which will be added to attention_scores later
  1469. expanded_attn_mask = paddle.where(
  1470. expanded_attn_mask, 0.0, paddle.finfo(dtype).min
  1471. ).astype(dtype)
  1472. return expanded_attn_mask
  1473. @paddle.jit.not_to_static
  1474. def recompute_training_full(
  1475. self,
  1476. layer_module: nn.Layer,
  1477. hidden_states: paddle.Tensor,
  1478. attention_mask: paddle.Tensor,
  1479. position_ids: Optional[paddle.Tensor],
  1480. past_key_value: paddle.Tensor,
  1481. output_attentions: bool,
  1482. use_cache: bool,
  1483. cache_position: Optional[paddle.Tensor] = None,
  1484. ):
  1485. def create_custom_forward(module):
  1486. def custom_forward(*inputs):
  1487. return module(*inputs)
  1488. return custom_forward
  1489. hidden_states = recompute(
  1490. create_custom_forward(layer_module),
  1491. hidden_states,
  1492. attention_mask,
  1493. position_ids,
  1494. past_key_value,
  1495. output_attentions,
  1496. use_cache,
  1497. cache_position,
  1498. use_reentrant=self.config.recompute_use_reentrant,
  1499. )
  1500. return hidden_states
  1501. def forward(
  1502. self,
  1503. input_ids: paddle.Tensor = None,
  1504. attention_mask: Optional[paddle.Tensor] = None,
  1505. position_ids: Optional[paddle.Tensor] = None,
  1506. past_key_values: Optional[List[paddle.Tensor]] = None,
  1507. inputs_embeds: Optional[paddle.Tensor] = None,
  1508. use_cache: Optional[bool] = None,
  1509. output_attentions: Optional[bool] = None,
  1510. output_hidden_states: Optional[bool] = None,
  1511. return_dict: Optional[bool] = None,
  1512. cache_position: Optional[paddle.Tensor] = None,
  1513. ) -> Union[Tuple, BaseModelOutputWithPast]:
  1514. output_attentions = (
  1515. output_attentions
  1516. if output_attentions is not None
  1517. else self.config.output_attentions
  1518. )
  1519. output_hidden_states = (
  1520. output_hidden_states
  1521. if output_hidden_states is not None
  1522. else self.config.output_hidden_states
  1523. )
  1524. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1525. return_dict = (
  1526. return_dict if return_dict is not None else self.config.use_return_dict
  1527. )
  1528. if (input_ids is None) ^ (inputs_embeds is not None):
  1529. raise ValueError(
  1530. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  1531. )
  1532. elif input_ids is not None:
  1533. batch_size, seq_length = input_ids.shape
  1534. elif inputs_embeds is not None:
  1535. batch_size, seq_length, _ = inputs_embeds.shape
  1536. else:
  1537. raise ValueError(
  1538. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  1539. )
  1540. if past_key_values is None:
  1541. past_key_values = tuple([None] * len(self.layers))
  1542. # NOTE: to make cache can be clear in-time
  1543. past_key_values = list(past_key_values)
  1544. seq_length_with_past = seq_length
  1545. cache_length = 0
  1546. if past_key_values[0] is not None:
  1547. cache_length = past_key_values[0][0].shape[2] # shape[1] in qwen2
  1548. seq_length_with_past += cache_length
  1549. if inputs_embeds is None:
  1550. inputs_embeds = self.embed_tokens(input_ids)
  1551. # embed positions
  1552. if attention_mask is None:
  1553. # [bs, seq_len]
  1554. attention_mask = paddle.ones(
  1555. (batch_size, seq_length_with_past), dtype=paddle.bool
  1556. )
  1557. if flash_attn_varlen_func:
  1558. causal_mask = attention_mask
  1559. else:
  1560. causal_mask = self._prepare_decoder_attention_mask(
  1561. attention_mask,
  1562. (batch_size, seq_length),
  1563. cache_length,
  1564. inputs_embeds.dtype,
  1565. ) # [bs, 1, seq_len, seq_len]
  1566. if cache_position is None:
  1567. past_seen_tokens = (
  1568. past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
  1569. )
  1570. cache_position = paddle.arange(
  1571. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]
  1572. )
  1573. if position_ids is None:
  1574. # the hard coded `3` is for temporal, height and width.
  1575. position_ids = cache_position.reshape([1, 1, -1]).expand(
  1576. [3, inputs_embeds.shape[0], -1]
  1577. )
  1578. hidden_states = inputs_embeds
  1579. # decoder layers
  1580. all_hidden_states = () if output_hidden_states else None
  1581. all_self_attns = () if output_attentions else None
  1582. next_decoder_cache = ()
  1583. for idx, (decoder_layer) in enumerate(self.layers):
  1584. if output_hidden_states:
  1585. all_hidden_states += (hidden_states,)
  1586. past_key_value = (
  1587. past_key_values[idx] if past_key_values is not None else None
  1588. )
  1589. if self.enamble_recompute and self.training:
  1590. layer_outputs = self.recompute_training_full(
  1591. decoder_layer,
  1592. hidden_states,
  1593. causal_mask,
  1594. position_ids,
  1595. past_key_value,
  1596. output_attentions,
  1597. use_cache,
  1598. cache_position,
  1599. )
  1600. else:
  1601. layer_outputs = decoder_layer(
  1602. hidden_states,
  1603. attention_mask=causal_mask,
  1604. position_ids=position_ids,
  1605. past_key_value=past_key_value,
  1606. output_attentions=output_attentions, # False
  1607. use_cache=use_cache, # True
  1608. cache_position=cache_position,
  1609. )
  1610. # NOTE: clear outdate cache after it has been used for memory saving
  1611. past_key_value = past_key_values[idx] = None
  1612. hidden_states = layer_outputs[0]
  1613. next_decoder_cache = (
  1614. next_decoder_cache + (layer_outputs[-1],) if use_cache else None
  1615. )
  1616. if output_attentions:
  1617. all_self_attns += (layer_outputs[1],)
  1618. hidden_states = self.norm(hidden_states)
  1619. # add hidden states from the last decoder layer
  1620. if output_hidden_states:
  1621. all_hidden_states += (hidden_states,)
  1622. next_cache = next_decoder_cache if use_cache else None
  1623. if not return_dict:
  1624. return tuple(
  1625. v
  1626. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
  1627. if v is not None
  1628. )
  1629. return BaseModelOutputWithPast(
  1630. last_hidden_state=hidden_states,
  1631. past_key_values=next_cache,
  1632. hidden_states=all_hidden_states,
  1633. attentions=all_self_attns,
  1634. )
  1635. class Qwen2LMHead(nn.Layer):
  1636. def __init__(self, config, embedding_weights=None, transpose_y=False):
  1637. super(Qwen2LMHead, self).__init__()
  1638. self.config = config
  1639. if (
  1640. config.tensor_parallel_degree > 1
  1641. and config.vocab_size % config.tensor_parallel_degree == 0
  1642. ):
  1643. vocab_size = config.vocab_size // config.tensor_parallel_degree
  1644. else:
  1645. vocab_size = config.vocab_size
  1646. self.transpose_y = transpose_y
  1647. if transpose_y:
  1648. # only for weight from embedding_weights
  1649. if embedding_weights is not None:
  1650. self.weight = embedding_weights
  1651. else:
  1652. self.weight = self.create_parameter(
  1653. shape=[vocab_size, config.hidden_size],
  1654. dtype=paddle.get_default_dtype(),
  1655. )
  1656. else:
  1657. if vocab_size != config.vocab_size:
  1658. with get_rng_state_tracker().rng_state():
  1659. self.weight = self.create_parameter(
  1660. shape=[config.hidden_size, vocab_size],
  1661. dtype=paddle.get_default_dtype(),
  1662. )
  1663. else:
  1664. self.weight = self.create_parameter(
  1665. shape=[config.hidden_size, vocab_size],
  1666. dtype=paddle.get_default_dtype(),
  1667. )
  1668. # Must set distributed attr for Tensor Parallel !
  1669. self.weight.is_distributed = (
  1670. True if (vocab_size != config.vocab_size) else False
  1671. )
  1672. if self.weight.is_distributed:
  1673. # for tie_word_embeddings
  1674. self.weight.split_axis = 0 if self.transpose_y else 1
  1675. def forward(self, hidden_states, tensor_parallel_output=None):
  1676. if tensor_parallel_output is None:
  1677. tensor_parallel_output = self.config.tensor_parallel_output
  1678. # 确保数据类型一致
  1679. if self.weight.dtype != hidden_states.dtype:
  1680. hidden_states = paddle.cast(hidden_states, self.weight.dtype)
  1681. logits = parallel_matmul(
  1682. hidden_states,
  1683. self.weight,
  1684. transpose_y=self.transpose_y,
  1685. tensor_parallel_output=tensor_parallel_output,
  1686. )
  1687. return logits
  1688. class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel):
  1689. _tied_weights_keys = ["lm_head.weight"]
  1690. def __init__(self, config, attn_implementation="flash_attention_2"):
  1691. super().__init__(config)
  1692. config._attn_implementation = attn_implementation
  1693. config.vision_config._attn_implementation = attn_implementation
  1694. self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
  1695. config.vision_config
  1696. )
  1697. self.model = Qwen2VLModel(config)
  1698. self.vocab_size = config.vocab_size
  1699. if config.tie_word_embeddings:
  1700. self.lm_head = Qwen2LMHead(
  1701. config,
  1702. embedding_weights=self.model.embed_tokens.weight,
  1703. transpose_y=True,
  1704. )
  1705. self.tie_weights()
  1706. else:
  1707. self.lm_head = Qwen2LMHead(config)
  1708. self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
  1709. def get_input_embeddings(self):
  1710. return self.model.embed_tokens
  1711. def set_input_embeddings(self, value):
  1712. self.model.embed_tokens = value
  1713. def get_output_embeddings(self):
  1714. return self.lm_head
  1715. def set_output_embeddings(self, new_embeddings):
  1716. self.lm_head = new_embeddings
  1717. def set_decoder(self, decoder):
  1718. self.model = decoder
  1719. def get_decoder(self):
  1720. return self.model
  1721. @classmethod
  1722. def _get_tensor_parallel_mappings(cls, config: Qwen2VLConfig, is_split=True):
  1723. logging.info("Qwen2 inference model _get_tensor_parallel_mappings")
  1724. from paddlenlp.transformers.conversion_utils import split_or_merge_func
  1725. fn = split_or_merge_func(
  1726. is_split=is_split,
  1727. tensor_parallel_degree=config.tensor_parallel_degree,
  1728. tensor_parallel_rank=config.tensor_parallel_rank,
  1729. num_attention_heads=config.num_attention_heads,
  1730. )
  1731. def get_tensor_parallel_split_mappings(num_layers):
  1732. final_actions = {}
  1733. base_actions = {
  1734. "lm_head.weight": partial(fn, is_column=True),
  1735. # Row Linear
  1736. "embed_tokens.weight": partial(fn, is_column=False),
  1737. "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
  1738. "layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
  1739. }
  1740. base_actions["layers.0.self_attn.q_proj.weight"] = partial(
  1741. fn, is_column=True
  1742. )
  1743. base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
  1744. # if we have enough num_key_value_heads to split, then split it.
  1745. if config.num_key_value_heads % config.tensor_parallel_degree == 0:
  1746. base_actions["layers.0.self_attn.k_proj.weight"] = partial(
  1747. fn, is_column=True
  1748. )
  1749. base_actions["layers.0.self_attn.v_proj.weight"] = partial(
  1750. fn, is_column=True
  1751. )
  1752. base_actions["layers.0.self_attn.k_proj.bias"] = partial(
  1753. fn, is_column=True
  1754. )
  1755. base_actions["layers.0.self_attn.v_proj.bias"] = partial(
  1756. fn, is_column=True
  1757. )
  1758. if config.fuse_attention_ffn:
  1759. base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
  1760. fn, is_column=True, is_naive_2fuse=True
  1761. )
  1762. else:
  1763. base_actions["layers.0.mlp.gate_proj.weight"] = partial(
  1764. fn, is_column=True
  1765. )
  1766. base_actions["layers.0.mlp.up_proj.weight"] = partial(
  1767. fn, is_column=True
  1768. )
  1769. for key, action in base_actions.items():
  1770. if "layers.0." in key:
  1771. for i in range(num_layers):
  1772. final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
  1773. final_actions[key] = action
  1774. return final_actions
  1775. mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
  1776. return mappings
  1777. @staticmethod
  1778. def get_rope_index(
  1779. spatial_merge_size,
  1780. image_token_id,
  1781. video_token_id,
  1782. vision_start_token_id,
  1783. input_ids: paddle.Tensor,
  1784. image_grid_thw: Optional[paddle.Tensor] = None,
  1785. video_grid_thw: Optional[paddle.Tensor] = None,
  1786. attention_mask: Optional[paddle.Tensor] = None,
  1787. ) -> Tuple[paddle.Tensor, paddle.Tensor]:
  1788. """
  1789. Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
  1790. Explanation:
  1791. Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
  1792. For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
  1793. Examples:
  1794. input_ids: [T T T T T], here T is for text.
  1795. temporal position_ids: [0, 1, 2, 3, 4]
  1796. height position_ids: [0, 1, 2, 3, 4]
  1797. width position_ids: [0, 1, 2, 3, 4]
  1798. For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
  1799. and 1D rotary position embedding for text part.
  1800. Examples:
  1801. Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
  1802. input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
  1803. vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
  1804. vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
  1805. vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
  1806. text temporal position_ids: [3, 4, 5, 6, 7]
  1807. text height position_ids: [3, 4, 5, 6, 7]
  1808. text width position_ids: [3, 4, 5, 6, 7]
  1809. Here we calculate the text start position_ids as the max vision position_ids plus 1.
  1810. Args:
  1811. input_ids (`paddle.Tensor` of shape `(batch_size, sequence_length)`):
  1812. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  1813. it.
  1814. image_grid_thw (`paddle.Tensor` of shape `(num_images, 3)`, *optional*):
  1815. The temporal, height and width of feature shape of each image in LLM.
  1816. video_grid_thw (`paddle.Tensor` of shape `(num_videos, 3)`, *optional*):
  1817. The temporal, height and width of feature shape of each video in LLM.
  1818. attention_mask (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1819. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1820. - 1 for tokens that are **not masked**,
  1821. - 0 for tokens that are **masked**.
  1822. Returns:
  1823. position_ids (`paddle.Tensor` of shape `(3, batch_size, sequence_length)`)
  1824. mrope_position_deltas (`paddle.Tensor` of shape `(batch_size)`)
  1825. """
  1826. mrope_position_deltas = []
  1827. if image_grid_thw is not None or video_grid_thw is not None:
  1828. total_input_ids = input_ids
  1829. position_ids = paddle.ones(
  1830. [3, input_ids.shape[0], input_ids.shape[1]], dtype=input_ids.dtype
  1831. )
  1832. image_index, video_index = 0, 0
  1833. for i, input_ids in enumerate(total_input_ids):
  1834. # TODO: CUDA error in some paddle version
  1835. if attention_mask is not None:
  1836. input_ids = paddle.to_tensor(
  1837. input_ids.cpu()[attention_mask[i].cpu() == 1]
  1838. ) # NOTE 原始写法
  1839. image_nums, video_nums = 0, 0
  1840. vision_start_indices = paddle.nonzero(
  1841. input_ids == vision_start_token_id
  1842. ).squeeze(
  1843. 1
  1844. ) # NOTE 原始写法
  1845. vision_tokens = input_ids[vision_start_indices + 1]
  1846. image_nums = (
  1847. (vision_tokens == image_token_id).sum()
  1848. if vision_tokens.numel() > 0
  1849. else 0
  1850. )
  1851. video_nums = (
  1852. (vision_tokens == video_token_id).sum()
  1853. if vision_tokens.numel() > 0
  1854. else 0
  1855. )
  1856. input_tokens = input_ids.tolist()
  1857. llm_pos_ids_list: list = []
  1858. st = 0
  1859. remain_images, remain_videos = image_nums, video_nums
  1860. for _ in range(image_nums + video_nums):
  1861. if image_token_id in input_tokens and remain_images > 0:
  1862. ed_image = input_tokens.index(image_token_id, st)
  1863. else:
  1864. ed_image = len(input_tokens) + 1
  1865. if video_token_id in input_tokens and remain_videos > 0:
  1866. ed_video = input_tokens.index(video_token_id, st)
  1867. else:
  1868. ed_video = len(input_tokens) + 1
  1869. if ed_image < ed_video:
  1870. t, h, w = (
  1871. image_grid_thw[image_index][0],
  1872. image_grid_thw[image_index][1],
  1873. image_grid_thw[image_index][2],
  1874. )
  1875. image_index += 1
  1876. remain_images -= 1
  1877. ed = ed_image
  1878. else:
  1879. t, h, w = (
  1880. video_grid_thw[video_index][0],
  1881. video_grid_thw[video_index][1],
  1882. video_grid_thw[video_index][2],
  1883. )
  1884. video_index += 1
  1885. remain_videos -= 1
  1886. ed = ed_video
  1887. llm_grid_t, llm_grid_h, llm_grid_w = (
  1888. t.item(),
  1889. h.item() // spatial_merge_size,
  1890. w.item() // spatial_merge_size,
  1891. )
  1892. text_len = ed - st
  1893. st_idx = (
  1894. llm_pos_ids_list[-1].max() + 1
  1895. if len(llm_pos_ids_list) > 0
  1896. else 0
  1897. )
  1898. llm_pos_ids_list.append(
  1899. paddle.arange(text_len).reshape([1, -1]).expand([3, -1])
  1900. + st_idx
  1901. )
  1902. t_index = (
  1903. paddle.arange(llm_grid_t)
  1904. .reshape([-1, 1])
  1905. .expand([-1, llm_grid_h * llm_grid_w])
  1906. .flatten()
  1907. )
  1908. h_index = (
  1909. paddle.arange(llm_grid_h)
  1910. .reshape([1, -1, 1])
  1911. .expand([llm_grid_t, -1, llm_grid_w])
  1912. .flatten()
  1913. )
  1914. w_index = (
  1915. paddle.arange(llm_grid_w)
  1916. .reshape([1, 1, -1])
  1917. .expand([llm_grid_t, llm_grid_h, -1])
  1918. .flatten()
  1919. )
  1920. llm_pos_ids_list.append(
  1921. paddle.stack([t_index, h_index, w_index]) + text_len + st_idx
  1922. )
  1923. st = ed + llm_grid_t * llm_grid_h * llm_grid_w
  1924. if st < len(input_tokens):
  1925. st_idx = (
  1926. llm_pos_ids_list[-1].max() + 1
  1927. if len(llm_pos_ids_list) > 0
  1928. else 0
  1929. )
  1930. text_len = len(input_tokens) - st
  1931. llm_pos_ids_list.append(
  1932. paddle.arange(text_len).reshape([1, -1]).expand([3, -1])
  1933. + st_idx
  1934. )
  1935. llm_positions = paddle.concat(llm_pos_ids_list, axis=1).reshape([3, -1])
  1936. if _IS_NPU:
  1937. bool_indices = (
  1938. (attention_mask[i] == 1)
  1939. .unsqueeze(0)
  1940. .tile([position_ids.shape[0], 1])
  1941. )
  1942. position_ids[:, i] = paddle.index_put(
  1943. position_ids[:, i], [bool_indices], llm_positions.reshape([-1])
  1944. )
  1945. else:
  1946. position_ids[..., i, attention_mask[i] == 1] = llm_positions
  1947. mrope_position_deltas.append(
  1948. llm_positions.max() + 1 - len(total_input_ids[i])
  1949. )
  1950. mrope_position_deltas = paddle.to_tensor(mrope_position_deltas).unsqueeze(1)
  1951. else:
  1952. if attention_mask is not None:
  1953. position_ids = paddle.cast(attention_mask, dtype="int64").cumsum(-1) - 1
  1954. position_ids.masked_fill_(mask=attention_mask == 0, value=1)
  1955. position_ids = position_ids.unsqueeze(0).expand([3, -1, -1])
  1956. max_position_ids = position_ids.max(0, keepdim=False)[0].max(
  1957. -1, keepdim=True
  1958. )[0]
  1959. mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
  1960. else:
  1961. position_ids = (
  1962. paddle.arange(input_ids.shape[1])
  1963. .reshape([1, 1, -1])
  1964. .expand(shape=[3, input_ids.shape[0], -1])
  1965. )
  1966. mrope_position_deltas = paddle.zeros(
  1967. [input_ids.shape[0], 1], dtype=input_ids.dtype
  1968. )
  1969. return position_ids, mrope_position_deltas
  1970. def update_model_kwargs_for_generation(
  1971. self,
  1972. outputs: ModelOutput,
  1973. model_kwargs: Dict[str, Any],
  1974. is_encoder_decoder: bool = False,
  1975. # num_new_tokens: int = 1,
  1976. ) -> Dict[str, Any]:
  1977. model_kwargs = super().update_model_kwargs_for_generation(
  1978. outputs=outputs,
  1979. model_kwargs=model_kwargs,
  1980. is_encoder_decoder=is_encoder_decoder,
  1981. )
  1982. if getattr(outputs, "rope_deltas", None) is not None:
  1983. model_kwargs["rope_deltas"] = outputs.rope_deltas
  1984. return model_kwargs
  1985. def vision_forward(
  1986. self,
  1987. input_ids: paddle.Tensor,
  1988. inputs_embeds: Optional[paddle.Tensor] = None,
  1989. attention_mask: Optional[paddle.Tensor] = None,
  1990. position_ids: Optional[paddle.Tensor] = None,
  1991. pixel_values: Optional[paddle.Tensor] = None,
  1992. pixel_values_videos: Optional[paddle.Tensor] = None,
  1993. image_grid_thw: Optional[paddle.Tensor] = None,
  1994. video_grid_thw: Optional[paddle.Tensor] = None,
  1995. rope_deltas: Optional[paddle.Tensor] = None,
  1996. ):
  1997. if inputs_embeds is None:
  1998. from paddlenlp.experimental.transformers.qwen2.modeling import (
  1999. Qwen2VLForConditionalGenerationBlockInferenceModel,
  2000. )
  2001. assert isinstance(
  2002. self.model, Qwen2VLForConditionalGenerationBlockInferenceModel
  2003. ), "model is not an instance of Qwen2VLForConditionalGenerationBlockInferenceModel"
  2004. inputs_embeds = self.model.qwen2.embed_tokens(input_ids)
  2005. if pixel_values is not None:
  2006. pixel_values = paddle.cast(pixel_values, paddle.bfloat16)
  2007. image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  2008. image_mask = input_ids == self.config.image_token_id
  2009. inputs_embeds[image_mask] = image_embeds
  2010. if pixel_values_videos is not None:
  2011. pixel_values_videos = paddle.cast(pixel_values_videos, paddle.bfloat16)
  2012. video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
  2013. video_mask = input_ids == self.config.video_token_id
  2014. inputs_embeds[video_mask] = video_embeds
  2015. return inputs_embeds
  2016. def forward(
  2017. self,
  2018. input_ids: paddle.Tensor = None,
  2019. attention_mask: Optional[paddle.Tensor] = None,
  2020. position_ids: Optional[paddle.Tensor] = None,
  2021. past_key_values: Optional[List[paddle.Tensor]] = None,
  2022. inputs_embeds: Optional[paddle.Tensor] = None,
  2023. labels: Optional[paddle.Tensor] = None,
  2024. use_cache: Optional[bool] = None,
  2025. output_attentions: Optional[bool] = None,
  2026. output_hidden_states: Optional[bool] = None,
  2027. return_dict: Optional[bool] = None,
  2028. pixel_values: Optional[paddle.Tensor] = None,
  2029. pixel_values_videos: Optional[paddle.Tensor] = None,
  2030. image_grid_thw: Optional[paddle.Tensor] = None,
  2031. video_grid_thw: Optional[paddle.Tensor] = None,
  2032. rope_deltas: Optional[paddle.Tensor] = None,
  2033. ):
  2034. """
  2035. Args:
  2036. labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  2037. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  2038. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  2039. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  2040. """
  2041. output_attentions = (
  2042. output_attentions
  2043. if output_attentions is not None
  2044. else self.config.output_attentions
  2045. )
  2046. output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip
  2047. return_dict = True # return_dict if return_dict is not None else self.config.use_return_dict
  2048. if inputs_embeds is None:
  2049. inputs_embeds = self.model.embed_tokens(input_ids)
  2050. if pixel_values is not None:
  2051. pixel_values = paddle.cast(pixel_values, inputs_embeds.dtype)
  2052. image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  2053. image_embeds = paddle.cast(image_embeds, inputs_embeds.dtype)
  2054. image_mask = input_ids == self.config.image_token_id
  2055. if self.training:
  2056. inputs_embeds = inputs_embeds.clone()
  2057. inputs_embeds[image_mask] = image_embeds
  2058. if pixel_values_videos is not None:
  2059. pixel_values_videos = paddle.cast(
  2060. pixel_values_videos, inputs_embeds.dtype
  2061. )
  2062. video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
  2063. video_embeds = paddle.cast(video_embeds, inputs_embeds.dtype)
  2064. video_mask = input_ids == self.config.video_token_id
  2065. inputs_embeds[video_mask] = video_embeds
  2066. if attention_mask is not None:
  2067. attention_mask = attention_mask
  2068. outputs = self.model(
  2069. input_ids=None,
  2070. position_ids=position_ids,
  2071. attention_mask=attention_mask,
  2072. past_key_values=past_key_values,
  2073. inputs_embeds=inputs_embeds,
  2074. use_cache=use_cache,
  2075. output_attentions=output_attentions,
  2076. output_hidden_states=output_hidden_states,
  2077. return_dict=return_dict,
  2078. )
  2079. hidden_states = outputs[0]
  2080. tensor_parallel_output = (
  2081. self.config.tensor_parallel_output
  2082. and self.config.tensor_parallel_degree > 1
  2083. )
  2084. logits = self.lm_head(
  2085. hidden_states, tensor_parallel_output=tensor_parallel_output
  2086. )
  2087. logits = paddle.cast(logits, "float32")
  2088. loss = None
  2089. if labels is not None:
  2090. # Shift so that tokens < n predict n
  2091. shift_logits = logits[..., :-1, :]
  2092. shift_labels = labels[..., 1:]
  2093. # Flatten the tokens
  2094. shift_logits = shift_logits.reshape([-1, self.config.vocab_size])
  2095. shift_labels = shift_labels.reshape([-1])
  2096. if _IS_NPU:
  2097. tmp = F.log_softmax(shift_logits, axis=1)
  2098. loss = F.nll_loss(tmp, shift_labels, reduction="sum")
  2099. else:
  2100. loss_fct = nn.CrossEntropyLoss(reduction="sum")
  2101. loss = loss_fct(shift_logits, shift_labels)
  2102. label_sum = paddle.sum(shift_labels != -100).cast("float32")
  2103. loss = loss / label_sum
  2104. if not return_dict:
  2105. output = (logits,) + tuple(outputs[1:])
  2106. return (loss,) + output if loss is not None else output
  2107. return Qwen2VLCausalLMOutputWithPast(
  2108. loss=loss,
  2109. logits=logits,
  2110. past_key_values=outputs.past_key_values,
  2111. hidden_states=outputs.hidden_states,
  2112. attentions=outputs.attentions,
  2113. rope_deltas=rope_deltas,
  2114. )
  2115. def prepare_inputs_for_generation(
  2116. self,
  2117. input_ids,
  2118. past_key_values=None,
  2119. attention_mask=None,
  2120. inputs_embeds=None,
  2121. cache_position=None,
  2122. position_ids=None,
  2123. use_cache=True,
  2124. pixel_values=None,
  2125. pixel_values_videos=None,
  2126. image_grid_thw=None,
  2127. video_grid_thw=None,
  2128. **kwargs,
  2129. ):
  2130. batch_size, seq_length = input_ids.shape
  2131. if past_key_values is None:
  2132. cache_position = paddle.arange(input_ids.shape[1])
  2133. else:
  2134. cache_position = paddle.to_tensor([seq_length - 1])
  2135. if past_key_values is not None:
  2136. input_ids = input_ids[:, -1].unsqueeze(-1)
  2137. rope_deltas = kwargs.get("rope_deltas", None)
  2138. if attention_mask is not None and position_ids is None:
  2139. if cache_position is None or (
  2140. cache_position is not None and cache_position[0] == 0
  2141. ):
  2142. position_ids, rope_deltas = self.get_rope_index(
  2143. self.config.vision_config.spatial_merge_size,
  2144. self.config.image_token_id,
  2145. self.config.video_token_id,
  2146. self.config.vision_start_token_id,
  2147. input_ids,
  2148. image_grid_thw,
  2149. video_grid_thw,
  2150. attention_mask,
  2151. )
  2152. else:
  2153. batch_size, seq_length = input_ids.shape
  2154. delta = (
  2155. cache_position[0] + rope_deltas
  2156. if cache_position is not None and rope_deltas is not None
  2157. else 0
  2158. )
  2159. position_ids = paddle.arange(seq_length)
  2160. position_ids = position_ids.reshape([1, -1]).expand([batch_size, -1])
  2161. position_ids = position_ids + delta
  2162. position_ids = position_ids.unsqueeze(axis=0).expand([3, -1, -1])
  2163. if cache_position[0] != 0:
  2164. pixel_values = None
  2165. pixel_values_videos = None
  2166. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  2167. if inputs_embeds is not None and cache_position[0] == 0:
  2168. model_inputs = {"inputs_embeds": inputs_embeds}
  2169. else:
  2170. model_inputs = {"input_ids": input_ids}
  2171. model_inputs.update(
  2172. {
  2173. "position_ids": position_ids, # [3, 1, 3602]
  2174. "past_key_values": past_key_values, # DynamicCache()
  2175. "use_cache": use_cache, # 1
  2176. "attention_mask": attention_mask, # [1, 3602]
  2177. "pixel_values": pixel_values, # [14308, 1176]
  2178. "pixel_values_videos": pixel_values_videos,
  2179. "image_grid_thw": image_grid_thw, # [[ 1, 98, 146]]
  2180. "video_grid_thw": video_grid_thw,
  2181. "rope_deltas": rope_deltas, # [[-3504]]
  2182. }
  2183. )
  2184. return model_inputs
  2185. def gme_qwen2_vl_forward(
  2186. self,
  2187. input_ids: paddle.Tensor = None,
  2188. attention_mask: Optional[paddle.Tensor] = None,
  2189. position_ids: Optional[paddle.Tensor] = None,
  2190. past_key_values: Optional[List[paddle.Tensor]] = None,
  2191. inputs_embeds: Optional[paddle.Tensor] = None,
  2192. labels: Optional[paddle.Tensor] = None,
  2193. use_cache: Optional[bool] = None,
  2194. output_attentions: Optional[bool] = None,
  2195. output_hidden_states: Optional[bool] = None,
  2196. return_dict: Optional[bool] = None,
  2197. pixel_values: Optional[paddle.Tensor] = None,
  2198. pixel_values_videos: Optional[paddle.Tensor] = None,
  2199. image_grid_thw: Optional[paddle.Tensor] = None,
  2200. video_grid_thw: Optional[paddle.Tensor] = None,
  2201. rope_deltas: Optional[paddle.Tensor] = None,
  2202. ):
  2203. output_attentions = (
  2204. output_attentions
  2205. if output_attentions is not None
  2206. else self.config.output_attentions
  2207. )
  2208. output_hidden_states = (
  2209. output_hidden_states
  2210. if output_hidden_states is not None
  2211. else self.config.output_hidden_states
  2212. )
  2213. return_dict = True # return_dict if return_dict is not None else self.config.use_return_dict
  2214. if inputs_embeds is None:
  2215. inputs_embeds = self.model.embed_tokens(input_ids)
  2216. if pixel_values is not None:
  2217. # 确保 pixel_values 和 inputs_embeds 使用相同的数据类型
  2218. pixel_values = paddle.cast(pixel_values, inputs_embeds.dtype)
  2219. image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  2220. # 确保 image_embeds 和 inputs_embeds 使用相同的数据类型
  2221. image_embeds = paddle.cast(image_embeds, inputs_embeds.dtype)
  2222. image_mask = input_ids == self.config.image_token_id
  2223. if self.training:
  2224. inputs_embeds = inputs_embeds.clone()
  2225. inputs_embeds[image_mask] = image_embeds
  2226. if pixel_values_videos is not None:
  2227. # 确保 pixel_values_videos 和 inputs_embeds 使用相同的数据类型
  2228. pixel_values_videos = paddle.cast(
  2229. pixel_values_videos, inputs_embeds.dtype
  2230. )
  2231. video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
  2232. # 确保 video_embeds 和 inputs_embeds 使用相同的数据类型
  2233. video_embeds = paddle.cast(video_embeds, inputs_embeds.dtype)
  2234. video_mask = input_ids == self.config.video_token_id
  2235. inputs_embeds[video_mask] = video_embeds
  2236. if attention_mask is not None:
  2237. attention_mask = attention_mask
  2238. outputs = self.model(
  2239. input_ids=None,
  2240. position_ids=position_ids,
  2241. attention_mask=attention_mask,
  2242. past_key_values=past_key_values,
  2243. inputs_embeds=inputs_embeds,
  2244. use_cache=use_cache,
  2245. output_attentions=output_attentions,
  2246. output_hidden_states=output_hidden_states,
  2247. return_dict=return_dict,
  2248. )
  2249. hidden_states = outputs[0]
  2250. # get last hidden state
  2251. last_hidden_state = hidden_states[:, -1, :]
  2252. return last_hidden_state
  2253. class PPDocBeeInference(Qwen2VLForConditionalGeneration):
  2254. set_inference_operations(get_inference_operations() + ["docbee_generate"])
  2255. @benchmark.timeit_with_options(name="docbee_generate")
  2256. def generate(self, inputs, **kwargs):
  2257. max_new_tokens = kwargs.get("max_new_tokens", 2048)
  2258. temperature = kwargs.get("temperature", 0.1)
  2259. top_p = kwargs.get("top_p", 0.001)
  2260. top_k = kwargs.get("top_k", 1)
  2261. with paddle.no_grad():
  2262. generated_ids = super().generate(
  2263. **inputs,
  2264. max_new_tokens=max_new_tokens,
  2265. temperature=temperature,
  2266. top_p=top_p,
  2267. top_k=top_k,
  2268. )
  2269. return generated_ids