mfr_cudagraph.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899
  1. from typing import Optional, Tuple, Union
  2. import torch
  3. from torch import nn
  4. import os
  5. from unimernet.common.config import Config
  6. import unimernet.tasks as tasks
  7. import argparse
  8. from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
  9. from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
  10. class PatchedMBartLearnedPositionalEmbedding(nn.Module):
  11. def __init__(self, origin: nn.Module):
  12. super().__init__()
  13. self.offset = origin.offset
  14. self.embedding = nn.Embedding(origin.num_embeddings, origin.embedding_dim)
  15. self.embedding.weight.data = origin.weight.data
  16. def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
  17. """`input_ids' shape is expected to be [bsz x seqlen]."""
  18. bsz, seq_len = input_ids.shape[:2]
  19. positions = torch.arange(0, seq_len, dtype=torch.long, device=self.embedding.weight.device
  20. )
  21. positions += past_key_values_length
  22. positions = positions.expand(bsz, -1)
  23. return self.embedding(positions + self.offset)
  24. class PatchedMBartDecoder(nn.Module):
  25. def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
  26. super().__init__()
  27. self.origin = origin
  28. self.kvlen = kvlen
  29. self.config = origin.config
  30. self.embed_tokens = origin.embed_tokens
  31. self.embed_scale = origin.embed_scale
  32. self._use_flash_attention_2 = origin._use_flash_attention_2
  33. self.embed_positions = origin.embed_positions
  34. self.counting_context_weight = getattr(origin, 'counting_context_weight', None)
  35. self.layernorm_embedding = origin.layernorm_embedding
  36. self.layers = origin.layers
  37. self.layer_norm = origin.layer_norm
  38. self.patched_embed_positions = PatchedMBartLearnedPositionalEmbedding(self.embed_positions)
  39. def forward(
  40. self,
  41. input_ids: torch.LongTensor = None,
  42. attention_mask: Optional[torch.Tensor] = None,
  43. count_pred: Optional[torch.FloatTensor] = None,
  44. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  45. encoder_attention_mask: Optional[torch.LongTensor] = None,
  46. head_mask: Optional[torch.Tensor] = None,
  47. cross_attn_head_mask: Optional[torch.Tensor] = None,
  48. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  49. inputs_embeds: Optional[torch.FloatTensor] = None,
  50. use_cache: Optional[bool] = None,
  51. output_attentions: Optional[bool] = None,
  52. output_hidden_states: Optional[bool] = None,
  53. return_dict: Optional[bool] = None,
  54. ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
  55. run_origin = False
  56. if past_key_values is None:
  57. run_origin = True
  58. elif past_key_values[0][0].size(-2) < attention_mask.size(-1):
  59. run_origin = True
  60. if run_origin:
  61. return self.origin(
  62. input_ids=input_ids,
  63. attention_mask=attention_mask,
  64. count_pred=count_pred,
  65. encoder_hidden_states=encoder_hidden_states,
  66. encoder_attention_mask=encoder_attention_mask,
  67. head_mask=head_mask,
  68. cross_attn_head_mask=cross_attn_head_mask,
  69. past_key_values=past_key_values,
  70. inputs_embeds=inputs_embeds,
  71. use_cache=use_cache,
  72. output_attentions=output_attentions,
  73. output_hidden_states=output_hidden_states,
  74. return_dict=return_dict,
  75. )
  76. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  77. output_hidden_states = (
  78. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  79. )
  80. use_cache = use_cache if use_cache is not None else self.config.use_cache
  81. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  82. # retrieve input_ids and inputs_embeds
  83. if input_ids is not None and inputs_embeds is not None:
  84. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  85. elif input_ids is not None:
  86. input = input_ids
  87. input_shape = input.size()
  88. input_ids = input_ids.view(-1, input_shape[-1])
  89. elif inputs_embeds is not None:
  90. input_shape = inputs_embeds.size()[:-1]
  91. input = inputs_embeds[:, :, -1]
  92. else:
  93. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  94. # past_key_values_length
  95. past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
  96. if inputs_embeds is None:
  97. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  98. if self._use_flash_attention_2:
  99. # 2d mask is passed through the layers
  100. attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  101. else:
  102. # 4d mask is passed through the layers
  103. attention_mask = _prepare_4d_causal_attention_mask(
  104. attention_mask, input_shape, inputs_embeds, past_key_values_length
  105. )
  106. # expand encoder attention mask
  107. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  108. if self._use_flash_attention_2:
  109. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  110. else:
  111. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  112. encoder_attention_mask = _prepare_4d_attention_mask(
  113. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  114. )
  115. # embed positions
  116. positions = self.patched_embed_positions(input, self.kvlen)
  117. hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
  118. # TODO: add counting context weight to hidden_states
  119. if count_pred is not None:
  120. count_context_weight = self.counting_context_weight(count_pred)
  121. hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
  122. hidden_states = self.layernorm_embedding(hidden_states)
  123. # decoder layers
  124. all_hidden_states = () if output_hidden_states else None
  125. all_self_attns = () if output_attentions else None
  126. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  127. next_decoder_cache = () if use_cache else None
  128. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  129. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  130. if attn_mask is not None:
  131. if attn_mask.size()[0] != len(self.layers):
  132. raise ValueError(
  133. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  134. f" {attn_mask.size()[0]}."
  135. )
  136. for idx, decoder_layer in enumerate(self.layers):
  137. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  138. if output_hidden_states:
  139. all_hidden_states += (hidden_states,)
  140. past_key_value = past_key_values[idx] if past_key_values is not None else None
  141. layer_outputs = decoder_layer(
  142. hidden_states,
  143. attention_mask=attention_mask,
  144. encoder_hidden_states=encoder_hidden_states,
  145. encoder_attention_mask=encoder_attention_mask,
  146. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  147. cross_attn_layer_head_mask=(
  148. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
  149. ),
  150. past_key_value=past_key_value,
  151. output_attentions=output_attentions,
  152. use_cache=use_cache,
  153. )
  154. hidden_states = layer_outputs[0]
  155. if use_cache:
  156. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  157. if output_attentions:
  158. all_self_attns += (layer_outputs[1],)
  159. if encoder_hidden_states is not None:
  160. all_cross_attentions += (layer_outputs[2],)
  161. hidden_states = self.layer_norm(hidden_states)
  162. # add hidden states from the last decoder layer
  163. if output_hidden_states:
  164. all_hidden_states += (hidden_states,)
  165. next_cache = next_decoder_cache if use_cache else None
  166. if not return_dict:
  167. return tuple(
  168. v
  169. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
  170. if v is not None
  171. )
  172. return BaseModelOutputWithPastAndCrossAttentions(
  173. last_hidden_state=hidden_states,
  174. past_key_values=next_cache,
  175. hidden_states=all_hidden_states,
  176. attentions=all_self_attns,
  177. cross_attentions=all_cross_attentions,
  178. )
  179. class PatchedMBartAttention(nn.Module):
  180. def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
  181. super().__init__()
  182. self.embed_dim = origin.embed_dim
  183. self.num_heads = origin.num_heads
  184. self.dropout = origin.dropout
  185. self.head_dim = origin.head_dim
  186. self.config = origin.config
  187. self.scaling = origin.scaling
  188. self.is_decoder = origin.is_decoder
  189. self.is_causal = origin.is_causal
  190. self.k_proj = origin.k_proj
  191. self.v_proj = origin.v_proj
  192. self.q_proj = origin.q_proj
  193. self.out_proj = origin.out_proj
  194. self.kvlen = kvlen
  195. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  196. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  197. def forward(
  198. self,
  199. hidden_states: torch.Tensor,
  200. key_value_states: Optional[torch.Tensor] = None,
  201. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  202. attention_mask: Optional[torch.Tensor] = None,
  203. layer_head_mask: Optional[torch.Tensor] = None,
  204. output_attentions: bool = False,
  205. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  206. """Input shape: Batch x Time x Channel"""
  207. # if key_value_states are provided this layer is used as a cross-attention layer
  208. # for the decoder
  209. is_cross_attention = key_value_states is not None
  210. bsz, tgt_len, _ = hidden_states.size()
  211. # get query proj
  212. query_states = self.q_proj(hidden_states) * self.scaling
  213. # get key, value proj
  214. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  215. # is checking that the `sequence_length` of the `past_key_value` is the same as
  216. # the provided `key_value_states` to support prefix tuning
  217. if (
  218. is_cross_attention
  219. and past_key_value is not None
  220. and past_key_value[0].shape[2] == key_value_states.shape[1]
  221. ):
  222. # reuse k,v, cross_attentions
  223. key_states = past_key_value[0]
  224. value_states = past_key_value[1]
  225. elif is_cross_attention:
  226. # cross_attentions
  227. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  228. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  229. elif past_key_value is not None:
  230. # reuse k, v, self_attention
  231. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  232. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  233. if past_key_value[0].size(-2) < attention_mask.size(-1):
  234. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  235. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  236. else:
  237. past_key_value[0][:, :, self.kvlen[None]] = key_states
  238. past_key_value[1][:, :, self.kvlen[None]] = value_states
  239. key_states = past_key_value[0]
  240. value_states = past_key_value[1]
  241. else:
  242. # self_attention
  243. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  244. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  245. if self.is_decoder:
  246. past_key_value = (key_states, value_states)
  247. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  248. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  249. key_states = key_states.reshape(*proj_shape)
  250. value_states = value_states.reshape(*proj_shape)
  251. src_len = key_states.size(1)
  252. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  253. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  254. raise ValueError(
  255. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  256. f" {attn_weights.size()}"
  257. )
  258. if attention_mask is not None:
  259. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  260. raise ValueError(
  261. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  262. )
  263. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  264. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  265. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  266. if layer_head_mask is not None:
  267. if layer_head_mask.size() != (self.num_heads,):
  268. raise ValueError(
  269. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  270. f" {layer_head_mask.size()}"
  271. )
  272. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  273. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  274. if output_attentions:
  275. # this operation is a bit awkward, but it's required to
  276. # make sure that attn_weights keeps its gradient.
  277. # In order to do so, attn_weights have to be reshaped
  278. # twice and have to be reused in the following
  279. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  280. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  281. else:
  282. attn_weights_reshaped = None
  283. attn_probs = attn_weights
  284. attn_output = torch.bmm(attn_probs, value_states)
  285. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  286. raise ValueError(
  287. f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
  288. f" {attn_output.size()}"
  289. )
  290. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  291. attn_output = attn_output.transpose(1, 2)
  292. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  293. # partitioned across GPUs when using tensor-parallelism.
  294. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  295. # attn_output = self.out_proj(attn_output)
  296. attn_output = self.out_proj(attn_output)
  297. return attn_output, attn_weights_reshaped, past_key_value
  298. class PatchedMBartSqueezeAttention(nn.Module):
  299. def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
  300. super().__init__()
  301. self.embed_dim = origin.embed_dim
  302. self.num_heads = origin.num_heads
  303. self.dropout = origin.dropout
  304. self.head_dim = origin.head_dim
  305. self.squeeze_head_dim=origin.squeeze_head_dim
  306. self.config = origin.config
  307. self.scaling = origin.scaling
  308. self.is_decoder = origin.is_decoder
  309. self.scaling = origin.scaling
  310. self.q_proj = origin.q_proj
  311. self.k_proj = origin.k_proj
  312. self.v_proj = origin.v_proj
  313. self.out_proj = origin.out_proj
  314. self.kvlen = kvlen
  315. def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  316. return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous()
  317. def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  318. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  319. def forward(
  320. self,
  321. hidden_states: torch.Tensor,
  322. key_value_states: Optional[torch.Tensor] = None,
  323. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  324. attention_mask: Optional[torch.Tensor] = None,
  325. layer_head_mask: Optional[torch.Tensor] = None,
  326. output_attentions: bool = False,
  327. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  328. """Input shape: Batch x Time x Channel"""
  329. # if key_value_states are provided this layer is used as a cross-attention layer
  330. # for the decoder
  331. is_cross_attention = key_value_states is not None
  332. bsz, tgt_len, _ = hidden_states.size()
  333. # get query proj
  334. query_states = self.q_proj(hidden_states) * self.scaling
  335. # get key, value proj
  336. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  337. # is checking that the `sequence_length` of the `past_key_value` is the same as
  338. # the provided `key_value_states` to support prefix tuning
  339. if (
  340. is_cross_attention
  341. and past_key_value is not None
  342. and past_key_value[0].shape[2] == key_value_states.shape[1]
  343. ):
  344. # reuse k,v, cross_attentions
  345. key_states = past_key_value[0]
  346. value_states = past_key_value[1]
  347. elif is_cross_attention:
  348. # cross_attentions
  349. key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
  350. value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
  351. elif past_key_value is not None:
  352. # reuse k, v, self_attention
  353. key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
  354. value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
  355. if past_key_value[0].size(-2) < attention_mask.size(-1):
  356. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  357. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  358. else:
  359. past_key_value[0][:, :, self.kvlen[None]] = key_states
  360. past_key_value[1][:, :, self.kvlen[None]] = value_states
  361. key_states = past_key_value[0]
  362. value_states = past_key_value[1]
  363. else:
  364. # self_attention
  365. key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
  366. value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
  367. if self.is_decoder:
  368. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  369. # Further calls to cross_attention layer can then reuse all cross-attention
  370. # key/value_states (first "if" case)
  371. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  372. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  373. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  374. # if encoder bi-directional self-attention `past_key_value` is always `None`
  375. past_key_value = (key_states, value_states)
  376. proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim)
  377. value_shape = (bsz * self.num_heads, -1, self.head_dim)
  378. query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape)
  379. key_states = key_states.reshape(*proj_shape)
  380. value_states = value_states.reshape(*value_shape)
  381. src_len = key_states.size(1)
  382. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  383. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  384. raise ValueError(
  385. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  386. f" {attn_weights.size()}"
  387. )
  388. if attention_mask is not None:
  389. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  390. raise ValueError(
  391. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  392. )
  393. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  394. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  395. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  396. if layer_head_mask is not None:
  397. if layer_head_mask.size() != (self.num_heads,):
  398. raise ValueError(
  399. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  400. f" {layer_head_mask.size()}"
  401. )
  402. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  403. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  404. if output_attentions:
  405. # this operation is a bit awkward, but it's required to
  406. # make sure that attn_weights keeps its gradient.
  407. # In order to do so, attn_weights have to be reshaped
  408. # twice and have to be reused in the following
  409. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  410. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  411. else:
  412. attn_weights_reshaped = None
  413. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  414. attn_output = torch.bmm(attn_probs, value_states)
  415. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  416. raise ValueError(
  417. f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
  418. f" {attn_output.size()}"
  419. )
  420. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  421. attn_output = attn_output.transpose(1, 2)
  422. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  423. # partitioned across GPUs when using tensor-parallelism.
  424. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  425. attn_output = self.out_proj(attn_output)
  426. return attn_output, attn_weights_reshaped, past_key_value
  427. def patch_model(model: nn.Module, kvlen: torch.LongTensor):
  428. for name, child in model.named_children():
  429. cls_name = type(child).__name__
  430. if cls_name == 'MBartAttention':
  431. patched_child = PatchedMBartAttention(child, kvlen)
  432. model.register_module(name, patched_child)
  433. elif cls_name == 'MBartSqueezeAttention':
  434. patched_child = PatchedMBartSqueezeAttention(child, kvlen)
  435. model.register_module(name, patched_child)
  436. else:
  437. patch_model(child, kvlen)
  438. cls_name = type(model).__name__
  439. if cls_name == 'CustomMBartDecoder':
  440. model = PatchedMBartDecoder(model, kvlen)
  441. return model
  442. def next_power_of_2(n: int):
  443. """Return the smallest power of 2 greater than or equal to n."""
  444. n -= 1
  445. n |= n >> 1
  446. n |= n >> 2
  447. n |= n >> 4
  448. n |= n >> 8
  449. n |= n >> 16
  450. n |= n >> 32
  451. n += 1
  452. return n
  453. def get_graph_key(batch_size: int, kvlens: int):
  454. batch_size = next_power_of_2(batch_size)
  455. kvlens = next_power_of_2(kvlens)
  456. batch_size = max(8, batch_size)
  457. kvlens = max(32, kvlens)
  458. return batch_size, kvlens
  459. class GraphRunnerImpl:
  460. def __init__(self, model: nn.Module, graph: torch.cuda.CUDAGraph, input_buffers: dict, output_buffers: dict):
  461. self.model = model
  462. self.graph = graph
  463. self.input_buffers = input_buffers
  464. self.output_buffers = output_buffers
  465. @staticmethod
  466. def extract_input_buffers(input_buffers: dict, batch_size: int, kvlens: int):
  467. input_ids = input_buffers['input_ids'][:batch_size]
  468. attention_mask = input_buffers['attention_mask'][:batch_size, :kvlens]
  469. encoder_hidden_states = input_buffers['encoder_hidden_states'][:batch_size]
  470. kvlen=input_buffers['kvlen']
  471. past_key_values = []
  472. for past_key_value in input_buffers['past_key_values']:
  473. k0 = past_key_value[0][:batch_size, :, :kvlens]
  474. v0 = past_key_value[1][:batch_size, :, :kvlens]
  475. k1 = past_key_value[2][:batch_size]
  476. v1 = past_key_value[3][:batch_size]
  477. past_key_values.append((k0, v0, k1, v1))
  478. input_buffers = dict(
  479. input_ids=input_ids,
  480. attention_mask=attention_mask,
  481. encoder_hidden_states=encoder_hidden_states,
  482. past_key_values=past_key_values,
  483. kvlen=kvlen,
  484. )
  485. return input_buffers
  486. @staticmethod
  487. def fill_input_buffers(
  488. input_buffer: dict,
  489. input_ids: torch.LongTensor = None,
  490. attention_mask: Optional[torch.Tensor] = None,
  491. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  492. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  493. ):
  494. batch_size = input_ids.size(0)
  495. kvlens = attention_mask.size(1)
  496. input_buffer['input_ids'][:batch_size] = input_ids
  497. if input_buffer['attention_mask'].data_ptr() != attention_mask.data_ptr():
  498. input_buffer['attention_mask'].fill_(0)
  499. input_buffer['attention_mask'][:batch_size, :kvlens] = attention_mask
  500. input_buffer['encoder_hidden_states'][:batch_size] = encoder_hidden_states
  501. if past_key_values is not None:
  502. for buf_kv, kv in zip(input_buffer['past_key_values'], past_key_values):
  503. idx = 0
  504. if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
  505. buf_kv[idx].fill_(0)
  506. buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
  507. idx = 1
  508. if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
  509. buf_kv[idx].fill_(0)
  510. buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
  511. idx = 2
  512. if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
  513. buf_kv[idx].fill_(0)
  514. buf_kv[idx][:batch_size] = kv[idx]
  515. idx = 3
  516. if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
  517. buf_kv[idx].fill_(0)
  518. buf_kv[idx][:batch_size] = kv[idx]
  519. input_buffer['kvlen'].fill_(kvlens - 1)
  520. @classmethod
  521. @torch.inference_mode()
  522. def capture(cls,
  523. model: nn.Module,
  524. input_buffers: dict,
  525. pool,
  526. warmup: bool = False,
  527. input_ids: torch.LongTensor = None,
  528. attention_mask: Optional[torch.Tensor] = None,
  529. count_pred: Optional[torch.FloatTensor] = None,
  530. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  531. encoder_attention_mask: Optional[torch.LongTensor] = None,
  532. head_mask: Optional[torch.Tensor] = None,
  533. cross_attn_head_mask: Optional[torch.Tensor] = None,
  534. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  535. inputs_embeds: Optional[torch.FloatTensor] = None,
  536. use_cache: Optional[bool] = None,
  537. output_attentions: Optional[bool] = None,
  538. output_hidden_states: Optional[bool] = None,
  539. return_dict: Optional[bool] = None,):
  540. batch_size = input_ids.size(0)
  541. kvlens = attention_mask.size(1)
  542. graph_key = get_graph_key(batch_size, kvlens)
  543. batch_size = graph_key[0]
  544. kvlens = graph_key[1]
  545. input_buffers = cls.extract_input_buffers(input_buffers,
  546. batch_size=batch_size,
  547. kvlens=kvlens)
  548. cls.fill_input_buffers(input_buffers,
  549. input_ids,
  550. attention_mask,
  551. encoder_hidden_states,
  552. past_key_values)
  553. input_ids = input_buffers['input_ids']
  554. attention_mask = input_buffers['attention_mask']
  555. encoder_hidden_states = input_buffers['encoder_hidden_states']
  556. past_key_values = input_buffers['past_key_values']
  557. if warmup:
  558. # warmup
  559. model(
  560. input_ids=input_ids,
  561. attention_mask=attention_mask,
  562. count_pred=count_pred,
  563. encoder_hidden_states=encoder_hidden_states,
  564. encoder_attention_mask=encoder_attention_mask,
  565. head_mask=head_mask,
  566. cross_attn_head_mask=cross_attn_head_mask,
  567. past_key_values=past_key_values,
  568. inputs_embeds=inputs_embeds,
  569. use_cache=use_cache,
  570. output_attentions=output_attentions,
  571. output_hidden_states=output_hidden_states,
  572. return_dict=return_dict)
  573. graph = torch.cuda.CUDAGraph()
  574. with torch.cuda.graph(graph,
  575. pool=pool):
  576. outputs = model(
  577. input_ids=input_ids,
  578. attention_mask=attention_mask,
  579. count_pred=count_pred,
  580. encoder_hidden_states=encoder_hidden_states,
  581. encoder_attention_mask=encoder_attention_mask,
  582. head_mask=head_mask,
  583. cross_attn_head_mask=cross_attn_head_mask,
  584. past_key_values=past_key_values,
  585. inputs_embeds=inputs_embeds,
  586. use_cache=use_cache,
  587. output_attentions=output_attentions,
  588. output_hidden_states=output_hidden_states,
  589. return_dict=return_dict)
  590. output_buffers = dict(
  591. last_hidden_state=outputs['last_hidden_state'],
  592. past_key_values=outputs['past_key_values'],
  593. )
  594. return GraphRunnerImpl(model, graph, input_buffers, output_buffers)
  595. def __call__(self,
  596. input_ids: torch.LongTensor = None,
  597. attention_mask: Optional[torch.Tensor] = None,
  598. count_pred: Optional[torch.FloatTensor] = None,
  599. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  600. encoder_attention_mask: Optional[torch.LongTensor] = None,
  601. head_mask: Optional[torch.Tensor] = None,
  602. cross_attn_head_mask: Optional[torch.Tensor] = None,
  603. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  604. inputs_embeds: Optional[torch.FloatTensor] = None,
  605. use_cache: Optional[bool] = None,
  606. output_attentions: Optional[bool] = None,
  607. output_hidden_states: Optional[bool] = None,
  608. return_dict: Optional[bool] = None,
  609. ):
  610. batch_size = input_ids.size(0)
  611. kvlens = attention_mask.size(1)
  612. self.fill_input_buffers(self.input_buffers,
  613. input_ids,
  614. attention_mask,
  615. encoder_hidden_states,
  616. past_key_values)
  617. self.graph.replay()
  618. last_hidden_state = self.output_buffers['last_hidden_state'][:batch_size]
  619. past_key_values = []
  620. for past_key_value in self.output_buffers['past_key_values']:
  621. k0 = past_key_value[0][:batch_size, :, :kvlens]
  622. v0 = past_key_value[1][:batch_size, :, :kvlens]
  623. k1 = past_key_value[2][:batch_size]
  624. v1 = past_key_value[3][:batch_size]
  625. past_key_values.append((k0, v0, k1, v1))
  626. outputs = BaseModelOutputWithPastAndCrossAttentions(
  627. last_hidden_state=last_hidden_state,
  628. past_key_values=past_key_values,
  629. )
  630. return outputs
  631. class GraphRunner(nn.Module):
  632. def __init__(self, model: nn.Module, max_batchs: int, max_kvlens: int, dtype:torch.dtype = torch.float16, device: torch.device = 'cuda'):
  633. super().__init__()
  634. self.kvlen = torch.tensor(0, dtype=torch.long, device=device)
  635. model = patch_model(model.to(dtype), self.kvlen)
  636. self.model = model
  637. self.max_batchs = max_batchs
  638. self.max_kvlens = max_kvlens
  639. self.device = device
  640. self.input_buffers = None
  641. self.impl_map = dict()
  642. self.graph_pool_handle = torch.cuda.graph_pool_handle()
  643. self.warmuped = False
  644. def create_buffers(self, encoder_kvlens: int, dtype: torch.dtype):
  645. max_batchs = self.max_batchs
  646. max_kvlens = self.max_kvlens
  647. device = self.device
  648. config = self.model.config
  649. d_model = config.d_model
  650. decoder_layers = config.decoder_layers
  651. num_heads = config.decoder_attention_heads
  652. head_dim = d_model // num_heads
  653. self_attn = self.model.layers[0].self_attn
  654. qk_head_dim = getattr(self_attn, 'squeeze_head_dim', head_dim)
  655. input_ids = torch.ones((max_batchs, 1), dtype=torch.int64, device=device)
  656. attention_mask = torch.zeros((max_batchs, max_kvlens), dtype=torch.int64, device=device)
  657. encoder_hidden_states = torch.zeros((max_batchs, encoder_kvlens, d_model), dtype=dtype, device=device)
  658. past_key_values = []
  659. for _ in range(decoder_layers):
  660. k0 = torch.zeros((max_batchs, num_heads, max_kvlens, qk_head_dim), dtype=dtype, device=device)
  661. v0 = torch.zeros((max_batchs, num_heads, max_kvlens, head_dim), dtype=dtype, device=device)
  662. k1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, qk_head_dim), dtype=dtype, device=device)
  663. v1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, head_dim), dtype=dtype, device=device)
  664. past_key_values.append((k0, v0, k1, v1))
  665. self.input_buffers = dict(
  666. input_ids=input_ids,
  667. attention_mask=attention_mask,
  668. encoder_hidden_states=encoder_hidden_states,
  669. past_key_values=past_key_values,
  670. kvlen=self.kvlen
  671. )
  672. @torch.inference_mode()
  673. def forward(self,
  674. input_ids: torch.LongTensor = None,
  675. attention_mask: Optional[torch.Tensor] = None,
  676. count_pred: Optional[torch.FloatTensor] = None,
  677. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  678. encoder_attention_mask: Optional[torch.LongTensor] = None,
  679. head_mask: Optional[torch.Tensor] = None,
  680. cross_attn_head_mask: Optional[torch.Tensor] = None,
  681. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  682. inputs_embeds: Optional[torch.FloatTensor] = None,
  683. use_cache: Optional[bool] = None,
  684. output_attentions: Optional[bool] = None,
  685. output_hidden_states: Optional[bool] = None,
  686. return_dict: Optional[bool] = None,
  687. ):
  688. batch_size, qlens = input_ids.size()
  689. kvlens = attention_mask.size(1)
  690. eager_mode = False
  691. if qlens != 1:
  692. eager_mode = True
  693. if past_key_values is None:
  694. eager_mode = True
  695. else:
  696. for past_key_value in past_key_values:
  697. if past_key_value is None:
  698. eager_mode = True
  699. break
  700. if batch_size >= self.max_batchs or kvlens >= self.max_kvlens:
  701. eager_mode = True
  702. if eager_mode:
  703. return self.model(
  704. input_ids=input_ids,
  705. attention_mask=attention_mask,
  706. count_pred=count_pred,
  707. encoder_hidden_states=encoder_hidden_states,
  708. encoder_attention_mask=encoder_attention_mask,
  709. head_mask=head_mask,
  710. cross_attn_head_mask=cross_attn_head_mask,
  711. past_key_values=past_key_values,
  712. inputs_embeds=inputs_embeds,
  713. use_cache=use_cache,
  714. output_attentions=output_attentions,
  715. output_hidden_states=output_hidden_states,
  716. return_dict=return_dict,)
  717. # create buffer if not exists.
  718. if self.input_buffers is None:
  719. encoder_kvlens = encoder_hidden_states.size(1)
  720. self.create_buffers(encoder_kvlens=encoder_kvlens, dtype=encoder_hidden_states.dtype)
  721. graph_key = get_graph_key(batch_size, kvlens)
  722. if graph_key not in self.impl_map:
  723. warmup = False
  724. if not self.warmuped:
  725. warmup = True
  726. self.warmuped = True
  727. impl = GraphRunnerImpl.capture(
  728. self.model,
  729. self.input_buffers,
  730. self.graph_pool_handle,
  731. warmup=warmup,
  732. input_ids=input_ids,
  733. attention_mask=attention_mask,
  734. count_pred=count_pred,
  735. encoder_hidden_states=encoder_hidden_states,
  736. encoder_attention_mask=encoder_attention_mask,
  737. head_mask=head_mask,
  738. cross_attn_head_mask=cross_attn_head_mask,
  739. past_key_values=past_key_values,
  740. inputs_embeds=inputs_embeds,
  741. use_cache=use_cache,
  742. output_attentions=output_attentions,
  743. output_hidden_states=output_hidden_states,
  744. return_dict=return_dict,
  745. )
  746. self.impl_map[graph_key] = impl
  747. impl = self.impl_map[graph_key]
  748. ret = impl(
  749. input_ids=input_ids,
  750. attention_mask=attention_mask,
  751. count_pred=count_pred,
  752. encoder_hidden_states=encoder_hidden_states,
  753. encoder_attention_mask=encoder_attention_mask,
  754. head_mask=head_mask,
  755. cross_attn_head_mask=cross_attn_head_mask,
  756. past_key_values=past_key_values,
  757. inputs_embeds=inputs_embeds,
  758. use_cache=use_cache,
  759. output_attentions=output_attentions,
  760. output_hidden_states=output_hidden_states,
  761. return_dict=return_dict,
  762. )
  763. return ret