| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899 |
- from typing import Optional, Tuple, Union
- import torch
- from torch import nn
- import os
- from unimernet.common.config import Config
- import unimernet.tasks as tasks
- import argparse
- from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
- class PatchedMBartLearnedPositionalEmbedding(nn.Module):
- def __init__(self, origin: nn.Module):
- super().__init__()
- self.offset = origin.offset
- self.embedding = nn.Embedding(origin.num_embeddings, origin.embedding_dim)
- self.embedding.weight.data = origin.weight.data
- def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
- """`input_ids' shape is expected to be [bsz x seqlen]."""
- bsz, seq_len = input_ids.shape[:2]
- positions = torch.arange(0, seq_len, dtype=torch.long, device=self.embedding.weight.device
- )
- positions += past_key_values_length
- positions = positions.expand(bsz, -1)
- return self.embedding(positions + self.offset)
- class PatchedMBartDecoder(nn.Module):
- def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
- super().__init__()
- self.origin = origin
- self.kvlen = kvlen
- self.config = origin.config
- self.embed_tokens = origin.embed_tokens
- self.embed_scale = origin.embed_scale
- self._use_flash_attention_2 = origin._use_flash_attention_2
- self.embed_positions = origin.embed_positions
- self.counting_context_weight = getattr(origin, 'counting_context_weight', None)
- self.layernorm_embedding = origin.layernorm_embedding
- self.layers = origin.layers
- self.layer_norm = origin.layer_norm
- self.patched_embed_positions = PatchedMBartLearnedPositionalEmbedding(self.embed_positions)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- count_pred: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
- run_origin = False
- if past_key_values is None:
- run_origin = True
- elif past_key_values[0][0].size(-2) < attention_mask.size(-1):
- run_origin = True
- if run_origin:
- return self.origin(
- input_ids=input_ids,
- attention_mask=attention_mask,
- count_pred=count_pred,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- head_mask=head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- input = input_ids
- input_shape = input.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- input = inputs_embeds[:, :, -1]
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
- # past_key_values_length
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # 4d mask is passed through the layers
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
- )
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
- encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
- # embed positions
- positions = self.patched_embed_positions(input, self.kvlen)
- hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
- # TODO: add counting context weight to hidden_states
- if count_pred is not None:
- count_context_weight = self.counting_context_weight(count_pred)
- hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
- hidden_states = self.layernorm_embedding(hidden_states)
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
- next_decoder_cache = () if use_cache else None
- # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
- for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
- if attn_mask is not None:
- if attn_mask.size()[0] != len(self.layers):
- raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
- f" {attn_mask.size()[0]}."
- )
- for idx, decoder_layer in enumerate(self.layers):
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- past_key_value = past_key_values[idx] if past_key_values is not None else None
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
- cross_attn_layer_head_mask=(
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
- ),
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
- hidden_states = layer_outputs[0]
- if use_cache:
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- if encoder_hidden_states is not None:
- all_cross_attentions += (layer_outputs[2],)
- hidden_states = self.layer_norm(hidden_states)
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- next_cache = next_decoder_cache if use_cache else None
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
- if v is not None
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- cross_attentions=all_cross_attentions,
- )
- class PatchedMBartAttention(nn.Module):
- def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
- super().__init__()
- self.embed_dim = origin.embed_dim
- self.num_heads = origin.num_heads
- self.dropout = origin.dropout
- self.head_dim = origin.head_dim
- self.config = origin.config
- self.scaling = origin.scaling
- self.is_decoder = origin.is_decoder
- self.is_causal = origin.is_causal
- self.k_proj = origin.k_proj
- self.v_proj = origin.v_proj
- self.q_proj = origin.q_proj
- self.out_proj = origin.out_proj
- self.kvlen = kvlen
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- if past_key_value[0].size(-2) < attention_mask.size(-1):
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- past_key_value[0][:, :, self.kvlen[None]] = key_states
- past_key_value[1][:, :, self.kvlen[None]] = value_states
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- if self.is_decoder:
- past_key_value = (key_states, value_states)
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
- attn_probs = attn_weights
- attn_output = torch.bmm(attn_probs, value_states)
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
- # attn_output = self.out_proj(attn_output)
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
- class PatchedMBartSqueezeAttention(nn.Module):
- def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
- super().__init__()
- self.embed_dim = origin.embed_dim
- self.num_heads = origin.num_heads
- self.dropout = origin.dropout
- self.head_dim = origin.head_dim
- self.squeeze_head_dim=origin.squeeze_head_dim
- self.config = origin.config
- self.scaling = origin.scaling
- self.is_decoder = origin.is_decoder
- self.scaling = origin.scaling
- self.q_proj = origin.q_proj
- self.k_proj = origin.k_proj
- self.v_proj = origin.v_proj
- self.out_proj = origin.out_proj
- self.kvlen = kvlen
- def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous()
- def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
- if past_key_value[0].size(-2) < attention_mask.size(-1):
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- past_key_value[0][:, :, self.kvlen[None]] = key_states
- past_key_value[1][:, :, self.kvlen[None]] = value_states
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- else:
- # self_attention
- key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
- proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim)
- value_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*value_shape)
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_output = torch.bmm(attn_probs, value_states)
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
- def patch_model(model: nn.Module, kvlen: torch.LongTensor):
- for name, child in model.named_children():
- cls_name = type(child).__name__
- if cls_name == 'MBartAttention':
- patched_child = PatchedMBartAttention(child, kvlen)
- model.register_module(name, patched_child)
- elif cls_name == 'MBartSqueezeAttention':
- patched_child = PatchedMBartSqueezeAttention(child, kvlen)
- model.register_module(name, patched_child)
- else:
- patch_model(child, kvlen)
- cls_name = type(model).__name__
- if cls_name == 'CustomMBartDecoder':
- model = PatchedMBartDecoder(model, kvlen)
- return model
- def next_power_of_2(n: int):
- """Return the smallest power of 2 greater than or equal to n."""
- n -= 1
- n |= n >> 1
- n |= n >> 2
- n |= n >> 4
- n |= n >> 8
- n |= n >> 16
- n |= n >> 32
- n += 1
- return n
- def get_graph_key(batch_size: int, kvlens: int):
- batch_size = next_power_of_2(batch_size)
- kvlens = next_power_of_2(kvlens)
- batch_size = max(8, batch_size)
- kvlens = max(32, kvlens)
- return batch_size, kvlens
- class GraphRunnerImpl:
- def __init__(self, model: nn.Module, graph: torch.cuda.CUDAGraph, input_buffers: dict, output_buffers: dict):
- self.model = model
- self.graph = graph
- self.input_buffers = input_buffers
- self.output_buffers = output_buffers
- @staticmethod
- def extract_input_buffers(input_buffers: dict, batch_size: int, kvlens: int):
- input_ids = input_buffers['input_ids'][:batch_size]
- attention_mask = input_buffers['attention_mask'][:batch_size, :kvlens]
- encoder_hidden_states = input_buffers['encoder_hidden_states'][:batch_size]
- kvlen=input_buffers['kvlen']
- past_key_values = []
- for past_key_value in input_buffers['past_key_values']:
- k0 = past_key_value[0][:batch_size, :, :kvlens]
- v0 = past_key_value[1][:batch_size, :, :kvlens]
- k1 = past_key_value[2][:batch_size]
- v1 = past_key_value[3][:batch_size]
- past_key_values.append((k0, v0, k1, v1))
- input_buffers = dict(
- input_ids=input_ids,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- past_key_values=past_key_values,
- kvlen=kvlen,
- )
- return input_buffers
- @staticmethod
- def fill_input_buffers(
- input_buffer: dict,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- ):
- batch_size = input_ids.size(0)
- kvlens = attention_mask.size(1)
- input_buffer['input_ids'][:batch_size] = input_ids
- if input_buffer['attention_mask'].data_ptr() != attention_mask.data_ptr():
- input_buffer['attention_mask'].fill_(0)
- input_buffer['attention_mask'][:batch_size, :kvlens] = attention_mask
- input_buffer['encoder_hidden_states'][:batch_size] = encoder_hidden_states
- if past_key_values is not None:
- for buf_kv, kv in zip(input_buffer['past_key_values'], past_key_values):
- idx = 0
- if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
- buf_kv[idx].fill_(0)
- buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
- idx = 1
- if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
- buf_kv[idx].fill_(0)
- buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
- idx = 2
- if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
- buf_kv[idx].fill_(0)
- buf_kv[idx][:batch_size] = kv[idx]
- idx = 3
- if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
- buf_kv[idx].fill_(0)
- buf_kv[idx][:batch_size] = kv[idx]
- input_buffer['kvlen'].fill_(kvlens - 1)
- @classmethod
- @torch.inference_mode()
- def capture(cls,
- model: nn.Module,
- input_buffers: dict,
- pool,
- warmup: bool = False,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- count_pred: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,):
- batch_size = input_ids.size(0)
- kvlens = attention_mask.size(1)
- graph_key = get_graph_key(batch_size, kvlens)
- batch_size = graph_key[0]
- kvlens = graph_key[1]
- input_buffers = cls.extract_input_buffers(input_buffers,
- batch_size=batch_size,
- kvlens=kvlens)
- cls.fill_input_buffers(input_buffers,
- input_ids,
- attention_mask,
- encoder_hidden_states,
- past_key_values)
- input_ids = input_buffers['input_ids']
- attention_mask = input_buffers['attention_mask']
- encoder_hidden_states = input_buffers['encoder_hidden_states']
- past_key_values = input_buffers['past_key_values']
- if warmup:
- # warmup
- model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- count_pred=count_pred,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- head_mask=head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict)
- graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(graph,
- pool=pool):
- outputs = model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- count_pred=count_pred,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- head_mask=head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict)
- output_buffers = dict(
- last_hidden_state=outputs['last_hidden_state'],
- past_key_values=outputs['past_key_values'],
- )
- return GraphRunnerImpl(model, graph, input_buffers, output_buffers)
- def __call__(self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- count_pred: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ):
- batch_size = input_ids.size(0)
- kvlens = attention_mask.size(1)
- self.fill_input_buffers(self.input_buffers,
- input_ids,
- attention_mask,
- encoder_hidden_states,
- past_key_values)
- self.graph.replay()
- last_hidden_state = self.output_buffers['last_hidden_state'][:batch_size]
- past_key_values = []
- for past_key_value in self.output_buffers['past_key_values']:
- k0 = past_key_value[0][:batch_size, :, :kvlens]
- v0 = past_key_value[1][:batch_size, :, :kvlens]
- k1 = past_key_value[2][:batch_size]
- v1 = past_key_value[3][:batch_size]
- past_key_values.append((k0, v0, k1, v1))
- outputs = BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=last_hidden_state,
- past_key_values=past_key_values,
- )
- return outputs
- class GraphRunner(nn.Module):
- def __init__(self, model: nn.Module, max_batchs: int, max_kvlens: int, dtype:torch.dtype = torch.float16, device: torch.device = 'cuda'):
- super().__init__()
- self.kvlen = torch.tensor(0, dtype=torch.long, device=device)
- model = patch_model(model.to(dtype), self.kvlen)
- self.model = model
- self.max_batchs = max_batchs
- self.max_kvlens = max_kvlens
- self.device = device
- self.input_buffers = None
- self.impl_map = dict()
- self.graph_pool_handle = torch.cuda.graph_pool_handle()
- self.warmuped = False
- def create_buffers(self, encoder_kvlens: int, dtype: torch.dtype):
- max_batchs = self.max_batchs
- max_kvlens = self.max_kvlens
- device = self.device
- config = self.model.config
- d_model = config.d_model
- decoder_layers = config.decoder_layers
- num_heads = config.decoder_attention_heads
- head_dim = d_model // num_heads
- self_attn = self.model.layers[0].self_attn
- qk_head_dim = getattr(self_attn, 'squeeze_head_dim', head_dim)
- input_ids = torch.ones((max_batchs, 1), dtype=torch.int64, device=device)
- attention_mask = torch.zeros((max_batchs, max_kvlens), dtype=torch.int64, device=device)
- encoder_hidden_states = torch.zeros((max_batchs, encoder_kvlens, d_model), dtype=dtype, device=device)
- past_key_values = []
- for _ in range(decoder_layers):
- k0 = torch.zeros((max_batchs, num_heads, max_kvlens, qk_head_dim), dtype=dtype, device=device)
- v0 = torch.zeros((max_batchs, num_heads, max_kvlens, head_dim), dtype=dtype, device=device)
- k1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, qk_head_dim), dtype=dtype, device=device)
- v1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, head_dim), dtype=dtype, device=device)
- past_key_values.append((k0, v0, k1, v1))
- self.input_buffers = dict(
- input_ids=input_ids,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- past_key_values=past_key_values,
- kvlen=self.kvlen
- )
- @torch.inference_mode()
- def forward(self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- count_pred: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ):
- batch_size, qlens = input_ids.size()
- kvlens = attention_mask.size(1)
- eager_mode = False
- if qlens != 1:
- eager_mode = True
- if past_key_values is None:
- eager_mode = True
- else:
- for past_key_value in past_key_values:
- if past_key_value is None:
- eager_mode = True
- break
- if batch_size >= self.max_batchs or kvlens >= self.max_kvlens:
- eager_mode = True
- if eager_mode:
- return self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- count_pred=count_pred,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- head_mask=head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,)
- # create buffer if not exists.
- if self.input_buffers is None:
- encoder_kvlens = encoder_hidden_states.size(1)
- self.create_buffers(encoder_kvlens=encoder_kvlens, dtype=encoder_hidden_states.dtype)
- graph_key = get_graph_key(batch_size, kvlens)
- if graph_key not in self.impl_map:
- warmup = False
- if not self.warmuped:
- warmup = True
- self.warmuped = True
- impl = GraphRunnerImpl.capture(
- self.model,
- self.input_buffers,
- self.graph_pool_handle,
- warmup=warmup,
- input_ids=input_ids,
- attention_mask=attention_mask,
- count_pred=count_pred,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- head_mask=head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- self.impl_map[graph_key] = impl
- impl = self.impl_map[graph_key]
- ret = impl(
- input_ids=input_ids,
- attention_mask=attention_mask,
- count_pred=count_pred,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- head_mask=head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- return ret
|