فهرست منبع

perf: table model update with PP OCRv4

liukaiwen 1 سال پیش
والد
کامیت
4949408c9d
2فایلهای تغییر یافته به همراه0 افزوده شده و 905 حذف شده
  1. 0 899
      magic_pdf/model/mfr_cudagraph.py
  2. 0 6
      magic_pdf/model/pdf_extract_kit.py

+ 0 - 899
magic_pdf/model/mfr_cudagraph.py

@@ -1,899 +0,0 @@
-from typing import Optional, Tuple, Union
-import torch
-from torch import nn
-import os
-from unimernet.common.config import Config
-import unimernet.tasks as tasks
-import argparse
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
-from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
-
-class PatchedMBartLearnedPositionalEmbedding(nn.Module):
-
-    def __init__(self, origin: nn.Module):
-        super().__init__()
-        self.offset = origin.offset
-        self.embedding = nn.Embedding(origin.num_embeddings, origin.embedding_dim)
-        self.embedding.weight.data = origin.weight.data
-
-    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
-        """`input_ids' shape is expected to be [bsz x seqlen]."""
-
-        bsz, seq_len = input_ids.shape[:2]
-        positions = torch.arange(0, seq_len, dtype=torch.long, device=self.embedding.weight.device
-        )
-        positions += past_key_values_length
-        positions = positions.expand(bsz, -1)
-
-        return self.embedding(positions + self.offset)
-
-
-class PatchedMBartDecoder(nn.Module):
-    def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
-        super().__init__()
-        self.origin = origin
-        self.kvlen = kvlen
-
-        self.config = origin.config
-        self.embed_tokens = origin.embed_tokens
-        self.embed_scale = origin.embed_scale
-        self._use_flash_attention_2 = origin._use_flash_attention_2
-        self.embed_positions = origin.embed_positions
-        self.counting_context_weight = getattr(origin, 'counting_context_weight', None)
-        self.layernorm_embedding = origin.layernorm_embedding
-        self.layers = origin.layers
-        self.layer_norm = origin.layer_norm
-
-        self.patched_embed_positions = PatchedMBartLearnedPositionalEmbedding(self.embed_positions)
-
-    def forward(
-        self,
-        input_ids: torch.LongTensor = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        count_pred: Optional[torch.FloatTensor] = None,
-        encoder_hidden_states: Optional[torch.FloatTensor] = None,
-        encoder_attention_mask: Optional[torch.LongTensor] = None,
-        head_mask: Optional[torch.Tensor] = None,
-        cross_attn_head_mask: Optional[torch.Tensor] = None,
-        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
-        inputs_embeds: Optional[torch.FloatTensor] = None,
-        use_cache: Optional[bool] = None,
-        output_attentions: Optional[bool] = None,
-        output_hidden_states: Optional[bool] = None,
-        return_dict: Optional[bool] = None,
-    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
-        run_origin = False
-        if past_key_values is None:
-            run_origin = True
-        elif past_key_values[0][0].size(-2) < attention_mask.size(-1):
-            run_origin = True
-
-        if run_origin:
-            return self.origin(
-                input_ids=input_ids,
-                attention_mask=attention_mask,
-                count_pred=count_pred,
-                encoder_hidden_states=encoder_hidden_states,
-                encoder_attention_mask=encoder_attention_mask,
-                head_mask=head_mask,
-                cross_attn_head_mask=cross_attn_head_mask,
-                past_key_values=past_key_values,
-                inputs_embeds=inputs_embeds,
-                use_cache=use_cache,
-                output_attentions=output_attentions,
-                output_hidden_states=output_hidden_states,
-                return_dict=return_dict,
-            )
-
-        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-        output_hidden_states = (
-            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
-        )
-        use_cache = use_cache if use_cache is not None else self.config.use_cache
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        # retrieve input_ids and inputs_embeds
-        if input_ids is not None and inputs_embeds is not None:
-            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-        elif input_ids is not None:
-            input = input_ids
-            input_shape = input.size()
-            input_ids = input_ids.view(-1, input_shape[-1])
-        elif inputs_embeds is not None:
-            input_shape = inputs_embeds.size()[:-1]
-            input = inputs_embeds[:, :, -1]
-        else:
-            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
-
-        # past_key_values_length
-        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
-
-        if inputs_embeds is None:
-            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
-
-        if self._use_flash_attention_2:
-            # 2d mask is passed through the layers
-            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
-        else:
-            # 4d mask is passed through the layers
-            attention_mask = _prepare_4d_causal_attention_mask(
-                attention_mask, input_shape, inputs_embeds, past_key_values_length
-            )
-
-        # expand encoder attention mask
-        if encoder_hidden_states is not None and encoder_attention_mask is not None:
-            if self._use_flash_attention_2:
-                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
-            else:
-                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
-                encoder_attention_mask = _prepare_4d_attention_mask(
-                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
-                )
-
-        # embed positions
-        positions = self.patched_embed_positions(input, self.kvlen)
-
-        hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
-
-        # TODO: add counting context weight to hidden_states
-        if count_pred is not None:
-            count_context_weight = self.counting_context_weight(count_pred)
-            hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
-        hidden_states = self.layernorm_embedding(hidden_states)
-
-        # decoder layers
-        all_hidden_states = () if output_hidden_states else None
-        all_self_attns = () if output_attentions else None
-        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
-        next_decoder_cache = () if use_cache else None
-
-        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
-        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
-            if attn_mask is not None:
-                if attn_mask.size()[0] != len(self.layers):
-                    raise ValueError(
-                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
-                        f" {attn_mask.size()[0]}."
-                    )
-        for idx, decoder_layer in enumerate(self.layers):
-            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
-            if output_hidden_states:
-                all_hidden_states += (hidden_states,)
-
-            past_key_value = past_key_values[idx] if past_key_values is not None else None
-            layer_outputs = decoder_layer(
-                hidden_states,
-                attention_mask=attention_mask,
-                encoder_hidden_states=encoder_hidden_states,
-                encoder_attention_mask=encoder_attention_mask,
-                layer_head_mask=(head_mask[idx] if head_mask is not None else None),
-                cross_attn_layer_head_mask=(
-                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
-                ),
-                past_key_value=past_key_value,
-                output_attentions=output_attentions,
-                use_cache=use_cache,
-            )
-            hidden_states = layer_outputs[0]
-
-            if use_cache:
-                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
-
-            if output_attentions:
-                all_self_attns += (layer_outputs[1],)
-
-                if encoder_hidden_states is not None:
-                    all_cross_attentions += (layer_outputs[2],)
-
-        hidden_states = self.layer_norm(hidden_states)
-
-        # add hidden states from the last decoder layer
-        if output_hidden_states:
-            all_hidden_states += (hidden_states,)
-
-        next_cache = next_decoder_cache if use_cache else None
-        if not return_dict:
-            return tuple(
-                v
-                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
-                if v is not None
-            )
-        return BaseModelOutputWithPastAndCrossAttentions(
-            last_hidden_state=hidden_states,
-            past_key_values=next_cache,
-            hidden_states=all_hidden_states,
-            attentions=all_self_attns,
-            cross_attentions=all_cross_attentions,
-        )
-
-
-class PatchedMBartAttention(nn.Module):
-
-    def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
-        super().__init__()
-        self.embed_dim = origin.embed_dim
-        self.num_heads = origin.num_heads
-        self.dropout = origin.dropout
-        self.head_dim = origin.head_dim
-        self.config = origin.config
-
-        self.scaling = origin.scaling
-        self.is_decoder = origin.is_decoder
-        self.is_causal = origin.is_causal
-
-        self.k_proj = origin.k_proj
-        self.v_proj = origin.v_proj
-        self.q_proj = origin.q_proj
-        self.out_proj = origin.out_proj
-        self.kvlen = kvlen
-
-    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
-        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
-    def forward(
-        self,
-        hidden_states: torch.Tensor,
-        key_value_states: Optional[torch.Tensor] = None,
-        past_key_value: Optional[Tuple[torch.Tensor]] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        layer_head_mask: Optional[torch.Tensor] = None,
-        output_attentions: bool = False,
-    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-        """Input shape: Batch x Time x Channel"""
-
-        # if key_value_states are provided this layer is used as a cross-attention layer
-        # for the decoder
-        is_cross_attention = key_value_states is not None
-
-        bsz, tgt_len, _ = hidden_states.size()
-
-        # get query proj
-        query_states = self.q_proj(hidden_states) * self.scaling
-        # get key, value proj
-        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
-        # is checking that the `sequence_length` of the `past_key_value` is the same as
-        # the provided `key_value_states` to support prefix tuning
-        if (
-            is_cross_attention
-            and past_key_value is not None
-            and past_key_value[0].shape[2] == key_value_states.shape[1]
-        ):
-            # reuse k,v, cross_attentions
-            key_states = past_key_value[0]
-            value_states = past_key_value[1]
-        elif is_cross_attention:
-            # cross_attentions
-            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
-            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
-        elif past_key_value is not None:
-            # reuse k, v, self_attention
-            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
-            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
-            if past_key_value[0].size(-2) < attention_mask.size(-1):
-                key_states = torch.cat([past_key_value[0], key_states], dim=2)
-                value_states = torch.cat([past_key_value[1], value_states], dim=2)
-            else:
-                past_key_value[0][:, :, self.kvlen[None]] = key_states
-                past_key_value[1][:, :, self.kvlen[None]] = value_states
-                key_states = past_key_value[0]
-                value_states = past_key_value[1]
-        else:
-            # self_attention
-            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
-            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
-        if self.is_decoder:
-            past_key_value = (key_states, value_states)
-
-        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
-        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
-        key_states = key_states.reshape(*proj_shape)
-        value_states = value_states.reshape(*proj_shape)
-
-        src_len = key_states.size(1)
-        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
-        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
-            raise ValueError(
-                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
-                f" {attn_weights.size()}"
-            )
-
-        if attention_mask is not None:
-            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
-                raise ValueError(
-                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
-                )
-            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
-            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
-        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
-        if layer_head_mask is not None:
-            if layer_head_mask.size() != (self.num_heads,):
-                raise ValueError(
-                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
-                    f" {layer_head_mask.size()}"
-                )
-            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
-            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
-        if output_attentions:
-            # this operation is a bit awkward, but it's required to
-            # make sure that attn_weights keeps its gradient.
-            # In order to do so, attn_weights have to be reshaped
-            # twice and have to be reused in the following
-            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
-            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
-        else:
-            attn_weights_reshaped = None
-
-        attn_probs = attn_weights
-
-        attn_output = torch.bmm(attn_probs, value_states)
-
-        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
-            raise ValueError(
-                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
-                f" {attn_output.size()}"
-            )
-
-        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
-        attn_output = attn_output.transpose(1, 2)
-
-        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
-        # partitioned across GPUs when using tensor-parallelism.
-        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
-        # attn_output = self.out_proj(attn_output)
-        attn_output = self.out_proj(attn_output)
-
-        return attn_output, attn_weights_reshaped, past_key_value
-
-
-class PatchedMBartSqueezeAttention(nn.Module):
-
-    def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
-        super().__init__()
-        self.embed_dim = origin.embed_dim
-        self.num_heads = origin.num_heads
-        self.dropout = origin.dropout
-        self.head_dim = origin.head_dim
-        self.squeeze_head_dim=origin.squeeze_head_dim
-        self.config = origin.config
-
-        self.scaling = origin.scaling
-        self.is_decoder = origin.is_decoder
-        self.scaling = origin.scaling
-
-        self.q_proj = origin.q_proj
-        self.k_proj = origin.k_proj
-        self.v_proj = origin.v_proj
-        self.out_proj = origin.out_proj
-        self.kvlen = kvlen
-
-    def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
-        return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous()
-
-    def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
-        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
-    def forward(
-        self,
-        hidden_states: torch.Tensor,
-        key_value_states: Optional[torch.Tensor] = None,
-        past_key_value: Optional[Tuple[torch.Tensor]] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        layer_head_mask: Optional[torch.Tensor] = None,
-        output_attentions: bool = False,
-    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-        """Input shape: Batch x Time x Channel"""
-
-        # if key_value_states are provided this layer is used as a cross-attention layer
-        # for the decoder
-        is_cross_attention = key_value_states is not None
-
-        bsz, tgt_len, _ = hidden_states.size()
-
-        # get query proj
-        query_states = self.q_proj(hidden_states) * self.scaling
-        # get key, value proj
-        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
-        # is checking that the `sequence_length` of the `past_key_value` is the same as
-        # the provided `key_value_states` to support prefix tuning
-        if (
-            is_cross_attention
-            and past_key_value is not None
-            and past_key_value[0].shape[2] == key_value_states.shape[1]
-        ):
-            # reuse k,v, cross_attentions
-            key_states = past_key_value[0]
-            value_states = past_key_value[1]
-        elif is_cross_attention:
-            # cross_attentions
-            key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
-            value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
-        elif past_key_value is not None:
-            # reuse k, v, self_attention
-            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
-            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
-
-            if past_key_value[0].size(-2) < attention_mask.size(-1):
-                key_states = torch.cat([past_key_value[0], key_states], dim=2)
-                value_states = torch.cat([past_key_value[1], value_states], dim=2)
-            else:
-                past_key_value[0][:, :, self.kvlen[None]] = key_states
-                past_key_value[1][:, :, self.kvlen[None]] = value_states
-                key_states = past_key_value[0]
-                value_states = past_key_value[1]
-        else:
-            # self_attention
-            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
-            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
-
-        if self.is_decoder:
-            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
-            # Further calls to cross_attention layer can then reuse all cross-attention
-            # key/value_states (first "if" case)
-            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
-            # all previous decoder key/value_states. Further calls to uni-directional self-attention
-            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
-            # if encoder bi-directional self-attention `past_key_value` is always `None`
-            past_key_value = (key_states, value_states)
-
-        proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim)
-        value_shape = (bsz * self.num_heads, -1, self.head_dim)
-        query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape)
-        key_states = key_states.reshape(*proj_shape)
-        value_states = value_states.reshape(*value_shape)
-
-        src_len = key_states.size(1)
-        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
-        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
-            raise ValueError(
-                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
-                f" {attn_weights.size()}"
-            )
-
-        if attention_mask is not None:
-            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
-                raise ValueError(
-                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
-                )
-            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
-            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
-        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
-        if layer_head_mask is not None:
-            if layer_head_mask.size() != (self.num_heads,):
-                raise ValueError(
-                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
-                    f" {layer_head_mask.size()}"
-                )
-            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
-            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
-        if output_attentions:
-            # this operation is a bit awkward, but it's required to
-            # make sure that attn_weights keeps its gradient.
-            # In order to do so, attn_weights have to be reshaped
-            # twice and have to be reused in the following
-            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
-            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
-        else:
-            attn_weights_reshaped = None
-
-        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
-        attn_output = torch.bmm(attn_probs, value_states)
-
-        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
-            raise ValueError(
-                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
-                f" {attn_output.size()}"
-            )
-
-        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
-        attn_output = attn_output.transpose(1, 2)
-
-        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
-        # partitioned across GPUs when using tensor-parallelism.
-        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
-        attn_output = self.out_proj(attn_output)
-
-        return attn_output, attn_weights_reshaped, past_key_value
-
-def patch_model(model: nn.Module, kvlen: torch.LongTensor):
-    for name, child in model.named_children():
-        cls_name = type(child).__name__
-        if cls_name == 'MBartAttention':
-            patched_child = PatchedMBartAttention(child, kvlen)
-            model.register_module(name, patched_child)
-        elif cls_name == 'MBartSqueezeAttention':
-            patched_child = PatchedMBartSqueezeAttention(child, kvlen)
-            model.register_module(name, patched_child)
-        else:
-            patch_model(child, kvlen)
-
-    cls_name = type(model).__name__
-    if cls_name == 'CustomMBartDecoder':
-        model = PatchedMBartDecoder(model, kvlen)
-    return model
-
-
-def next_power_of_2(n: int):
-    """Return the smallest power of 2 greater than or equal to n."""
-    n -= 1
-    n |= n >> 1
-    n |= n >> 2
-    n |= n >> 4
-    n |= n >> 8
-    n |= n >> 16
-    n |= n >> 32
-    n += 1
-    return n
-
-
-def get_graph_key(batch_size: int, kvlens: int):
-    batch_size = next_power_of_2(batch_size)
-    kvlens = next_power_of_2(kvlens)
-
-    batch_size = max(8, batch_size)
-    kvlens = max(32, kvlens)
-
-    return batch_size, kvlens
-
-
-class GraphRunnerImpl:
-
-    def __init__(self, model: nn.Module, graph: torch.cuda.CUDAGraph, input_buffers: dict, output_buffers: dict):
-        self.model = model
-        self.graph = graph
-        self.input_buffers = input_buffers
-        self.output_buffers = output_buffers
-
-    @staticmethod
-    def extract_input_buffers(input_buffers: dict, batch_size: int, kvlens: int):
-        input_ids = input_buffers['input_ids'][:batch_size]
-        attention_mask = input_buffers['attention_mask'][:batch_size, :kvlens]
-        encoder_hidden_states = input_buffers['encoder_hidden_states'][:batch_size]
-        kvlen=input_buffers['kvlen']
-
-        past_key_values = []
-        for past_key_value in input_buffers['past_key_values']:
-            k0 = past_key_value[0][:batch_size, :, :kvlens]
-            v0 = past_key_value[1][:batch_size, :, :kvlens]
-            k1 = past_key_value[2][:batch_size]
-            v1 = past_key_value[3][:batch_size]
-            past_key_values.append((k0, v0, k1, v1))
-
-        input_buffers = dict(
-            input_ids=input_ids,
-            attention_mask=attention_mask,
-            encoder_hidden_states=encoder_hidden_states,
-            past_key_values=past_key_values,
-            kvlen=kvlen,
-        )
-        return input_buffers
-
-    @staticmethod
-    def fill_input_buffers(
-        input_buffer: dict,
-        input_ids: torch.LongTensor = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        encoder_hidden_states: Optional[torch.FloatTensor] = None,
-        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
-        ):
-        batch_size = input_ids.size(0)
-        kvlens = attention_mask.size(1)
-
-        input_buffer['input_ids'][:batch_size] = input_ids
-
-        if input_buffer['attention_mask'].data_ptr() != attention_mask.data_ptr():
-            input_buffer['attention_mask'].fill_(0)
-        input_buffer['attention_mask'][:batch_size, :kvlens] = attention_mask
-        input_buffer['encoder_hidden_states'][:batch_size] = encoder_hidden_states
-
-        if past_key_values is not None:
-            for buf_kv, kv in zip(input_buffer['past_key_values'], past_key_values):
-                idx = 0
-                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
-                    buf_kv[idx].fill_(0)
-                    buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
-                idx = 1
-                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
-                    buf_kv[idx].fill_(0)
-                    buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
-
-                idx = 2
-                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
-                    buf_kv[idx].fill_(0)
-                    buf_kv[idx][:batch_size] = kv[idx]
-                idx = 3
-                if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
-                    buf_kv[idx].fill_(0)
-                    buf_kv[idx][:batch_size] = kv[idx]
-
-        input_buffer['kvlen'].fill_(kvlens - 1)
-
-    @classmethod
-    @torch.inference_mode()
-    def capture(cls,
-                model: nn.Module,
-                input_buffers: dict,
-                pool,
-                warmup: bool = False,
-                input_ids: torch.LongTensor = None,
-                attention_mask: Optional[torch.Tensor] = None,
-                count_pred: Optional[torch.FloatTensor] = None,
-                encoder_hidden_states: Optional[torch.FloatTensor] = None,
-                encoder_attention_mask: Optional[torch.LongTensor] = None,
-                head_mask: Optional[torch.Tensor] = None,
-                cross_attn_head_mask: Optional[torch.Tensor] = None,
-                past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
-                inputs_embeds: Optional[torch.FloatTensor] = None,
-                use_cache: Optional[bool] = None,
-                output_attentions: Optional[bool] = None,
-                output_hidden_states: Optional[bool] = None,
-                return_dict: Optional[bool] = None,):
-        batch_size = input_ids.size(0)
-        kvlens = attention_mask.size(1)
-
-        graph_key = get_graph_key(batch_size, kvlens)
-        batch_size = graph_key[0]
-        kvlens = graph_key[1]
-
-        input_buffers = cls.extract_input_buffers(input_buffers,
-                                                  batch_size=batch_size,
-                                                  kvlens=kvlens)
-        cls.fill_input_buffers(input_buffers,
-                               input_ids,
-                               attention_mask,
-                               encoder_hidden_states,
-                               past_key_values)
-
-        input_ids = input_buffers['input_ids']
-        attention_mask = input_buffers['attention_mask']
-        encoder_hidden_states = input_buffers['encoder_hidden_states']
-        past_key_values = input_buffers['past_key_values']
-
-        if warmup:
-            # warmup
-            model(
-                input_ids=input_ids,
-                attention_mask=attention_mask,
-                count_pred=count_pred,
-                encoder_hidden_states=encoder_hidden_states,
-                encoder_attention_mask=encoder_attention_mask,
-                head_mask=head_mask,
-                cross_attn_head_mask=cross_attn_head_mask,
-                past_key_values=past_key_values,
-                inputs_embeds=inputs_embeds,
-                use_cache=use_cache,
-                output_attentions=output_attentions,
-                output_hidden_states=output_hidden_states,
-                return_dict=return_dict)
-
-        graph = torch.cuda.CUDAGraph()
-        with torch.cuda.graph(graph,
-                              pool=pool):
-            outputs = model(
-                input_ids=input_ids,
-                attention_mask=attention_mask,
-                count_pred=count_pred,
-                encoder_hidden_states=encoder_hidden_states,
-                encoder_attention_mask=encoder_attention_mask,
-                head_mask=head_mask,
-                cross_attn_head_mask=cross_attn_head_mask,
-                past_key_values=past_key_values,
-                inputs_embeds=inputs_embeds,
-                use_cache=use_cache,
-                output_attentions=output_attentions,
-                output_hidden_states=output_hidden_states,
-                return_dict=return_dict)
-
-        output_buffers = dict(
-            last_hidden_state=outputs['last_hidden_state'],
-            past_key_values=outputs['past_key_values'],
-            )
-
-        return GraphRunnerImpl(model, graph, input_buffers, output_buffers)
-
-    def __call__(self,
-        input_ids: torch.LongTensor = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        count_pred: Optional[torch.FloatTensor] = None,
-        encoder_hidden_states: Optional[torch.FloatTensor] = None,
-        encoder_attention_mask: Optional[torch.LongTensor] = None,
-        head_mask: Optional[torch.Tensor] = None,
-        cross_attn_head_mask: Optional[torch.Tensor] = None,
-        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
-        inputs_embeds: Optional[torch.FloatTensor] = None,
-        use_cache: Optional[bool] = None,
-        output_attentions: Optional[bool] = None,
-        output_hidden_states: Optional[bool] = None,
-        return_dict: Optional[bool] = None,
-        ):
-        batch_size = input_ids.size(0)
-        kvlens = attention_mask.size(1)
-        self.fill_input_buffers(self.input_buffers,
-                               input_ids,
-                               attention_mask,
-                               encoder_hidden_states,
-                               past_key_values)
-
-        self.graph.replay()
-
-        last_hidden_state = self.output_buffers['last_hidden_state'][:batch_size]
-
-        past_key_values = []
-        for past_key_value in self.output_buffers['past_key_values']:
-            k0 = past_key_value[0][:batch_size, :, :kvlens]
-            v0 = past_key_value[1][:batch_size, :, :kvlens]
-            k1 = past_key_value[2][:batch_size]
-            v1 = past_key_value[3][:batch_size]
-            past_key_values.append((k0, v0, k1, v1))
-
-        outputs = BaseModelOutputWithPastAndCrossAttentions(
-            last_hidden_state=last_hidden_state,
-            past_key_values=past_key_values,
-        )
-        return outputs
-
-class GraphRunner(nn.Module):
-
-    def __init__(self, model: nn.Module, max_batchs: int, max_kvlens: int, dtype:torch.dtype = torch.float16, device: torch.device = 'cuda'):
-        super().__init__()
-
-        self.kvlen = torch.tensor(0, dtype=torch.long, device=device)
-        model = patch_model(model.to(dtype), self.kvlen)
-        self.model = model
-        self.max_batchs = max_batchs
-        self.max_kvlens = max_kvlens
-        self.device = device
-
-        self.input_buffers = None
-
-        self.impl_map = dict()
-        self.graph_pool_handle = torch.cuda.graph_pool_handle()
-        self.warmuped = False
-
-    def create_buffers(self, encoder_kvlens: int, dtype: torch.dtype):
-        max_batchs = self.max_batchs
-        max_kvlens = self.max_kvlens
-        device = self.device
-        config = self.model.config
-
-        d_model = config.d_model
-        decoder_layers = config.decoder_layers
-        num_heads = config.decoder_attention_heads
-
-        head_dim = d_model // num_heads
-        self_attn = self.model.layers[0].self_attn
-        qk_head_dim = getattr(self_attn, 'squeeze_head_dim', head_dim)
-
-        input_ids = torch.ones((max_batchs, 1), dtype=torch.int64, device=device)
-        attention_mask = torch.zeros((max_batchs, max_kvlens), dtype=torch.int64, device=device)
-        encoder_hidden_states = torch.zeros((max_batchs, encoder_kvlens, d_model), dtype=dtype, device=device)
-
-        past_key_values = []
-        for _ in range(decoder_layers):
-            k0 = torch.zeros((max_batchs, num_heads, max_kvlens, qk_head_dim), dtype=dtype, device=device)
-            v0 = torch.zeros((max_batchs, num_heads, max_kvlens, head_dim), dtype=dtype, device=device)
-            k1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, qk_head_dim), dtype=dtype, device=device)
-            v1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, head_dim), dtype=dtype, device=device)
-
-            past_key_values.append((k0, v0, k1, v1))
-
-        self.input_buffers = dict(
-            input_ids=input_ids,
-            attention_mask=attention_mask,
-            encoder_hidden_states=encoder_hidden_states,
-            past_key_values=past_key_values,
-            kvlen=self.kvlen
-        )
-
-    @torch.inference_mode()
-    def forward(self,
-        input_ids: torch.LongTensor = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        count_pred: Optional[torch.FloatTensor] = None,
-        encoder_hidden_states: Optional[torch.FloatTensor] = None,
-        encoder_attention_mask: Optional[torch.LongTensor] = None,
-        head_mask: Optional[torch.Tensor] = None,
-        cross_attn_head_mask: Optional[torch.Tensor] = None,
-        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
-        inputs_embeds: Optional[torch.FloatTensor] = None,
-        use_cache: Optional[bool] = None,
-        output_attentions: Optional[bool] = None,
-        output_hidden_states: Optional[bool] = None,
-        return_dict: Optional[bool] = None,
-        ):
-        batch_size, qlens = input_ids.size()
-        kvlens = attention_mask.size(1)
-
-        eager_mode = False
-
-        if qlens != 1:
-            eager_mode = True
-
-        if past_key_values is None:
-            eager_mode = True
-        else:
-            for past_key_value in past_key_values:
-                if past_key_value is None:
-                    eager_mode = True
-                    break
-
-        if batch_size >= self.max_batchs or kvlens >= self.max_kvlens:
-            eager_mode = True
-
-        if eager_mode:
-            return self.model(
-                input_ids=input_ids,
-                attention_mask=attention_mask,
-                count_pred=count_pred,
-                encoder_hidden_states=encoder_hidden_states,
-                encoder_attention_mask=encoder_attention_mask,
-                head_mask=head_mask,
-                cross_attn_head_mask=cross_attn_head_mask,
-                past_key_values=past_key_values,
-                inputs_embeds=inputs_embeds,
-                use_cache=use_cache,
-                output_attentions=output_attentions,
-                output_hidden_states=output_hidden_states,
-                return_dict=return_dict,)
-
-        # create buffer if not exists.
-        if self.input_buffers is None:
-            encoder_kvlens = encoder_hidden_states.size(1)
-            self.create_buffers(encoder_kvlens=encoder_kvlens, dtype=encoder_hidden_states.dtype)
-
-        graph_key = get_graph_key(batch_size, kvlens)
-        if graph_key not in self.impl_map:
-            warmup = False
-            if not self.warmuped:
-                warmup = True
-                self.warmuped = True
-            impl = GraphRunnerImpl.capture(
-                self.model,
-                self.input_buffers,
-                self.graph_pool_handle,
-                warmup=warmup,
-                input_ids=input_ids,
-                attention_mask=attention_mask,
-                count_pred=count_pred,
-                encoder_hidden_states=encoder_hidden_states,
-                encoder_attention_mask=encoder_attention_mask,
-                head_mask=head_mask,
-                cross_attn_head_mask=cross_attn_head_mask,
-                past_key_values=past_key_values,
-                inputs_embeds=inputs_embeds,
-                use_cache=use_cache,
-                output_attentions=output_attentions,
-                output_hidden_states=output_hidden_states,
-                return_dict=return_dict,
-            )
-            self.impl_map[graph_key] = impl
-        impl = self.impl_map[graph_key]
-
-        ret = impl(
-                input_ids=input_ids,
-                attention_mask=attention_mask,
-                count_pred=count_pred,
-                encoder_hidden_states=encoder_hidden_states,
-                encoder_attention_mask=encoder_attention_mask,
-                head_mask=head_mask,
-                cross_attn_head_mask=cross_attn_head_mask,
-                past_key_values=past_key_values,
-                inputs_embeds=inputs_embeds,
-                use_cache=use_cache,
-                output_attentions=output_attentions,
-                output_hidden_states=output_hidden_states,
-                return_dict=return_dict,
-        )
-        return ret

+ 0 - 6
magic_pdf/model/pdf_extract_kit.py

@@ -6,7 +6,6 @@ import shutil
 from magic_pdf.libs.Constants import *
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.model.model_list import AtomicModel
-from .mfr_cudagraph import GraphRunner
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
@@ -70,11 +69,6 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
     model = task.build_model(cfg)
     model.to(_device_)
     model.eval()
-    model = model.to(_device_)
-    if 'cuda' in _device_:
-        decoder_runner = GraphRunner(model.model.model.decoder.model.decoder, max_batchs=128, max_kvlens=256,
-                                     device=_device_)
-        model.model.model.decoder.model.decoder = decoder_runner
     vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
     mfr_transform = transforms.Compose([vis_processor, ])
     return [model, mfr_transform]