import math import re from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from transformers import ( Qwen2ForCausalLM, Qwen2Model, SiglipVisionConfig, SiglipVisionModel, ) from transformers.generation.utils import GenerateOutput from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_mineru2 import Mineru2QwenConfig from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape class SiglipVisionTower(nn.Module): def __init__(self, vision_tower): super().__init__() self.config = SiglipVisionConfig.from_pretrained(vision_tower) assert isinstance(self.config, SiglipVisionConfig) self.config.num_hidden_layers -= 1 # drop the last hidden layer self.config.vision_use_head = False self.vision_tower = SiglipVisionModel(self.config) self.vision_tower.requires_grad_(False) self.image_processor = Mineru2ImageProcessor() def forward(self, images): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower( image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True ) image_feature = image_forward_out.hidden_states[-1].to(image.dtype) image_features.append(image_feature) else: image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) image_features = image_forward_outs.hidden_states[-1].to(images.dtype) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): for p in self.vision_tower.parameters(): return p.dtype @property def device(self): for p in self.vision_tower.parameters(): return p.device @property def hidden_size(self): return self.config.hidden_size @property def num_patches(self): return (self.config.image_size // self.config.patch_size) ** 2 @property def num_patches_per_side(self): return self.config.image_size // self.config.patch_size @property def image_size(self): return self.config.image_size def build_vision_tower(config: Mineru2QwenConfig): vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", "")) model_path = getattr(config, "_name_or_path", "") if "siglip" in vision_tower.lower(): if model_path: return SiglipVisionTower(f"{model_path}/{vision_tower}") else: return SiglipVisionTower(vision_tower) raise ValueError(f"Unknown vision tower: {vision_tower}") def build_vision_projector(config: Mineru2QwenConfig): projector_type = getattr(config, "mm_projector_type", "linear") if projector_type == "linear": return nn.Linear(config.mm_hidden_size, config.hidden_size) mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) # type: ignore modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return nn.Sequential(*modules) if projector_type == "identity": return nn.Identity() raise ValueError(f"Unknown projector type: {projector_type}") class Mineru2QwenModel(Qwen2Model): config_class = Mineru2QwenConfig def __init__(self, config: Mineru2QwenConfig): super(Mineru2QwenModel, self).__init__(config) self.vision_tower = build_vision_tower(config) self.mm_projector = build_vision_projector(config) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) class Mineru2QwenForCausalLM(Qwen2ForCausalLM): config_class = Mineru2QwenConfig def __init__(self, config: Mineru2QwenConfig): super(Qwen2ForCausalLM, self).__init__(config) config.rope_scaling = None self.model = Mineru2QwenModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.ignore_index = config.ignore_index self.image_token_index = config.image_token_index # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def encode_images(self, images: torch.Tensor): image_features = self.get_model().vision_tower(images) image_features = self.get_model().mm_projector(image_features) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None ): vision_tower = self.get_model().vision_tower if vision_tower is None or images is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: if type(images) is list: images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") if mm_patch_merge_type == "flat": image_features = [x.flatten(0, 1) for x in image_features] elif mm_patch_merge_type.startswith("spatial"): new_image_features = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.get_model().vision_tower.num_patches_per_side assert height * width == base_image_feature.shape[0] if "anyres_max" in image_aspect_ratio: matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio) if matched_anyres_max_num_patches: max_num_patches = int(matched_anyres_max_num_patches.group(1)) if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: num_patch_width, num_patch_height = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_model().vision_tower.config.image_size, ) image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) else: raise NotImplementedError if ( "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches ): unit = image_feature.shape[2] image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) c, h, w = image_feature.shape times = math.sqrt(h * w / (max_num_patches * unit**2)) if times > 1.1: image_feature = image_feature[None] image_feature = nn.functional.interpolate( image_feature, [int(h // times), int(w // times)], mode="bilinear" )[0] image_feature = torch.cat( ( image_feature, self.model.image_newline[:, None, None] .expand(*image_feature.shape[:-1], 1) .to(image_feature.device), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) elif "unpad" in mm_patch_merge_type: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = torch.cat( ( image_feature, self.model.image_newline[:, None, None] .expand(*image_feature.shape[:-1], 1) .to(image_feature.device), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) else: image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() image_feature = image_feature.flatten(0, 3) image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] if "unpad" in mm_patch_merge_type: image_feature = torch.cat( (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0 ) new_image_features.append(image_feature) image_features = new_image_features else: raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") else: image_features = self.encode_images(images) _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, self.ignore_index) # remove the padding using attention_mask -- FIXME _input_ids = input_ids input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] new_input_embeds = [] new_labels = [] cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == self.image_token_index).sum() if num_images == 0: cur_image_features = image_features[cur_image_idx] cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) new_input_embeds.append(cur_input_embeds) new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = ( [-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]] ) cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) split_sizes = [x.shape[0] for x in cur_labels_noim] cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) cur_new_input_embeds = [] cur_new_labels = [] for i in range(num_images + 1): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: cur_image_features = image_features[cur_image_idx] cur_image_idx += 1 cur_new_input_embeds.append(cur_image_features) cur_new_labels.append( torch.full( (cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype ) ) cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) # Truncate sequences to max length as image embeddings can make the sequence longer tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) if tokenizer_model_max_length is not None: new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] new_labels = [x[:tokenizer_model_max_length] for x in new_labels] # Combine them max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full( (batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device ) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): cur_len = cur_new_embed.shape[0] if getattr(self.config, "tokenizer_padding_side", "right") == "left": new_input_embeds_padded.append( torch.cat( ( torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), cur_new_embed, ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, -cur_len:] = cur_new_labels attention_mask[i, -cur_len:] = True position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) else: new_input_embeds_padded.append( torch.cat( ( cur_new_embed, torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = ( self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes ) ) return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[List[List[int]]] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal( inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes ) return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) image_sizes = kwargs.pop("image_sizes", None) inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) if images is not None: inputs["images"] = images if image_sizes is not None: inputs["image_sizes"] = image_sizes return inputs