| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449 |
- import math
- import re
- from typing import List, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- from transformers import (
- Qwen2ForCausalLM,
- Qwen2Model,
- SiglipVisionConfig,
- SiglipVisionModel,
- )
- from transformers.generation.utils import GenerateOutput
- from transformers.modeling_outputs import CausalLMOutputWithPast
- from .configuration_mineru2 import Mineru2QwenConfig
- from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
- class SiglipVisionTower(nn.Module):
- def __init__(self, vision_tower):
- super().__init__()
- self.config = SiglipVisionConfig.from_pretrained(vision_tower)
- assert isinstance(self.config, SiglipVisionConfig)
- self.config.num_hidden_layers -= 1 # drop the last hidden layer
- self.config.vision_use_head = False
- self.vision_tower = SiglipVisionModel(self.config)
- self.vision_tower.requires_grad_(False)
- self.image_processor = Mineru2ImageProcessor()
- def forward(self, images):
- if type(images) is list:
- image_features = []
- for image in images:
- image_forward_out = self.vision_tower(
- image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
- )
- image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
- image_features.append(image_feature)
- else:
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
- image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
- return image_features
- @property
- def dummy_feature(self):
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
- @property
- def dtype(self):
- for p in self.vision_tower.parameters():
- return p.dtype
- @property
- def device(self):
- for p in self.vision_tower.parameters():
- return p.device
- @property
- def hidden_size(self):
- return self.config.hidden_size
- @property
- def num_patches(self):
- return (self.config.image_size // self.config.patch_size) ** 2
- @property
- def num_patches_per_side(self):
- return self.config.image_size // self.config.patch_size
- @property
- def image_size(self):
- return self.config.image_size
- def build_vision_tower(config: Mineru2QwenConfig):
- vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
- model_path = getattr(config, "_name_or_path", "")
- if "siglip" in vision_tower.lower():
- if model_path:
- return SiglipVisionTower(f"{model_path}/{vision_tower}")
- else:
- return SiglipVisionTower(vision_tower)
- raise ValueError(f"Unknown vision tower: {vision_tower}")
- def build_vision_projector(config: Mineru2QwenConfig):
- projector_type = getattr(config, "mm_projector_type", "linear")
- if projector_type == "linear":
- return nn.Linear(config.mm_hidden_size, config.hidden_size)
- mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
- if mlp_gelu_match:
- mlp_depth = int(mlp_gelu_match.group(1))
- modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
- for _ in range(1, mlp_depth):
- modules.append(nn.GELU()) # type: ignore
- modules.append(nn.Linear(config.hidden_size, config.hidden_size))
- return nn.Sequential(*modules)
- if projector_type == "identity":
- return nn.Identity()
- raise ValueError(f"Unknown projector type: {projector_type}")
- class Mineru2QwenModel(Qwen2Model):
- config_class = Mineru2QwenConfig
- def __init__(self, config: Mineru2QwenConfig):
- super(Mineru2QwenModel, self).__init__(config)
- self.vision_tower = build_vision_tower(config)
- self.mm_projector = build_vision_projector(config)
- if "unpad" in getattr(config, "mm_patch_merge_type", ""):
- self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
- class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
- config_class = Mineru2QwenConfig
- def __init__(self, config: Mineru2QwenConfig):
- super(Qwen2ForCausalLM, self).__init__(config)
- config.rope_scaling = None
- self.model = Mineru2QwenModel(config)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.ignore_index = config.ignore_index
- self.image_token_index = config.image_token_index
- # Initialize weights and apply final processing
- self.post_init()
- def get_model(self):
- return self.model
- def encode_images(self, images: torch.Tensor):
- image_features = self.get_model().vision_tower(images)
- image_features = self.get_model().mm_projector(image_features)
- return image_features
- def prepare_inputs_labels_for_multimodal(
- self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
- ):
- vision_tower = self.get_model().vision_tower
- if vision_tower is None or images is None or input_ids.shape[1] == 1:
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
- if type(images) is list or images.ndim == 5:
- if type(images) is list:
- images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
- concat_images = torch.cat([image for image in images], dim=0)
- image_features = self.encode_images(concat_images)
- split_sizes = [image.shape[0] for image in images]
- image_features = torch.split(image_features, split_sizes, dim=0)
- mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
- image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
- if mm_patch_merge_type == "flat":
- image_features = [x.flatten(0, 1) for x in image_features]
- elif mm_patch_merge_type.startswith("spatial"):
- new_image_features = []
- for image_idx, image_feature in enumerate(image_features):
- if image_feature.shape[0] > 1:
- base_image_feature = image_feature[0]
- image_feature = image_feature[1:]
- height = width = self.get_model().vision_tower.num_patches_per_side
- assert height * width == base_image_feature.shape[0]
- if "anyres_max" in image_aspect_ratio:
- matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
- if matched_anyres_max_num_patches:
- max_num_patches = int(matched_anyres_max_num_patches.group(1))
- if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(
- image_sizes[image_idx],
- self.config.image_grid_pinpoints,
- self.get_model().vision_tower.config.image_size,
- )
- image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
- else:
- raise NotImplementedError
- if (
- "unpad" in mm_patch_merge_type
- and "anyres_max" in image_aspect_ratio
- and matched_anyres_max_num_patches
- ):
- unit = image_feature.shape[2]
- image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
- image_feature = image_feature.flatten(1, 2).flatten(2, 3)
- c, h, w = image_feature.shape
- times = math.sqrt(h * w / (max_num_patches * unit**2))
- if times > 1.1:
- image_feature = image_feature[None]
- image_feature = nn.functional.interpolate(
- image_feature, [int(h // times), int(w // times)], mode="bilinear"
- )[0]
- image_feature = torch.cat(
- (
- image_feature,
- self.model.image_newline[:, None, None]
- .expand(*image_feature.shape[:-1], 1)
- .to(image_feature.device),
- ),
- dim=-1,
- )
- image_feature = image_feature.flatten(1, 2).transpose(0, 1)
- elif "unpad" in mm_patch_merge_type:
- image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
- image_feature = image_feature.flatten(1, 2).flatten(2, 3)
- image_feature = torch.cat(
- (
- image_feature,
- self.model.image_newline[:, None, None]
- .expand(*image_feature.shape[:-1], 1)
- .to(image_feature.device),
- ),
- dim=-1,
- )
- image_feature = image_feature.flatten(1, 2).transpose(0, 1)
- else:
- image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
- image_feature = image_feature.flatten(0, 3)
- image_feature = torch.cat((base_image_feature, image_feature), dim=0)
- else:
- image_feature = image_feature[0]
- if "unpad" in mm_patch_merge_type:
- image_feature = torch.cat(
- (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
- )
- new_image_features.append(image_feature)
- image_features = new_image_features
- else:
- raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
- else:
- image_features = self.encode_images(images)
- _labels = labels
- _position_ids = position_ids
- _attention_mask = attention_mask
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
- else:
- attention_mask = attention_mask.bool()
- if position_ids is None:
- position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
- if labels is None:
- labels = torch.full_like(input_ids, self.ignore_index)
- # remove the padding using attention_mask -- FIXME
- _input_ids = input_ids
- input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
- labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
- new_input_embeds = []
- new_labels = []
- cur_image_idx = 0
- for batch_idx, cur_input_ids in enumerate(input_ids):
- num_images = (cur_input_ids == self.image_token_index).sum()
- if num_images == 0:
- cur_image_features = image_features[cur_image_idx]
- cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
- cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
- new_input_embeds.append(cur_input_embeds)
- new_labels.append(labels[batch_idx])
- cur_image_idx += 1
- continue
- image_token_indices = (
- [-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
- )
- cur_input_ids_noim = []
- cur_labels = labels[batch_idx]
- cur_labels_noim = []
- for i in range(len(image_token_indices) - 1):
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
- cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
- split_sizes = [x.shape[0] for x in cur_labels_noim]
- cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
- cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
- cur_new_input_embeds = []
- cur_new_labels = []
- for i in range(num_images + 1):
- cur_new_input_embeds.append(cur_input_embeds_no_im[i])
- cur_new_labels.append(cur_labels_noim[i])
- if i < num_images:
- cur_image_features = image_features[cur_image_idx]
- cur_image_idx += 1
- cur_new_input_embeds.append(cur_image_features)
- cur_new_labels.append(
- torch.full(
- (cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
- )
- )
- cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
- cur_new_input_embeds = torch.cat(cur_new_input_embeds)
- cur_new_labels = torch.cat(cur_new_labels)
- new_input_embeds.append(cur_new_input_embeds)
- new_labels.append(cur_new_labels)
- # Truncate sequences to max length as image embeddings can make the sequence longer
- tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
- if tokenizer_model_max_length is not None:
- new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
- new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
- # Combine them
- max_len = max(x.shape[0] for x in new_input_embeds)
- batch_size = len(new_input_embeds)
- new_input_embeds_padded = []
- new_labels_padded = torch.full(
- (batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
- )
- attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
- position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
- for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
- cur_len = cur_new_embed.shape[0]
- if getattr(self.config, "tokenizer_padding_side", "right") == "left":
- new_input_embeds_padded.append(
- torch.cat(
- (
- torch.zeros(
- (max_len - cur_len, cur_new_embed.shape[1]),
- dtype=cur_new_embed.dtype,
- device=cur_new_embed.device,
- ),
- cur_new_embed,
- ),
- dim=0,
- )
- )
- if cur_len > 0:
- new_labels_padded[i, -cur_len:] = cur_new_labels
- attention_mask[i, -cur_len:] = True
- position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
- else:
- new_input_embeds_padded.append(
- torch.cat(
- (
- cur_new_embed,
- torch.zeros(
- (max_len - cur_len, cur_new_embed.shape[1]),
- dtype=cur_new_embed.dtype,
- device=cur_new_embed.device,
- ),
- ),
- dim=0,
- )
- )
- if cur_len > 0:
- new_labels_padded[i, :cur_len] = cur_new_labels
- attention_mask[i, :cur_len] = True
- position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
- new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
- if _labels is None:
- new_labels = None
- else:
- new_labels = new_labels_padded
- if _attention_mask is None:
- attention_mask = None
- else:
- attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
- if _position_ids is None:
- position_ids = None
- return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- images: Optional[torch.FloatTensor] = None,
- image_sizes: Optional[List[List[int]]] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- if inputs_embeds is None:
- (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
- self.prepare_inputs_labels_for_multimodal(
- input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
- )
- )
- return super().forward(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- labels=labels,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- @torch.no_grad()
- def generate(
- self,
- inputs: Optional[torch.Tensor] = None,
- images: Optional[torch.Tensor] = None,
- image_sizes: Optional[List[List[int]]] = None,
- **kwargs,
- ) -> Union[GenerateOutput, torch.LongTensor]:
- position_ids = kwargs.pop("position_ids", None)
- attention_mask = kwargs.pop("attention_mask", None)
- if "inputs_embeds" in kwargs:
- raise NotImplementedError("`inputs_embeds` is not supported")
- inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
- inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
- )
- return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
- images = kwargs.pop("images", None)
- image_sizes = kwargs.pop("image_sizes", None)
- inputs = super().prepare_inputs_for_generation(
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
- )
- if images is not None:
- inputs["images"] = images
- if image_sizes is not None:
- inputs["image_sizes"] = image_sizes
- return inputs
|