liukaiwen 1 год назад
Родитель
Сommit
a3358878b3
2 измененных файлов с 905 добавлено и 0 удалено
  1. 899 0
      magic_pdf/model/mfr_cudagraph.py
  2. 6 0
      magic_pdf/model/pdf_extract_kit.py

+ 899 - 0
magic_pdf/model/mfr_cudagraph.py

@@ -0,0 +1,899 @@
+from typing import Optional, Tuple, Union
+import torch
+from torch import nn
+import os
+from unimernet.common.config import Config
+import unimernet.tasks as tasks
+import argparse
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+
+class PatchedMBartLearnedPositionalEmbedding(nn.Module):
+
+    def __init__(self, origin: nn.Module):
+        super().__init__()
+        self.offset = origin.offset
+        self.embedding = nn.Embedding(origin.num_embeddings, origin.embedding_dim)
+        self.embedding.weight.data = origin.weight.data
+
+    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+        """`input_ids' shape is expected to be [bsz x seqlen]."""
+
+        bsz, seq_len = input_ids.shape[:2]
+        positions = torch.arange(0, seq_len, dtype=torch.long, device=self.embedding.weight.device
+        )
+        positions += past_key_values_length
+        positions = positions.expand(bsz, -1)
+
+        return self.embedding(positions + self.offset)
+
+
+class PatchedMBartDecoder(nn.Module):
+    def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
+        super().__init__()
+        self.origin = origin
+        self.kvlen = kvlen
+
+        self.config = origin.config
+        self.embed_tokens = origin.embed_tokens
+        self.embed_scale = origin.embed_scale
+        self._use_flash_attention_2 = origin._use_flash_attention_2
+        self.embed_positions = origin.embed_positions
+        self.counting_context_weight = getattr(origin, 'counting_context_weight', None)
+        self.layernorm_embedding = origin.layernorm_embedding
+        self.layers = origin.layers
+        self.layer_norm = origin.layer_norm
+
+        self.patched_embed_positions = PatchedMBartLearnedPositionalEmbedding(self.embed_positions)
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        count_pred: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        run_origin = False
+        if past_key_values is None:
+            run_origin = True
+        elif past_key_values[0][0].size(-2) < attention_mask.size(-1):
+            run_origin = True
+
+        if run_origin:
+            return self.origin(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        if self._use_flash_attention_2:
+            # 2d mask is passed through the layers
+            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+        else:
+            # 4d mask is passed through the layers
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask, input_shape, inputs_embeds, past_key_values_length
+            )
+
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self._use_flash_attention_2:
+                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        # embed positions
+        positions = self.patched_embed_positions(input, self.kvlen)
+
+        hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+
+        # TODO: add counting context weight to hidden_states
+        if count_pred is not None:
+            count_context_weight = self.counting_context_weight(count_pred)
+            hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
+        hidden_states = self.layernorm_embedding(hidden_states)
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {attn_mask.size()[0]}."
+                    )
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+            layer_outputs = decoder_layer(
+                hidden_states,
+                attention_mask=attention_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                cross_attn_layer_head_mask=(
+                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+                ),
+                past_key_value=past_key_value,
+                output_attentions=output_attentions,
+                use_cache=use_cache,
+            )
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class PatchedMBartAttention(nn.Module):
+
+    def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
+        super().__init__()
+        self.embed_dim = origin.embed_dim
+        self.num_heads = origin.num_heads
+        self.dropout = origin.dropout
+        self.head_dim = origin.head_dim
+        self.config = origin.config
+
+        self.scaling = origin.scaling
+        self.is_decoder = origin.is_decoder
+        self.is_causal = origin.is_causal
+
+        self.k_proj = origin.k_proj
+        self.v_proj = origin.v_proj
+        self.q_proj = origin.q_proj
+        self.out_proj = origin.out_proj
+        self.kvlen = kvlen
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scaling
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+            if past_key_value[0].size(-2) < attention_mask.size(-1):
+                key_states = torch.cat([past_key_value[0], key_states], dim=2)
+                value_states = torch.cat([past_key_value[1], value_states], dim=2)
+            else:
+                past_key_value[0][:, :, self.kvlen[None]] = key_states
+                past_key_value[1][:, :, self.kvlen[None]] = value_states
+                key_states = past_key_value[0]
+                value_states = past_key_value[1]
+        else:
+            # self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.reshape(*proj_shape)
+        value_states = value_states.reshape(*proj_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if layer_head_mask is not None:
+            if layer_head_mask.size() != (self.num_heads,):
+                raise ValueError(
+                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+                    f" {layer_head_mask.size()}"
+                )
+            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if output_attentions:
+            # this operation is a bit awkward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to be reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = attn_weights
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        # attn_output = self.out_proj(attn_output)
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped, past_key_value
+
+
+class PatchedMBartSqueezeAttention(nn.Module):
+
+    def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
+        super().__init__()
+        self.embed_dim = origin.embed_dim
+        self.num_heads = origin.num_heads
+        self.dropout = origin.dropout
+        self.head_dim = origin.head_dim
+        self.squeeze_head_dim=origin.squeeze_head_dim
+        self.config = origin.config
+
+        self.scaling = origin.scaling
+        self.is_decoder = origin.is_decoder
+        self.scaling = origin.scaling
+
+        self.q_proj = origin.q_proj
+        self.k_proj = origin.k_proj
+        self.v_proj = origin.v_proj
+        self.out_proj = origin.out_proj
+        self.kvlen = kvlen
+
+    def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous()
+
+    def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scaling
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+
+            if past_key_value[0].size(-2) < attention_mask.size(-1):
+                key_states = torch.cat([past_key_value[0], key_states], dim=2)
+                value_states = torch.cat([past_key_value[1], value_states], dim=2)
+            else:
+                past_key_value[0][:, :, self.kvlen[None]] = key_states
+                past_key_value[1][:, :, self.kvlen[None]] = value_states
+                key_states = past_key_value[0]
+                value_states = past_key_value[1]
+        else:
+            # self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim)
+        value_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.reshape(*proj_shape)
+        value_states = value_states.reshape(*value_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if layer_head_mask is not None:
+            if layer_head_mask.size() != (self.num_heads,):
+                raise ValueError(
+                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+                    f" {layer_head_mask.size()}"
+                )
+            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if output_attentions:
+            # this operation is a bit awkward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to be reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped, past_key_value
+
+def patch_model(model: nn.Module, kvlen: torch.LongTensor):
+    for name, child in model.named_children():
+        cls_name = type(child).__name__
+        if cls_name == 'MBartAttention':
+            patched_child = PatchedMBartAttention(child, kvlen)
+            model.register_module(name, patched_child)
+        elif cls_name == 'MBartSqueezeAttention':
+            patched_child = PatchedMBartSqueezeAttention(child, kvlen)
+            model.register_module(name, patched_child)
+        else:
+            patch_model(child, kvlen)
+
+    cls_name = type(model).__name__
+    if cls_name == 'CustomMBartDecoder':
+        model = PatchedMBartDecoder(model, kvlen)
+    return model
+
+
+def next_power_of_2(n: int):
+    """Return the smallest power of 2 greater than or equal to n."""
+    n -= 1
+    n |= n >> 1
+    n |= n >> 2
+    n |= n >> 4
+    n |= n >> 8
+    n |= n >> 16
+    n |= n >> 32
+    n += 1
+    return n
+
+
+def get_graph_key(batch_size: int, kvlens: int):
+    batch_size = next_power_of_2(batch_size)
+    kvlens = next_power_of_2(kvlens)
+
+    batch_size = max(8, batch_size)
+    kvlens = max(32, kvlens)
+
+    return batch_size, kvlens
+
+
+class GraphRunnerImpl:
+
+    def __init__(self, model: nn.Module, graph: torch.cuda.CUDAGraph, input_buffers: dict, output_buffers: dict):
+        self.model = model
+        self.graph = graph
+        self.input_buffers = input_buffers
+        self.output_buffers = output_buffers
+
+    @staticmethod
+    def extract_input_buffers(input_buffers: dict, batch_size: int, kvlens: int):
+        input_ids = input_buffers['input_ids'][:batch_size]
+        attention_mask = input_buffers['attention_mask'][:batch_size, :kvlens]
+        encoder_hidden_states = input_buffers['encoder_hidden_states'][:batch_size]
+        kvlen=input_buffers['kvlen']
+
+        past_key_values = []
+        for past_key_value in input_buffers['past_key_values']:
+            k0 = past_key_value[0][:batch_size, :, :kvlens]
+            v0 = past_key_value[1][:batch_size, :, :kvlens]
+            k1 = past_key_value[2][:batch_size]
+            v1 = past_key_value[3][:batch_size]
+            past_key_values.append((k0, v0, k1, v1))
+
+        input_buffers = dict(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            past_key_values=past_key_values,
+            kvlen=kvlen,
+        )
+        return input_buffers
+
+    @staticmethod
+    def fill_input_buffers(
+        input_buffer: dict,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        ):
+        batch_size = input_ids.size(0)
+        kvlens = attention_mask.size(1)
+
+        input_buffer['input_ids'][:batch_size] = input_ids
+
+        if input_buffer['attention_mask'].data_ptr() != attention_mask.data_ptr():
+            input_buffer['attention_mask'].fill_(0)
+        input_buffer['attention_mask'][:batch_size, :kvlens] = attention_mask
+        input_buffer['encoder_hidden_states'][:batch_size] = encoder_hidden_states
+
+        if past_key_values is not None:
+            for buf_kv, kv in zip(input_buffer['past_key_values'], past_key_values):
+                idx = 0
+                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
+                    buf_kv[idx].fill_(0)
+                    buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
+                idx = 1
+                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
+                    buf_kv[idx].fill_(0)
+                    buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
+
+                idx = 2
+                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
+                    buf_kv[idx].fill_(0)
+                    buf_kv[idx][:batch_size] = kv[idx]
+                idx = 3
+                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
+                    buf_kv[idx].fill_(0)
+                    buf_kv[idx][:batch_size] = kv[idx]
+
+        input_buffer['kvlen'].fill_(kvlens - 1)
+
+    @classmethod
+    @torch.inference_mode()
+    def capture(cls,
+                model: nn.Module,
+                input_buffers: dict,
+                pool,
+                warmup: bool = False,
+                input_ids: torch.LongTensor = None,
+                attention_mask: Optional[torch.Tensor] = None,
+                count_pred: Optional[torch.FloatTensor] = None,
+                encoder_hidden_states: Optional[torch.FloatTensor] = None,
+                encoder_attention_mask: Optional[torch.LongTensor] = None,
+                head_mask: Optional[torch.Tensor] = None,
+                cross_attn_head_mask: Optional[torch.Tensor] = None,
+                past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+                inputs_embeds: Optional[torch.FloatTensor] = None,
+                use_cache: Optional[bool] = None,
+                output_attentions: Optional[bool] = None,
+                output_hidden_states: Optional[bool] = None,
+                return_dict: Optional[bool] = None,):
+        batch_size = input_ids.size(0)
+        kvlens = attention_mask.size(1)
+
+        graph_key = get_graph_key(batch_size, kvlens)
+        batch_size = graph_key[0]
+        kvlens = graph_key[1]
+
+        input_buffers = cls.extract_input_buffers(input_buffers,
+                                                  batch_size=batch_size,
+                                                  kvlens=kvlens)
+        cls.fill_input_buffers(input_buffers,
+                               input_ids,
+                               attention_mask,
+                               encoder_hidden_states,
+                               past_key_values)
+
+        input_ids = input_buffers['input_ids']
+        attention_mask = input_buffers['attention_mask']
+        encoder_hidden_states = input_buffers['encoder_hidden_states']
+        past_key_values = input_buffers['past_key_values']
+
+        if warmup:
+            # warmup
+            model(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict)
+
+        graph = torch.cuda.CUDAGraph()
+        with torch.cuda.graph(graph,
+                              pool=pool):
+            outputs = model(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict)
+
+        output_buffers = dict(
+            last_hidden_state=outputs['last_hidden_state'],
+            past_key_values=outputs['past_key_values'],
+            )
+
+        return GraphRunnerImpl(model, graph, input_buffers, output_buffers)
+
+    def __call__(self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        count_pred: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        ):
+        batch_size = input_ids.size(0)
+        kvlens = attention_mask.size(1)
+        self.fill_input_buffers(self.input_buffers,
+                               input_ids,
+                               attention_mask,
+                               encoder_hidden_states,
+                               past_key_values)
+
+        self.graph.replay()
+
+        last_hidden_state = self.output_buffers['last_hidden_state'][:batch_size]
+
+        past_key_values = []
+        for past_key_value in self.output_buffers['past_key_values']:
+            k0 = past_key_value[0][:batch_size, :, :kvlens]
+            v0 = past_key_value[1][:batch_size, :, :kvlens]
+            k1 = past_key_value[2][:batch_size]
+            v1 = past_key_value[3][:batch_size]
+            past_key_values.append((k0, v0, k1, v1))
+
+        outputs = BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=last_hidden_state,
+            past_key_values=past_key_values,
+        )
+        return outputs
+
+class GraphRunner(nn.Module):
+
+    def __init__(self, model: nn.Module, max_batchs: int, max_kvlens: int, dtype:torch.dtype = torch.float16, device: torch.device = 'cuda'):
+        super().__init__()
+
+        self.kvlen = torch.tensor(0, dtype=torch.long, device=device)
+        model = patch_model(model.to(dtype), self.kvlen)
+        self.model = model
+        self.max_batchs = max_batchs
+        self.max_kvlens = max_kvlens
+        self.device = device
+
+        self.input_buffers = None
+
+        self.impl_map = dict()
+        self.graph_pool_handle = torch.cuda.graph_pool_handle()
+        self.warmuped = False
+
+    def create_buffers(self, encoder_kvlens: int, dtype: torch.dtype):
+        max_batchs = self.max_batchs
+        max_kvlens = self.max_kvlens
+        device = self.device
+        config = self.model.config
+
+        d_model = config.d_model
+        decoder_layers = config.decoder_layers
+        num_heads = config.decoder_attention_heads
+
+        head_dim = d_model // num_heads
+        self_attn = self.model.layers[0].self_attn
+        qk_head_dim = getattr(self_attn, 'squeeze_head_dim', head_dim)
+
+        input_ids = torch.ones((max_batchs, 1), dtype=torch.int64, device=device)
+        attention_mask = torch.zeros((max_batchs, max_kvlens), dtype=torch.int64, device=device)
+        encoder_hidden_states = torch.zeros((max_batchs, encoder_kvlens, d_model), dtype=dtype, device=device)
+
+        past_key_values = []
+        for _ in range(decoder_layers):
+            k0 = torch.zeros((max_batchs, num_heads, max_kvlens, qk_head_dim), dtype=dtype, device=device)
+            v0 = torch.zeros((max_batchs, num_heads, max_kvlens, head_dim), dtype=dtype, device=device)
+            k1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, qk_head_dim), dtype=dtype, device=device)
+            v1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, head_dim), dtype=dtype, device=device)
+
+            past_key_values.append((k0, v0, k1, v1))
+
+        self.input_buffers = dict(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            past_key_values=past_key_values,
+            kvlen=self.kvlen
+        )
+
+    @torch.inference_mode()
+    def forward(self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        count_pred: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        ):
+        batch_size, qlens = input_ids.size()
+        kvlens = attention_mask.size(1)
+
+        eager_mode = False
+
+        if qlens != 1:
+            eager_mode = True
+
+        if past_key_values is None:
+            eager_mode = True
+        else:
+            for past_key_value in past_key_values:
+                if past_key_value is None:
+                    eager_mode = True
+                    break
+
+        if batch_size >= self.max_batchs or kvlens >= self.max_kvlens:
+            eager_mode = True
+
+        if eager_mode:
+            return self.model(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,)
+
+        # create buffer if not exists.
+        if self.input_buffers is None:
+            encoder_kvlens = encoder_hidden_states.size(1)
+            self.create_buffers(encoder_kvlens=encoder_kvlens, dtype=encoder_hidden_states.dtype)
+
+        graph_key = get_graph_key(batch_size, kvlens)
+        if graph_key not in self.impl_map:
+            warmup = False
+            if not self.warmuped:
+                warmup = True
+                self.warmuped = True
+            impl = GraphRunnerImpl.capture(
+                self.model,
+                self.input_buffers,
+                self.graph_pool_handle,
+                warmup=warmup,
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+            self.impl_map[graph_key] = impl
+        impl = self.impl_map[graph_key]
+
+        ret = impl(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                count_pred=count_pred,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                cross_attn_head_mask=cross_attn_head_mask,
+                past_key_values=past_key_values,
+                inputs_embeds=inputs_embeds,
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+        )
+        return ret

+ 6 - 0
magic_pdf/model/pdf_extract_kit.py

@@ -5,6 +5,7 @@ import time
 from magic_pdf.libs.Constants import *
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.model.model_list import AtomicModel
+from .mfr_cudagraph import GraphRunner
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
@@ -67,6 +68,11 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
     model = task.build_model(cfg)
     model.to(_device_)
     model.eval()
+    model = model.to(_device_)
+    if 'cuda' in _device_:
+        decoder_runner = GraphRunner(model.model.model.decoder.model.decoder, max_batchs=128, max_kvlens=256,
+                                     device=_device_)
+        model.model.model.decoder.model.decoder = decoder_runner
     vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
     mfr_transform = transforms.Compose([vis_processor, ])
     return [model, mfr_transform]