|
|
@@ -1,449 +0,0 @@
|
|
|
-import math
|
|
|
-import re
|
|
|
-from typing import List, Optional, Tuple, Union
|
|
|
-
|
|
|
-import torch
|
|
|
-import torch.nn as nn
|
|
|
-from transformers import (
|
|
|
- Qwen2ForCausalLM,
|
|
|
- Qwen2Model,
|
|
|
- SiglipVisionConfig,
|
|
|
- SiglipVisionModel,
|
|
|
-)
|
|
|
-from transformers.generation.utils import GenerateOutput
|
|
|
-from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
-
|
|
|
-from .configuration_mineru2 import Mineru2QwenConfig
|
|
|
-from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
|
|
|
-
|
|
|
-
|
|
|
-class SiglipVisionTower(nn.Module):
|
|
|
- def __init__(self, vision_tower):
|
|
|
- super().__init__()
|
|
|
-
|
|
|
- self.config = SiglipVisionConfig.from_pretrained(vision_tower)
|
|
|
- assert isinstance(self.config, SiglipVisionConfig)
|
|
|
- self.config.num_hidden_layers -= 1 # drop the last hidden layer
|
|
|
- self.config.vision_use_head = False
|
|
|
-
|
|
|
- self.vision_tower = SiglipVisionModel(self.config)
|
|
|
- self.vision_tower.requires_grad_(False)
|
|
|
-
|
|
|
- self.image_processor = Mineru2ImageProcessor()
|
|
|
-
|
|
|
- def forward(self, images):
|
|
|
- if type(images) is list:
|
|
|
- image_features = []
|
|
|
- for image in images:
|
|
|
- image_forward_out = self.vision_tower(
|
|
|
- image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
|
|
|
- )
|
|
|
- image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
|
|
- image_features.append(image_feature)
|
|
|
- else:
|
|
|
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
|
|
- image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
|
|
-
|
|
|
- return image_features
|
|
|
-
|
|
|
- @property
|
|
|
- def dummy_feature(self):
|
|
|
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
|
-
|
|
|
- @property
|
|
|
- def dtype(self):
|
|
|
- for p in self.vision_tower.parameters():
|
|
|
- return p.dtype
|
|
|
-
|
|
|
- @property
|
|
|
- def device(self):
|
|
|
- for p in self.vision_tower.parameters():
|
|
|
- return p.device
|
|
|
-
|
|
|
- @property
|
|
|
- def hidden_size(self):
|
|
|
- return self.config.hidden_size
|
|
|
-
|
|
|
- @property
|
|
|
- def num_patches(self):
|
|
|
- return (self.config.image_size // self.config.patch_size) ** 2
|
|
|
-
|
|
|
- @property
|
|
|
- def num_patches_per_side(self):
|
|
|
- return self.config.image_size // self.config.patch_size
|
|
|
-
|
|
|
- @property
|
|
|
- def image_size(self):
|
|
|
- return self.config.image_size
|
|
|
-
|
|
|
-
|
|
|
-def build_vision_tower(config: Mineru2QwenConfig):
|
|
|
- vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
|
|
|
- model_path = getattr(config, "_name_or_path", "")
|
|
|
- if "siglip" in vision_tower.lower():
|
|
|
- if model_path:
|
|
|
- return SiglipVisionTower(f"{model_path}/{vision_tower}")
|
|
|
- else:
|
|
|
- return SiglipVisionTower(vision_tower)
|
|
|
- raise ValueError(f"Unknown vision tower: {vision_tower}")
|
|
|
-
|
|
|
-
|
|
|
-def build_vision_projector(config: Mineru2QwenConfig):
|
|
|
- projector_type = getattr(config, "mm_projector_type", "linear")
|
|
|
-
|
|
|
- if projector_type == "linear":
|
|
|
- return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
|
|
-
|
|
|
- mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
|
|
- if mlp_gelu_match:
|
|
|
- mlp_depth = int(mlp_gelu_match.group(1))
|
|
|
- modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
|
|
- for _ in range(1, mlp_depth):
|
|
|
- modules.append(nn.GELU()) # type: ignore
|
|
|
- modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
|
|
- return nn.Sequential(*modules)
|
|
|
-
|
|
|
- if projector_type == "identity":
|
|
|
- return nn.Identity()
|
|
|
-
|
|
|
- raise ValueError(f"Unknown projector type: {projector_type}")
|
|
|
-
|
|
|
-
|
|
|
-class Mineru2QwenModel(Qwen2Model):
|
|
|
- config_class = Mineru2QwenConfig
|
|
|
-
|
|
|
- def __init__(self, config: Mineru2QwenConfig):
|
|
|
- super(Mineru2QwenModel, self).__init__(config)
|
|
|
-
|
|
|
- self.vision_tower = build_vision_tower(config)
|
|
|
- self.mm_projector = build_vision_projector(config)
|
|
|
-
|
|
|
- if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
|
|
- self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
|
|
|
-
|
|
|
-
|
|
|
-class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
|
|
|
- config_class = Mineru2QwenConfig
|
|
|
-
|
|
|
- def __init__(self, config: Mineru2QwenConfig):
|
|
|
- super(Qwen2ForCausalLM, self).__init__(config)
|
|
|
- config.rope_scaling = None
|
|
|
- self.model = Mineru2QwenModel(config)
|
|
|
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
-
|
|
|
- self.ignore_index = config.ignore_index
|
|
|
- self.image_token_index = config.image_token_index
|
|
|
-
|
|
|
- # Initialize weights and apply final processing
|
|
|
- self.post_init()
|
|
|
-
|
|
|
- def get_model(self):
|
|
|
- return self.model
|
|
|
-
|
|
|
- def encode_images(self, images: torch.Tensor):
|
|
|
- image_features = self.get_model().vision_tower(images)
|
|
|
- image_features = self.get_model().mm_projector(image_features)
|
|
|
- return image_features
|
|
|
-
|
|
|
- def prepare_inputs_labels_for_multimodal(
|
|
|
- self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
|
|
|
- ):
|
|
|
- vision_tower = self.get_model().vision_tower
|
|
|
- if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
|
|
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
|
|
-
|
|
|
- if type(images) is list or images.ndim == 5:
|
|
|
- if type(images) is list:
|
|
|
- images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
|
|
- concat_images = torch.cat([image for image in images], dim=0)
|
|
|
- image_features = self.encode_images(concat_images)
|
|
|
- split_sizes = [image.shape[0] for image in images]
|
|
|
- image_features = torch.split(image_features, split_sizes, dim=0)
|
|
|
- mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
|
|
- image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
|
|
- if mm_patch_merge_type == "flat":
|
|
|
- image_features = [x.flatten(0, 1) for x in image_features]
|
|
|
- elif mm_patch_merge_type.startswith("spatial"):
|
|
|
- new_image_features = []
|
|
|
- for image_idx, image_feature in enumerate(image_features):
|
|
|
- if image_feature.shape[0] > 1:
|
|
|
- base_image_feature = image_feature[0]
|
|
|
- image_feature = image_feature[1:]
|
|
|
- height = width = self.get_model().vision_tower.num_patches_per_side
|
|
|
- assert height * width == base_image_feature.shape[0]
|
|
|
-
|
|
|
- if "anyres_max" in image_aspect_ratio:
|
|
|
- matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
|
|
|
- if matched_anyres_max_num_patches:
|
|
|
- max_num_patches = int(matched_anyres_max_num_patches.group(1))
|
|
|
-
|
|
|
- if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
|
|
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
|
- image_sizes[image_idx],
|
|
|
- self.config.image_grid_pinpoints,
|
|
|
- self.get_model().vision_tower.config.image_size,
|
|
|
- )
|
|
|
- image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
|
|
- else:
|
|
|
- raise NotImplementedError
|
|
|
- if (
|
|
|
- "unpad" in mm_patch_merge_type
|
|
|
- and "anyres_max" in image_aspect_ratio
|
|
|
- and matched_anyres_max_num_patches
|
|
|
- ):
|
|
|
- unit = image_feature.shape[2]
|
|
|
- image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
|
- image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
|
- c, h, w = image_feature.shape
|
|
|
- times = math.sqrt(h * w / (max_num_patches * unit**2))
|
|
|
- if times > 1.1:
|
|
|
- image_feature = image_feature[None]
|
|
|
- image_feature = nn.functional.interpolate(
|
|
|
- image_feature, [int(h // times), int(w // times)], mode="bilinear"
|
|
|
- )[0]
|
|
|
- image_feature = torch.cat(
|
|
|
- (
|
|
|
- image_feature,
|
|
|
- self.model.image_newline[:, None, None]
|
|
|
- .expand(*image_feature.shape[:-1], 1)
|
|
|
- .to(image_feature.device),
|
|
|
- ),
|
|
|
- dim=-1,
|
|
|
- )
|
|
|
- image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
|
- elif "unpad" in mm_patch_merge_type:
|
|
|
- image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
|
- image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
|
- image_feature = torch.cat(
|
|
|
- (
|
|
|
- image_feature,
|
|
|
- self.model.image_newline[:, None, None]
|
|
|
- .expand(*image_feature.shape[:-1], 1)
|
|
|
- .to(image_feature.device),
|
|
|
- ),
|
|
|
- dim=-1,
|
|
|
- )
|
|
|
- image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
|
- else:
|
|
|
- image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
|
|
- image_feature = image_feature.flatten(0, 3)
|
|
|
- image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
|
|
- else:
|
|
|
- image_feature = image_feature[0]
|
|
|
- if "unpad" in mm_patch_merge_type:
|
|
|
- image_feature = torch.cat(
|
|
|
- (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
|
|
|
- )
|
|
|
- new_image_features.append(image_feature)
|
|
|
- image_features = new_image_features
|
|
|
- else:
|
|
|
- raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
|
|
- else:
|
|
|
- image_features = self.encode_images(images)
|
|
|
-
|
|
|
- _labels = labels
|
|
|
- _position_ids = position_ids
|
|
|
- _attention_mask = attention_mask
|
|
|
- if attention_mask is None:
|
|
|
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
|
|
- else:
|
|
|
- attention_mask = attention_mask.bool()
|
|
|
- if position_ids is None:
|
|
|
- position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
|
|
- if labels is None:
|
|
|
- labels = torch.full_like(input_ids, self.ignore_index)
|
|
|
-
|
|
|
- # remove the padding using attention_mask -- FIXME
|
|
|
- _input_ids = input_ids
|
|
|
- input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
|
|
- labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
|
|
-
|
|
|
- new_input_embeds = []
|
|
|
- new_labels = []
|
|
|
- cur_image_idx = 0
|
|
|
- for batch_idx, cur_input_ids in enumerate(input_ids):
|
|
|
- num_images = (cur_input_ids == self.image_token_index).sum()
|
|
|
- if num_images == 0:
|
|
|
- cur_image_features = image_features[cur_image_idx]
|
|
|
- cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
|
|
- cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
|
|
- new_input_embeds.append(cur_input_embeds)
|
|
|
- new_labels.append(labels[batch_idx])
|
|
|
- cur_image_idx += 1
|
|
|
- continue
|
|
|
-
|
|
|
- image_token_indices = (
|
|
|
- [-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
|
|
|
- )
|
|
|
- cur_input_ids_noim = []
|
|
|
- cur_labels = labels[batch_idx]
|
|
|
- cur_labels_noim = []
|
|
|
- for i in range(len(image_token_indices) - 1):
|
|
|
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
|
|
- cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
|
|
- split_sizes = [x.shape[0] for x in cur_labels_noim]
|
|
|
- cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
|
|
- cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
|
|
- cur_new_input_embeds = []
|
|
|
- cur_new_labels = []
|
|
|
-
|
|
|
- for i in range(num_images + 1):
|
|
|
- cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
|
|
- cur_new_labels.append(cur_labels_noim[i])
|
|
|
- if i < num_images:
|
|
|
- cur_image_features = image_features[cur_image_idx]
|
|
|
- cur_image_idx += 1
|
|
|
- cur_new_input_embeds.append(cur_image_features)
|
|
|
- cur_new_labels.append(
|
|
|
- torch.full(
|
|
|
- (cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
|
|
-
|
|
|
- cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
|
|
- cur_new_labels = torch.cat(cur_new_labels)
|
|
|
-
|
|
|
- new_input_embeds.append(cur_new_input_embeds)
|
|
|
- new_labels.append(cur_new_labels)
|
|
|
-
|
|
|
- # Truncate sequences to max length as image embeddings can make the sequence longer
|
|
|
- tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
|
|
- if tokenizer_model_max_length is not None:
|
|
|
- new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
|
|
- new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
|
|
-
|
|
|
- # Combine them
|
|
|
- max_len = max(x.shape[0] for x in new_input_embeds)
|
|
|
- batch_size = len(new_input_embeds)
|
|
|
-
|
|
|
- new_input_embeds_padded = []
|
|
|
- new_labels_padded = torch.full(
|
|
|
- (batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
|
|
|
- )
|
|
|
- attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
|
|
- position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
|
|
-
|
|
|
- for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
|
|
- cur_len = cur_new_embed.shape[0]
|
|
|
- if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
|
|
- new_input_embeds_padded.append(
|
|
|
- torch.cat(
|
|
|
- (
|
|
|
- torch.zeros(
|
|
|
- (max_len - cur_len, cur_new_embed.shape[1]),
|
|
|
- dtype=cur_new_embed.dtype,
|
|
|
- device=cur_new_embed.device,
|
|
|
- ),
|
|
|
- cur_new_embed,
|
|
|
- ),
|
|
|
- dim=0,
|
|
|
- )
|
|
|
- )
|
|
|
- if cur_len > 0:
|
|
|
- new_labels_padded[i, -cur_len:] = cur_new_labels
|
|
|
- attention_mask[i, -cur_len:] = True
|
|
|
- position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
|
|
- else:
|
|
|
- new_input_embeds_padded.append(
|
|
|
- torch.cat(
|
|
|
- (
|
|
|
- cur_new_embed,
|
|
|
- torch.zeros(
|
|
|
- (max_len - cur_len, cur_new_embed.shape[1]),
|
|
|
- dtype=cur_new_embed.dtype,
|
|
|
- device=cur_new_embed.device,
|
|
|
- ),
|
|
|
- ),
|
|
|
- dim=0,
|
|
|
- )
|
|
|
- )
|
|
|
- if cur_len > 0:
|
|
|
- new_labels_padded[i, :cur_len] = cur_new_labels
|
|
|
- attention_mask[i, :cur_len] = True
|
|
|
- position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
|
|
-
|
|
|
- new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
|
|
-
|
|
|
- if _labels is None:
|
|
|
- new_labels = None
|
|
|
- else:
|
|
|
- new_labels = new_labels_padded
|
|
|
-
|
|
|
- if _attention_mask is None:
|
|
|
- attention_mask = None
|
|
|
- else:
|
|
|
- attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
|
|
-
|
|
|
- if _position_ids is None:
|
|
|
- position_ids = None
|
|
|
-
|
|
|
- return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
|
|
-
|
|
|
- def forward(
|
|
|
- self,
|
|
|
- input_ids: torch.LongTensor = None,
|
|
|
- attention_mask: Optional[torch.Tensor] = None,
|
|
|
- position_ids: Optional[torch.LongTensor] = None,
|
|
|
- past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
- inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
- labels: Optional[torch.LongTensor] = None,
|
|
|
- use_cache: Optional[bool] = None,
|
|
|
- output_attentions: Optional[bool] = None,
|
|
|
- output_hidden_states: Optional[bool] = None,
|
|
|
- images: Optional[torch.FloatTensor] = None,
|
|
|
- image_sizes: Optional[List[List[int]]] = None,
|
|
|
- return_dict: Optional[bool] = None,
|
|
|
- cache_position: Optional[torch.LongTensor] = None,
|
|
|
- ) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
|
-
|
|
|
- if inputs_embeds is None:
|
|
|
- (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
|
|
|
- self.prepare_inputs_labels_for_multimodal(
|
|
|
- input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
|
|
|
- )
|
|
|
- )
|
|
|
- return super().forward(
|
|
|
- input_ids=input_ids,
|
|
|
- attention_mask=attention_mask,
|
|
|
- position_ids=position_ids,
|
|
|
- past_key_values=past_key_values,
|
|
|
- inputs_embeds=inputs_embeds,
|
|
|
- labels=labels,
|
|
|
- use_cache=use_cache,
|
|
|
- output_attentions=output_attentions,
|
|
|
- output_hidden_states=output_hidden_states,
|
|
|
- return_dict=return_dict,
|
|
|
- )
|
|
|
-
|
|
|
- @torch.no_grad()
|
|
|
- def generate(
|
|
|
- self,
|
|
|
- inputs: Optional[torch.Tensor] = None,
|
|
|
- images: Optional[torch.Tensor] = None,
|
|
|
- image_sizes: Optional[List[List[int]]] = None,
|
|
|
- **kwargs,
|
|
|
- ) -> Union[GenerateOutput, torch.LongTensor]:
|
|
|
- position_ids = kwargs.pop("position_ids", None)
|
|
|
- attention_mask = kwargs.pop("attention_mask", None)
|
|
|
- if "inputs_embeds" in kwargs:
|
|
|
- raise NotImplementedError("`inputs_embeds` is not supported")
|
|
|
-
|
|
|
- inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
|
|
|
- inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
|
|
|
- )
|
|
|
-
|
|
|
- return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
|
|
-
|
|
|
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
|
|
- images = kwargs.pop("images", None)
|
|
|
- image_sizes = kwargs.pop("image_sizes", None)
|
|
|
- inputs = super().prepare_inputs_for_generation(
|
|
|
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
|
|
- )
|
|
|
- if images is not None:
|
|
|
- inputs["images"] = images
|
|
|
- if image_sizes is not None:
|
|
|
- inputs["image_sizes"] = image_sizes
|
|
|
- return inputs
|