| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from functools import partial
- from typing import List, Optional, Tuple, Type
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from ...common.vlm.transformers.model_outputs import CausalLMOutputWithPast
- from .qwen2 import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
- class MLPBlock(paddle.nn.Layer):
- def __init__(
- self,
- embedding_dim: int,
- mlp_dim: int,
- act: Type[paddle.nn.Layer] = paddle.nn.GELU,
- ) -> None:
- super().__init__()
- self.lin1 = nn.Linear(embedding_dim, mlp_dim)
- self.lin2 = nn.Linear(mlp_dim, embedding_dim)
- self.act = act()
- def forward(self, x: paddle.Tensor) -> paddle.Tensor:
- return self.lin2(self.act(self.lin1(x)))
- class LayerNorm2d(paddle.nn.Layer):
- def __init__(self, num_channels: int, epsilon: float = 1e-06) -> None:
- super().__init__()
- self.weight = paddle.base.framework.EagerParamBase.from_tensor(
- tensor=paddle.ones(shape=num_channels)
- )
- self.bias = paddle.base.framework.EagerParamBase.from_tensor(
- tensor=paddle.zeros(shape=num_channels)
- )
- self.epsilon = epsilon
- def forward(self, x: paddle.Tensor) -> paddle.Tensor:
- u = x.mean(axis=1, keepdim=True)
- s = (x - u).pow(y=2).mean(axis=1, keepdim=True)
- x = (x - u) / paddle.sqrt(x=s + self.epsilon)
- x = self.weight[:, None, None] * x + self.bias[:, None, None]
- return x
- class ImageEncoderViT(paddle.nn.Layer):
- def __init__(
- self,
- img_size: int = 1024,
- patch_size: int = 16,
- in_chans: int = 3,
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.0,
- out_chans: int = 256,
- qkv_bias: bool = True,
- norm_layer: Type[nn.Layer] = nn.LayerNorm,
- act_layer: Type[nn.Layer] = nn.GELU,
- use_abs_pos: bool = True,
- use_rel_pos: bool = False,
- rel_pos_zero_init: bool = True,
- window_size: int = 0,
- global_attn_indexes: Tuple[int, ...] = (),
- ) -> None:
- """
- Args:
- img_size (int): Input image size.
- patch_size (int): Patch size.
- in_chans (int): Number of input image channels.
- embed_dim (int): Patch embedding dimension.
- depth (int): Depth of ViT.
- num_heads (int): Number of attention heads in each ViT block.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
- norm_layer (nn.Layer): Normalization layer.
- act_layer (nn.Layer): Activation layer.
- use_abs_pos (bool): If True, use absolute positional embeddings.
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
- window_size (int): Window size for window attention blocks.
- global_attn_indexes (list): Indexes for blocks using global attention.
- """
- super().__init__()
- self.img_size = img_size
- self.patch_embed = PatchEmbed(
- kernel_size=(patch_size, patch_size),
- stride=(patch_size, patch_size),
- in_chans=in_chans,
- embed_dim=embed_dim,
- )
- self.pos_embed: Optional[paddle.base.framework.EagerParamBase.from_tensor] = (
- None
- )
- if use_abs_pos:
- self.pos_embed = paddle.base.framework.EagerParamBase.from_tensor(
- tensor=paddle.zeros(
- shape=[1, img_size // patch_size, img_size // patch_size, embed_dim]
- )
- )
- self.blocks = paddle.nn.LayerList()
- for i in range(depth):
- block = Block(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- norm_layer=norm_layer,
- act_layer=act_layer,
- use_rel_pos=use_rel_pos,
- rel_pos_zero_init=rel_pos_zero_init,
- window_size=window_size if i not in global_attn_indexes else 0,
- input_size=(img_size // patch_size, img_size // patch_size),
- )
- self.blocks.append(block)
- self.neck = nn.Sequential(
- nn.Conv2D(
- embed_dim,
- out_chans,
- kernel_size=1,
- bias_attr=False,
- ),
- LayerNorm2d(out_chans),
- nn.Conv2D(
- out_chans,
- out_chans,
- kernel_size=3,
- padding=1,
- bias_attr=False,
- ),
- LayerNorm2d(out_chans),
- )
- self.net_2 = nn.Conv2D(
- 256, 512, kernel_size=3, stride=2, padding=1, bias_attr=False
- )
- self.net_3 = nn.Conv2D(
- 512, 1024, kernel_size=3, stride=2, padding=1, bias_attr=False
- )
- def forward(self, x: paddle.Tensor) -> paddle.Tensor:
- x = self.patch_embed(x)
- if self.pos_embed is not None:
- x = x + self.pos_embed
- for blk in self.blocks:
- x = blk(x)
- x = self.neck(x.transpose([0, 3, 1, 2]))
- x = self.net_2(x)
- x = self.net_3(x)
- return x
- class Block(paddle.nn.Layer):
- """Transformer blocks with support of window attention and residual propagation blocks"""
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- norm_layer: Type[nn.Layer] = nn.LayerNorm,
- act_layer: Type[nn.Layer] = nn.GELU,
- use_rel_pos: bool = False,
- rel_pos_zero_init: bool = True,
- window_size: int = 0,
- input_size: Optional[Tuple[int, int]] = None,
- ) -> None:
- """
- Args:
- dim (int): Number of input channels.
- num_heads (int): Number of attention heads in each ViT block.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
- norm_layer (nn.Layer): Normalization layer.
- act_layer (nn.Layer): Activation layer.
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
- window_size (int): Window size for window attention blocks. If it equals 0, then
- use global attention.
- input_size (tuple(int, int) or None): Input resolution for calculating the relative
- positional parameter size.
- """
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- use_rel_pos=use_rel_pos,
- rel_pos_zero_init=rel_pos_zero_init,
- input_size=input_size if window_size == 0 else (window_size, window_size),
- )
- self.norm2 = norm_layer(dim)
- self.mlp = MLPBlock(
- embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
- )
- self.window_size = window_size
- def forward(self, x: paddle.Tensor) -> paddle.Tensor:
- shortcut = x
- x = self.norm1(x)
- # Window partition
- if self.window_size > 0:
- H, W = x.shape[1], x.shape[2]
- x, pad_hw = window_partition(x, self.window_size)
- x = self.attn(x)
- # Reverse window partition
- if self.window_size > 0:
- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
- x = shortcut + x
- x = x + self.mlp(self.norm2(x))
- return x
- class Attention(paddle.nn.Layer):
- """Multi-head Attention block with relative position embeddings."""
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = True,
- use_rel_pos: bool = False,
- rel_pos_zero_init: bool = True,
- input_size: Optional[Tuple[int, int]] = None,
- ) -> None:
- """
- Args:
- dim (int): Number of input channels.
- num_heads (int): Number of attention heads.
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
- rel_pos (bool): If True, add relative positional embeddings to the attention map.
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
- input_size (tuple(int, int) or None): Input resolution for calculating the relative
- positional parameter size.
- """
- super().__init__()
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = head_dim**-0.5
- self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
- self.proj = nn.Linear(dim, dim)
- self.use_rel_pos = use_rel_pos
- if self.use_rel_pos:
- assert (
- input_size is not None
- ), "Input size must be provided if using relative positional encoding."
- self.rel_pos_h = paddle.base.framework.EagerParamBase.from_tensor(
- tensor=paddle.zeros(shape=[2 * input_size[0] - 1, head_dim])
- )
- self.rel_pos_w = paddle.base.framework.EagerParamBase.from_tensor(
- tensor=paddle.zeros(shape=[2 * input_size[1] - 1, head_dim])
- )
- def forward(self, x: paddle.Tensor) -> paddle.Tensor:
- B, H, W, _ = tuple(x.shape)
- qkv = (
- self.qkv(x)
- .reshape([B, H * W, 3, self.num_heads, -1])
- .transpose([2, 0, 3, 1, 4])
- )
- q, k, v = qkv.reshape([3, B * self.num_heads, H * W, -1]).unbind(axis=0)
- attn = (q * self.scale) @ k.transpose([0, 2, 1])
- if self.use_rel_pos:
- attn = add_decomposed_rel_pos(
- attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
- )
- attn = F.softmax(attn, axis=-1)
- x = (
- (attn @ v)
- .reshape([B, self.num_heads, H, W, -1])
- .transpose([0, 2, 3, 1, 4])
- .reshape([B, H, W, -1])
- )
- x = self.proj(x)
- return x
- def window_partition(
- x: paddle.Tensor, window_size: int
- ) -> Tuple[paddle.Tensor, Tuple[int, int]]:
- """
- Partition into non-overlapping windows with padding if needed.
- Args:
- x (tensor): input tokens with [B, H, W, C].
- window_size (int): window size.
- Returns:
- windows: windows after partition with [B * num_windows, window_size, window_size, C].
- (Hp, Wp): padded height and width before partition
- """
- B, H, W, C = tuple(x.shape)
- pad_h = (window_size - H % window_size) % window_size
- pad_w = (window_size - W % window_size) % window_size
- if pad_h > 0 or pad_w > 0:
- x = F.pad(x, pad=(0, pad_w, 0, pad_h), data_format="NHWC")
- Hp, Wp = H + pad_h, W + pad_w
- x = x.reshape(
- [B, Hp // window_size, window_size, Wp // window_size, window_size, C]
- )
- windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C])
- return windows, (Hp, Wp)
- def window_unpartition(
- windows: paddle.Tensor,
- window_size: int,
- pad_hw: Tuple[int, int],
- hw: Tuple[int, int],
- ) -> paddle.Tensor:
- """
- Window unpartition into original sequences and removing padding.
- Args:
- windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
- window_size (int): window size.
- pad_hw (Tuple): padded height and width (Hp, Wp).
- hw (Tuple): original height and width (H, W) before padding.
- Returns:
- x: unpartitioned sequences with [B, H, W, C].
- """
- Hp, Wp = pad_hw
- H, W = hw
- B = tuple(windows.shape)[0] // (Hp * Wp // window_size // window_size)
- x = windows.reshape(
- [B, Hp // window_size, Wp // window_size, window_size, window_size, -1]
- )
- x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, -1])
- if Hp > H or Wp > W:
- x = x[:, :H, :W, :]
- return x
- def get_rel_pos(q_size: int, k_size: int, rel_pos: paddle.Tensor) -> paddle.Tensor:
- """
- Get relative positional embeddings according to the relative positions of
- query and key sizes.
- Args:
- q_size (int): size of query q.
- k_size (int): size of key k.
- rel_pos (Tensor): relative position embeddings (L, C).
- Returns:
- Extracted positional embeddings according to relative positions.
- """
- max_rel_dist = int(2 * max(q_size, k_size) - 1)
- if tuple(rel_pos.shape)[0] != max_rel_dist:
- rel_pos_resized = paddle.nn.functional.interpolate(
- rel_pos.reshape([1, tuple(rel_pos.shape)[0], -1]).transpose([0, 2, 1]),
- size=max_rel_dist,
- mode="linear",
- )
- rel_pos_resized = rel_pos_resized.reshape([-1, max_rel_dist]).transpose([1, 0])
- else:
- rel_pos_resized = rel_pos
- q_coords = paddle.arange(end=q_size)[:, None] * max(k_size / q_size, 1.0)
- k_coords = paddle.arange(end=k_size)[None, :] * max(q_size / k_size, 1.0)
- relative_coords = q_coords - k_coords + (k_size - 1) * max(q_size / k_size, 1.0)
- return rel_pos_resized[relative_coords.astype(dtype="int64")]
- def add_decomposed_rel_pos(
- attn: paddle.Tensor,
- q: paddle.Tensor,
- rel_pos_h: paddle.Tensor,
- rel_pos_w: paddle.Tensor,
- q_size: Tuple[int, int],
- k_size: Tuple[int, int],
- ) -> paddle.Tensor:
- """
- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
- Args:
- attn (Tensor): attention map.
- q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
- rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
- rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
- q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
- k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
- Returns:
- attn (Tensor): attention map with added relative positional embeddings.
- """
- q_h, q_w = q_size
- k_h, k_w = k_size
- Rh = get_rel_pos(q_h, k_h, rel_pos_h)
- Rw = get_rel_pos(q_w, k_w, rel_pos_w)
- B, _, dim = tuple(q.shape)
- r_q = q.reshape([B, q_h, q_w, dim])
- rel_h = paddle.einsum("bhwc,hkc->bhwk", r_q, Rh)
- rel_w = paddle.einsum("bhwc,wkc->bhwk", r_q, Rw)
- attn = (
- attn.reshape([B, q_h, q_w, k_h, k_w])
- + rel_h[:, :, :, :, None]
- + rel_w[:, :, :, None, :]
- ).reshape([B, q_h * q_w, k_h * k_w])
- return attn
- class PatchEmbed(paddle.nn.Layer):
- """
- Image to Patch Embedding.
- """
- def __init__(
- self,
- kernel_size: Tuple[int, int] = (16, 16),
- stride: Tuple[int, int] = (16, 16),
- padding: Tuple[int, int] = (0, 0),
- in_chans: int = 3,
- embed_dim: int = 768,
- ) -> None:
- """
- Args:
- kernel_size (Tuple): kernel size of the projection layer.
- stride (Tuple): stride of the projection layer.
- padding (Tuple): padding size of the projection layer.
- in_chans (int): Number of input image channels.
- embed_dim (int): Patch embedding dimension.
- """
- super().__init__()
- self.proj = nn.Conv2D(
- in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
- )
- def forward(self, x: paddle.Tensor) -> paddle.Tensor:
- x = self.proj(x)
- # B C H W -> B H W C
- x = x.transpose([0, 2, 3, 1])
- return x
- DEFAULT_IMAGE_TOKEN = "<image>"
- DEFAULT_IMAGE_PATCH_TOKEN = "<imgpad>"
- DEFAULT_IM_START_TOKEN = "<img>"
- DEFAULT_IM_END_TOKEN = "</img>"
- class Qwen2LMHead(nn.Layer):
- def __init__(
- self,
- config,
- embedding_weights=None,
- transpose_y=False,
- tensor_parallel_output=1,
- ):
- super(Qwen2LMHead, self).__init__()
- self.config = config
- vocab_size = config.vocab_size
- self.transpose_y = transpose_y
- if transpose_y:
- # only for weight from embedding_weights
- if embedding_weights is not None:
- self.weight = embedding_weights
- else:
- self.weight = self.create_parameter(
- shape=[vocab_size, config.hidden_size],
- dtype=paddle.get_default_dtype(),
- )
- else:
- # for weight from model init
- self.weight = self.create_parameter(
- shape=[config.hidden_size, vocab_size],
- dtype=paddle.get_default_dtype(),
- )
- def forward(self, hidden_states, tensor_parallel_output=1):
- logits = paddle.matmul(hidden_states, self.weight, transpose_y=self.transpose_y)
- return logits
- class GOTConfig(Qwen2Config):
- model_type = "GOT"
- class GOTQwenModel(Qwen2Model):
- config_class = GOTConfig
- def __init__(self, config: Qwen2Config):
- super(GOTQwenModel, self).__init__(config)
- self.vision_tower_high = ImageEncoderViT(
- depth=12,
- embed_dim=768,
- img_size=1024,
- mlp_ratio=4,
- norm_layer=partial(paddle.nn.LayerNorm, epsilon=1e-6),
- num_heads=12,
- patch_size=16,
- qkv_bias=True,
- use_rel_pos=True,
- global_attn_indexes=[2, 5, 8, 11],
- window_size=14,
- out_chans=256,
- )
- self.mm_projector_vary = nn.Linear(1024, 1024)
- def forward(
- self,
- input_ids: paddle.Tensor = None,
- attention_mask: Optional[paddle.Tensor] = None,
- position_ids: Optional[paddle.Tensor] = None,
- past_key_values: Optional[List[paddle.Tensor]] = None,
- inputs_embeds: Optional[paddle.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- images: Optional[paddle.Tensor] = None,
- return_dict: Optional[bool] = None,
- ):
- # HACK: replace back original embeddings for LLaVA pretraining
- orig_embeds_params = getattr(self, "orig_embeds_params", None)
- if orig_embeds_params is not None:
- with paddle.no_grad():
- self.get_input_embeddings().weight[: -self.num_new_tokens] = (
- orig_embeds_params[: -self.num_new_tokens].data
- )
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- vision_tower_high = getattr(self, "vision_tower_high", None)
- if (
- vision_tower_high is not None
- and (input_ids.shape[1] != 1 or self.training)
- and images is not None
- ):
- use_im_start_end = getattr(self.config, "use_im_start_end", -1)
- im_patch_token = getattr(self.config, "im_patch_token", -1)
- im_start_token = getattr(self.config, "im_start_token", -1)
- im_end_token = getattr(self.config, "im_end_token", -1)
- im_patch_token = 151859
- im_start_token = 151857
- im_end_token = 151858
- image_features = []
- for image in images:
- if self.training:
- image = image[1]
- P, C, H, W = image.shape
- if P == 1:
- with paddle.set_grad_enabled(False):
- cnn_feature = vision_tower_high(image)
- cnn_feature = cnn_feature.flatten(2).transpose(
- [0, 2, 1]
- ) # 256*1024
- image_feature = self.mm_projector_vary(cnn_feature)
- image_features.append(image_feature)
- else:
- image_patches = paddle.unbind(image)
- image_patches_features = []
- for image_patch in image_patches:
- image_p = paddle.stack([image_patch])
- with paddle.set_grad_enabled(False):
- cnn_feature_p = vision_tower_high(image_p)
- cnn_feature_p = cnn_feature_p.flatten(2).transpose(
- [0, 2, 1]
- )
- image_feature_p = self.mm_projector_vary(cnn_feature_p)
- image_patches_features.append(image_feature_p)
- image_feature = paddle.concat(image_patches_features, axis=1)
- image_features.append(image_feature)
- dummy_image_features_2 = paddle.zeros(
- [256, 1024], dtype=inputs_embeds.dtype
- )
- dummy_image_features = dummy_image_features_2
- use_im_start_end = True
- new_input_embeds = []
- for cur_input_ids, cur_input_embeds, cur_image_features in zip(
- input_ids, inputs_embeds, image_features
- ):
- if (cur_input_ids == im_patch_token).sum() == 0:
- # multimodal LLM, but the current sample is not multimodal
- cur_input_embeds = (
- cur_input_embeds + (0.0 * dummy_image_features).sum()
- )
- new_input_embeds.append(cur_input_embeds)
- continue
- if use_im_start_end:
- if (cur_input_ids == im_start_token).sum() != (
- cur_input_ids == im_end_token
- ).sum():
- raise ValueError(
- "The number of image start tokens and image end tokens should be the same."
- )
- image_start_tokens = paddle.where(cur_input_ids == im_start_token)[
- 0
- ]
- for image_start_token_pos, per_cur_image_features in zip(
- image_start_tokens, cur_image_features
- ):
- num_patches = per_cur_image_features.shape[0]
- if (
- cur_input_ids[image_start_token_pos + num_patches + 1]
- != im_end_token
- ):
- raise ValueError(
- "The image end token should follow the image start token."
- )
- cur_input_embeds = paddle.concat(
- (
- cur_input_embeds[: image_start_token_pos + 1],
- per_cur_image_features,
- cur_input_embeds[
- image_start_token_pos + num_patches + 1 :
- ],
- ),
- axis=0,
- )
- new_input_embeds.append(cur_input_embeds)
- else:
- raise NotImplementedError
- inputs_embeds = paddle.stack(new_input_embeds, axis=0)
- return super().forward(
- input_ids=None,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- position_ids=position_ids,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- class GOTQwenForCausalLM(Qwen2ForCausalLM):
- config_class = GOTConfig
- def __init__(self, config):
- super(Qwen2ForCausalLM, self).__init__(config)
- self.qwen2 = GOTQwenModel(config)
- self.vocab_size = config.vocab_size
- if config.tie_word_embeddings:
- self.lm_head = Qwen2LMHead(
- config,
- embedding_weights=self.qwen2.embed_tokens.weight,
- transpose_y=True,
- )
- self.tie_weights()
- else:
- self.lm_head = Qwen2LMHead(config)
- self.eval()
- def get_model(self):
- return self.qwen2
- def forward(
- self,
- input_ids: paddle.Tensor = None,
- attention_mask: Optional[paddle.Tensor] = None,
- position_ids: Optional[paddle.Tensor] = None,
- past_key_values: Optional[List[paddle.Tensor]] = None,
- inputs_embeds: Optional[paddle.Tensor] = None,
- labels: Optional[paddle.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- images: Optional[paddle.Tensor] = None,
- return_dict: Optional[bool] = None,
- ):
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
- outputs = self.qwen2(
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- images=images,
- return_dict=return_dict,
- )
- hidden_states = outputs[0]
- logits = self.lm_head(hidden_states)
- logits = logits.astype(dtype="float32")
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :]
- shift_labels = labels[..., 1:]
- loss_fct = nn.CrossEntropyLoss(reduction="sum")
- shift_logits = shift_logits.reshape([-1, self.config.vocab_size])
- shift_labels = shift_labels.reshape([-1])
- loss = loss_fct(shift_logits, shift_labels)
- label_sum = paddle.sum(shift_labels != -100)
- loss = loss / label_sum
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- **kwargs
- ):
- batch_size, seq_length = input_ids.shape
- attention_mask = paddle.ones((batch_size, seq_length), dtype=paddle.bool)
- # Omit tokens covered by past_key_values
- if past_key_values is not None:
- past_length = past_key_values[0][0].shape[1]
- if past_length < input_ids.shape[1]:
- input_ids = input_ids[:, past_length:]
- position_ids = kwargs.get("position_ids", None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.astype(dtype="int64").cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1] :]
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids}
- model_inputs.update(
- {
- "position_ids": position_ids,
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "attention_mask": attention_mask,
- "images": kwargs.get("images", None),
- }
- )
- return model_inputs
- class PPChart2TableInference(GOTQwenForCausalLM):
- def generate(self, inputs, **kwargs):
- max_new_tokens = kwargs.get("max_new_tokens", 1024)
- no_repeat_ngram_size = kwargs.get("no_repeat_ngram_size", 20)
- with paddle.no_grad():
- generated_ids = super().generate(
- inputs["input_ids"],
- images=inputs["images"],
- do_sample=False,
- num_beams=1,
- no_repeat_ngram_size=no_repeat_ngram_size,
- max_new_tokens=max_new_tokens,
- )
- return generated_ids
|