| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214 |
- # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Iterable, Mapping, Sequence
- from functools import partial
- from typing import List, Optional, Tuple, Union
- import numpy as np
- from .....utils.deps import is_dep_available
- if all(
- map(is_dep_available, ("einops", "torch", "transformers", "vllm", "flash-attn"))
- ):
- import torch
- import torch.nn as nn
- from einops import rearrange, repeat
- from transformers import BatchFeature
- from transformers.activations import GELUActivation
- from transformers.modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPooling,
- )
- from transformers.utils import torch_int
- from vllm.compilation.decorators import support_torch_compile
- from vllm.config import VllmConfig
- from vllm.distributed import get_tensor_model_parallel_world_size
- from vllm.model_executor.layers.activation import get_act_fn
- from vllm.model_executor.layers.linear import (
- ColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear,
- )
- from vllm.model_executor.layers.logits_processor import LogitsProcessor
- from vllm.model_executor.layers.quantization import QuantizationConfig
- from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
- from vllm.model_executor.model_loader.weight_utils import (
- default_weight_loader,
- maybe_remap_kv_scale_name,
- )
- from vllm.model_executor.models.vision import get_vit_attn_backend
- from vllm.platforms import _Backend, current_platform
- try:
- from vllm.model_executor.models.ernie45 import Ernie4_5_ForCausalLM
- except ImportError:
- from vllm.model_executor.models.ernie45 import (
- Ernie4_5ForCausalLM as Ernie4_5_ForCausalLM,
- )
- from vllm.model_executor.models.interfaces import SupportsMultiModal
- from vllm.model_executor.models.utils import (
- AutoWeightsLoader,
- PPMissingLayer,
- is_pp_missing_parameter,
- merge_multimodal_embeddings,
- )
- from vllm.multimodal import MULTIMODAL_REGISTRY
- from vllm.multimodal.inputs import (
- MultiModalDataDict,
- MultiModalFieldConfig,
- MultiModalKwargs,
- NestedTensors,
- )
- from vllm.multimodal.parse import (
- ImageProcessorItems,
- ImageSize,
- MultiModalDataItems,
- )
- from vllm.multimodal.processing import (
- BaseMultiModalProcessor,
- BaseProcessingInfo,
- PromptReplacement,
- PromptUpdate,
- )
- from vllm.multimodal.profiling import BaseDummyInputsBuilder
- from vllm.sequence import IntermediateTensors
- def smart_resize(
- height: int,
- width: int,
- factor: int = 28,
- min_pixels: int = 28 * 28 * 130,
- max_pixels: int = 28 * 28 * 1280,
- ):
- """Rescales the image so that the following conditions are met:
- 1. Both dimensions (height and width) are divisible by 'factor'.
- 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
- 3. The aspect ratio of the image is maintained as closely as possible.
- """
- # if height < factor or width < factor:
- # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
- # if int(height < factor//4) + int(width < factor//4):
- # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor//4}")
- if height < factor:
- print(
- f"smart_resize: height={height} < factor={factor}, reset height=factor"
- )
- width = round((width * factor) / height)
- height = factor
- if width < factor:
- print(f"smart_resize: width={width} < factor={factor}, reset width=factor")
- height = round((height * factor) / width)
- width = factor
- if max(height, width) / min(height, width) > 200:
- raise ValueError(
- f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
- )
- h_bar = round(height / factor) * factor
- w_bar = round(width / factor) * factor
- if h_bar * w_bar > max_pixels:
- beta = math.sqrt((height * width) / max_pixels)
- h_bar = math.floor(height / beta / factor) * factor
- w_bar = math.floor(width / beta / factor) * factor
- elif h_bar * w_bar < min_pixels:
- beta = math.sqrt(min_pixels / (height * width))
- h_bar = math.ceil(height * beta / factor) * factor
- w_bar = math.ceil(width * beta / factor) * factor
- return h_bar, w_bar
- class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
- def get_hf_config(self):
- return self.ctx.get_hf_config()
- def get_hf_processor(self, **kwargs: object):
- return self.ctx.get_hf_processor(**kwargs)
- def get_image_processor(self, **kwargs: object):
- return self.get_hf_processor(**kwargs).image_processor
- def get_supported_mm_limits(self):
- return {"image": None}
- def get_num_image_tokens(
- self,
- *,
- image_width: int,
- image_height: int,
- image_processor,
- ) -> int:
- if image_processor is None:
- image_processor = self.get_image_processor()
- do_resize = True
- hf_config = self.get_hf_config()
- vision_config = hf_config.vision_config
- patch_size = vision_config.patch_size
- merge_size = vision_config.spatial_merge_size
- if do_resize:
- resized_height, resized_width = smart_resize(
- height=image_height,
- width=image_width,
- factor=patch_size * merge_size,
- min_pixels=image_processor.min_pixels,
- max_pixels=image_processor.max_pixels,
- )
- preprocessed_size = ImageSize(
- width=resized_width, height=resized_height
- )
- else:
- preprocessed_size = ImageSize(width=image_width, height=image_height)
- grid_t = 1
- grid_h = preprocessed_size.height // patch_size
- grid_w = preprocessed_size.width // patch_size
- num_patches = grid_t * grid_h * grid_w
- num_image_tokens = num_patches // (merge_size**2)
- return num_image_tokens
- def get_image_size_with_most_features(self) -> ImageSize:
- hf_config = self.get_hf_config()
- image_size = hf_config.vision_config.image_size
- return ImageSize(height=image_size, width=image_size)
- class PaddleOCRVLDummyInputsBuilder(
- BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]
- ):
- def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
- num_images = mm_counts.get("image", 0)
- processor = self.info.get_hf_processor()
- image_token = processor.image_token
- return image_token * num_images
- def get_dummy_mm_data(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- ) -> MultiModalDataDict:
- num_images = mm_counts.get("image", 0)
- (target_width, target_height) = (
- self.info.get_image_size_with_most_features()
- )
- return {
- "image": self._get_dummy_images(
- width=target_width, height=target_height, num_images=num_images
- )
- }
- class PaddleOCRVLMultiModalProcessor(
- BaseMultiModalProcessor[PaddleOCRVLProcessingInfo]
- ):
- def _call_hf_processor(
- self,
- prompt: str,
- mm_data: Mapping[str, object],
- mm_kwargs: Mapping[str, object],
- tok_kwargs: Mapping[str, object],
- ) -> BatchFeature:
- if mm_data:
- processed_outputs = self.info.ctx.call_hf_processor(
- self.info.get_hf_processor(**mm_kwargs),
- dict(text=prompt, **mm_data),
- dict(**mm_kwargs, **tok_kwargs),
- )
- processed_outputs["pixel_values"] = processed_outputs[
- "pixel_values"
- ].unsqueeze(0)
- else:
- tokenizer = self.info.get_tokenizer()
- processed_outputs = tokenizer(
- prompt, add_special_tokens=True, return_tensors="pt"
- )
- return processed_outputs
- def _get_mm_fields_config(
- self,
- hf_inputs: BatchFeature,
- hf_processor_mm_kwargs: Mapping[str, object],
- ) -> Mapping[str, MultiModalFieldConfig]:
- return dict(
- pixel_values=MultiModalFieldConfig.batched("image"),
- image_grid_thw=MultiModalFieldConfig.batched("image"),
- )
- def _get_prompt_updates(
- self,
- mm_items: MultiModalDataItems,
- hf_processor_mm_kwargs: Mapping[str, object],
- out_mm_kwargs: MultiModalKwargs,
- ) -> Sequence[PromptUpdate]:
- image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
- hf_config = self.info.get_hf_config()
- image_token_id = hf_config.image_token_id
- def get_replacement(item_idx: int, image_processor):
- images = mm_items.get_items("image", ImageProcessorItems)
- image_size = images.get_image_size(item_idx)
- num_image_tokens = self.info.get_num_image_tokens(
- image_width=image_size.width,
- image_height=image_size.height,
- image_processor=image_processor,
- )
- return [image_token_id] * num_image_tokens
- return [
- PromptReplacement(
- modality="image",
- target=[image_token_id],
- replacement=partial(
- get_replacement, image_processor=image_processor
- ),
- ),
- ]
- class Projector(nn.Module):
- def __init__(
- self,
- text_config,
- vision_config,
- prefix: str = "",
- ):
- super().__init__()
- self.text_config = text_config
- self.vision_config = vision_config
- self.merge_kernel_size = (2, 2)
- self.hidden_size = (
- self.vision_config.hidden_size
- * self.merge_kernel_size[0]
- * self.merge_kernel_size[1]
- )
- self.pre_norm = torch.nn.LayerNorm(
- self.vision_config.hidden_size, eps=1e-05
- )
- self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
- self.act = GELUActivation()
- self.linear_2 = nn.Linear(
- self.hidden_size, self.text_config.hidden_size, bias=True
- )
- def forward(
- self,
- image_features: torch.Tensor,
- image_grid_thw: List[Tuple[int, int, int]],
- ) -> torch.Tensor:
- m1, m2 = self.merge_kernel_size
- if isinstance(image_features, (list, tuple)):
- processed_features = list()
- for image_feature, image_grid in zip(image_features, image_grid_thw):
- image_feature = self.pre_norm(image_feature)
- t, h, w = image_grid
- image_feature = rearrange(
- image_feature,
- "(t h p1 w p2) d -> (t h w) (p1 p2 d)",
- t=t,
- h=h // m1,
- p1=m1,
- w=w // m2,
- p2=m2,
- )
- hidden_states = self.linear_1(image_feature)
- hidden_states = self.act(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- processed_features.append(hidden_states)
- return processed_features
- dims = image_features.shape[:-1]
- dim = image_features.shape[-1]
- image_features = image_features.view(np.prod(dims), dim)
- hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
- hidden_states = self.linear_1(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- return hidden_states.view(*dims, -1)
- class SiglipVisionEmbeddings(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
- self.patch_embedding = nn.Conv2d(
- in_channels=config.num_channels,
- out_channels=self.embed_dim,
- kernel_size=self.patch_size,
- stride=self.patch_size,
- padding="valid",
- )
- self.num_patches = (self.image_size // self.patch_size) ** 2
- self.num_positions = self.num_patches
- self.cache_position_embedding = dict()
- self.cache_position_count = dict()
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
- self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
- self.register_buffer(
- "position_ids",
- torch.arange(self.num_positions).expand((1, -1)),
- persistent=False,
- )
- def interpolate_pos_encoding(
- self,
- embeddings: torch.Tensor,
- height: int,
- width: int,
- is_after_patchify: bool = False,
- ) -> torch.Tensor:
- num_positions = self.position_embedding.weight.shape[0]
- patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
- dim = embeddings.shape[-1]
- if is_after_patchify:
- new_height = height
- new_width = width
- else:
- new_height = height // self.patch_size
- new_width = width // self.patch_size
- sqrt_num_positions = torch_int(num_positions**0.5)
- patch_pos_embed = patch_pos_embed.reshape(
- 1, sqrt_num_positions, sqrt_num_positions, dim
- )
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed,
- size=(new_height, new_width),
- mode="bilinear",
- align_corners=False,
- )
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return patch_pos_embed
- def fetch_position_embedding_lfu_cache(
- self, embeddings, h, w, max_cache: int = 20
- ):
- grid = (h, w)
- if grid in self.cache_position_embedding:
- self.cache_position_count[grid] += 1
- return self.cache_position_embedding[grid]
- if len(self.cache_position_embedding) >= max_cache:
- min_hit_grid = min(
- self.cache_position_count,
- key=self.cache_position_count.get,
- )
- self.cache_position_count.pop(min_hit_grid)
- self.cache_position_embedding.pop(min_hit_grid)
- position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
- self.cache_position_count[grid] = 1
- self.cache_position_embedding[grid] = position_embedding
- return position_embedding
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- position_ids: Optional[torch.Tensor] = None,
- image_grid_thw: Optional[
- List[
- Union[
- Tuple[int, int, int],
- List[Tuple[int, int, int]],
- ]
- ]
- ] = None,
- interpolate_pos_encoding=False,
- ) -> torch.Tensor:
- if pixel_values.dim() == 4:
- pixel_values = pixel_values.unsqueeze(0)
- if pixel_values.dim() == 5:
- if position_ids is None:
- raise ValueError(
- "position_ids cannot be None when pixel_values.dim() is 5."
- )
- (
- batch_size,
- squence_len,
- channel,
- height,
- width,
- ) = pixel_values.shape
- target_dtype = self.patch_embedding.weight.dtype
- pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
- patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
- embeddings = patch_embeds.flatten(-2).squeeze(-1)
- if interpolate_pos_encoding and image_grid_thw is not None:
- start = 0
- tmp_embeddings = list()
- for image_grid in image_grid_thw:
- t, h, w = image_grid
- end = start + t * h * w
- image_embeddings = embeddings[start:end, :]
- position_embedding = (
- self.interpolate_pos_encoding(image_embeddings, h, w, True)
- .squeeze(0)
- .repeat(t, 1)
- )
- image_embeddings = image_embeddings + position_embedding
- tmp_embeddings.append(image_embeddings)
- start = end
- embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
- else:
- embeddings = embeddings + self.packing_position_embedding(
- position_ids
- )
- return embeddings
- else:
- raise ValueError(
- "Unsupported pixel_values dimension:"
- f" {pixel_values.dim()}. Expected 4 or 5."
- )
- def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
- if not interleaved:
- x1, x2 = x.chunk(2, dim=-1)
- return torch.cat((-x2, x1), dim=-1)
- else:
- x1, x2 = x[..., ::2], x[..., 1::2]
- return rearrange(
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
- )
- def apply_rotary_emb_torch(
- x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
- ) -> torch.Tensor:
- """
- x: (batch_size, seqlen, nheads, headdim)
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
- """
- ro_dim = cos.shape[-1] * 2
- assert ro_dim <= x.shape[-1]
- cos = repeat(
- cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
- )
- sin = repeat(
- sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
- )
- return torch.cat(
- [
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
- x[..., ro_dim:],
- ],
- dim=-1,
- )
- def apply_rotary_pos_emb_flashatt(
- q: torch.Tensor,
- k: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- cos = cos.chunk(2, dim=-1)[0].contiguous()
- sin = sin.chunk(2, dim=-1)[0].contiguous()
- apply_rotary_emb = apply_rotary_emb_torch
- if current_platform.is_cuda():
- from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
- q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
- k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
- return q_embed, k_embed
- class SiglipAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You
- Need' paper."""
- def __init__(
- self,
- config,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ):
- super().__init__()
- self.config = config
- hidden_size = config.hidden_size
- self.hidden_size = config.hidden_size
- tp_size = get_tensor_model_parallel_world_size()
- self.total_num_heads = config.num_attention_heads
- assert self.total_num_heads % tp_size == 0
- self.num_heads = self.total_num_heads // tp_size
- self.total_num_kv_heads = config.num_attention_heads
- if self.total_num_kv_heads >= tp_size:
- assert self.total_num_kv_heads % tp_size == 0
- else:
- assert tp_size % self.total_num_kv_heads == 0
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
- self.head_dim = config.hidden_size // self.total_num_heads
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
- self.scale = self.head_dim**-0.5
- self.qkv_proj = QKVParallelLinear(
- hidden_size,
- self.head_dim,
- self.total_num_heads,
- self.total_num_kv_heads,
- bias=True,
- quant_config=quant_config,
- prefix=f"{prefix}.qkv_proj",
- )
- self.out_proj = RowParallelLinear(
- input_size=hidden_size,
- output_size=hidden_size,
- quant_config=quant_config,
- prefix=f"{prefix}.out_proj",
- )
- # Detect attention implementation.
- self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
- if self.attn_backend not in {
- _Backend.FLASH_ATTN,
- _Backend.TORCH_SDPA,
- _Backend.XFORMERS,
- }:
- raise RuntimeError(
- f"PaddleOCR-VL does not support {self.attn_backend} backend now."
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- cu_seqlens: Optional[List[torch.Tensor]] = None,
- rope_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- ) -> torch.Tensor:
- batch_size, seq_length, embed_dim = hidden_states.shape
- qkv_states, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv_states.chunk(3, dim=-1)
- q = q.view(batch_size, seq_length, self.num_heads, self.head_dim)
- k = k.view(batch_size, seq_length, self.num_heads, self.head_dim)
- v = v.view(batch_size, seq_length, self.num_heads, self.head_dim)
- if rope_emb is not None:
- cos, sin = rope_emb
- q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
- if self.attn_backend == _Backend.FLASH_ATTN:
- from flash_attn import flash_attn_varlen_func
- q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- output = flash_attn_varlen_func(
- q,
- k,
- v,
- cu_seqlens_q=cu_seqlens,
- cu_seqlens_k=cu_seqlens,
- max_seqlen_q=max_seqlen,
- max_seqlen_k=max_seqlen,
- )
- context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
- elif self.attn_backend == _Backend.TORCH_SDPA:
- # Execute attention entry by entry for speed & less VRAM.
- import torch.nn.functional as F
- outputs = []
- for i in range(1, len(cu_seqlens)):
- start_idx = cu_seqlens[i - 1]
- end_idx = cu_seqlens[i]
- q_i = q[:, start_idx:end_idx]
- k_i = k[:, start_idx:end_idx]
- v_i = v[:, start_idx:end_idx]
- q_i, k_i, v_i = (
- rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
- )
- output_i = F.scaled_dot_product_attention(
- q_i, k_i, v_i, dropout_p=0.0
- )
- output_i = rearrange(output_i, "b h s d -> b s h d ")
- outputs.append(output_i)
- context_layer = torch.cat(outputs, dim=1)
- elif self.attn_backend == _Backend.XFORMERS:
- from xformers import ops as xops
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- attn_bias = BlockDiagonalMask.from_seqlens(
- q_seqlen=seqlens, kv_seqlen=None, device=q.device
- )
- context_layer = xops.memory_efficient_attention_forward(
- q, k, v, attn_bias=attn_bias, p=0, scale=None
- )
- context_layer = rearrange(
- context_layer, "b s h d -> b s (h d)"
- ).contiguous()
- output, _ = self.out_proj(context_layer)
- return output
- class SigLIPRotaryEmbedding(nn.Module):
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
- super().__init__()
- self.dim = dim
- self.theta = theta
- self.rope_init()
- def rope_init(self):
- inv_freq = 1.0 / (
- self.theta
- ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- def forward(self, seqlen: int) -> torch.Tensor:
- seq = torch.arange(
- seqlen,
- device=self.inv_freq.device,
- dtype=self.inv_freq.dtype,
- )
- freqs = torch.outer(seq, self.inv_freq)
- return freqs
- class SiglipMLP(nn.Module):
- def __init__(
- self,
- config,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.config = config
- self.activation_fn = get_act_fn(config.hidden_act)
- # Special handling for BNB and torchao quantization
- if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
- quantizable = True
- else:
- # For other quantization, we require the hidden size to be a
- # multiple of 64
- quantizable = (
- config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
- )
- self.fc1 = ColumnParallelLinear(
- config.hidden_size,
- config.intermediate_size,
- quant_config=quant_config if quantizable else None,
- prefix=f"{prefix}.fc1",
- )
- self.fc2 = RowParallelLinear(
- config.intermediate_size,
- config.hidden_size,
- quant_config=quant_config if quantizable else None,
- prefix=f"{prefix}.fc2",
- )
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states, _ = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states, _ = self.fc2(hidden_states)
- return hidden_states
- class SiglipEncoderLayer(nn.Module):
- def __init__(
- self,
- config,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.self_attn = SiglipAttention(
- config,
- quant_config=quant_config,
- prefix=f"{prefix}.self_attn",
- )
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.mlp = SiglipMLP(
- config,
- quant_config=quant_config,
- prefix=f"{prefix}.mlp",
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- cu_seqlens: Optional[List[torch.Tensor]] = None,
- rope_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- ) -> Tuple[torch.FloatTensor]:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states = self.self_attn(
- hidden_states=hidden_states,
- cu_seqlens=cu_seqlens,
- rope_emb=rope_emb,
- )
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- class SiglipEncoder(nn.Module):
- def __init__(
- self,
- config,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ):
- super().__init__()
- self.config = config
- embed_dim = config.hidden_size
- num_heads = config.num_attention_heads
- head_dim = embed_dim // num_heads
- self.layers = nn.ModuleList(
- [
- SiglipEncoderLayer(
- config,
- quant_config=quant_config,
- prefix=f"{prefix}.layers.{layer_idx}",
- )
- for layer_idx in range(config.num_hidden_layers)
- ]
- )
- self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
- @staticmethod
- def flatten_list(image_grid_thw):
- tmp_image_grid_thw = list()
- for image_grid in image_grid_thw:
- if isinstance(image_grid, list):
- tmp_image_grid_thw.extend(image_grid)
- else:
- tmp_image_grid_thw.append(image_grid)
- return tmp_image_grid_thw
- def forward(
- self,
- inputs_embeds,
- cu_seqlens: Optional[List[torch.Tensor]] = None,
- image_grid_thw: Optional[
- List[
- Union[
- Tuple[int, int, int],
- List[Tuple[int, int, int]],
- ]
- ]
- ] = None,
- height_position_ids: Optional[torch.Tensor] = None,
- width_position_ids: Optional[torch.Tensor] = None,
- ) -> BaseModelOutput:
- device = inputs_embeds.device
- hidden_states = inputs_embeds
- flatten_image_grid_thw = self.flatten_list(image_grid_thw)
- if width_position_ids is None or height_position_ids is None:
- split_hids = list()
- split_wids = list()
- for t, h, w in flatten_image_grid_thw:
- image_pids = torch.arange(t * h * w, device=device) % (h * w)
- sample_hids = image_pids // w
- sample_wids = image_pids % w
- split_hids.append(sample_hids)
- split_wids.append(sample_wids)
- width_position_ids = torch.concat(split_wids, dim=0)
- height_position_ids = torch.concat(split_hids, dim=0)
- pids = torch.stack(
- [height_position_ids, width_position_ids],
- dim=-1,
- )
- max_grid_size = pids.max() + 1
- rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
- rope_emb = rope_emb_max_grid[pids].flatten(1)
- rope_emb = rope_emb.repeat(1, 2)
- rope_emb = (rope_emb.cos(), rope_emb.sin())
- attn_cu_seqlens = cu_seqlens
- hidden_states = inputs_embeds
- for encoder_layer in self.layers:
- hidden_states = encoder_layer(
- hidden_states,
- cu_seqlens=attn_cu_seqlens,
- rope_emb=rope_emb,
- )
- return hidden_states
- class SiglipVisionTransformer(nn.Module):
- def __init__(
- self,
- config,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ):
- super().__init__()
- self.config = config
- embed_dim = config.hidden_size
- self.embeddings = SiglipVisionEmbeddings(config)
- self.encoder = SiglipEncoder(
- config,
- quant_config=quant_config,
- prefix=f"{prefix}.encoder",
- )
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- def forward(
- self,
- pixel_values,
- interpolate_pos_encoding: Optional[bool] = False,
- position_ids: Optional[torch.Tensor] = None,
- height_position_ids: Optional[torch.Tensor] = None,
- width_position_ids: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[List[torch.Tensor]] = None,
- image_grid_thw: Optional[
- List[
- Union[
- Tuple[int, int, int],
- List[Tuple[int, int, int]],
- ]
- ]
- ] = None,
- ) -> BaseModelOutputWithPooling:
- hidden_states = self.embeddings(
- pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- position_ids=position_ids,
- image_grid_thw=image_grid_thw,
- )
- last_hidden_state = self.encoder(
- inputs_embeds=hidden_states,
- cu_seqlens=cu_seqlens,
- image_grid_thw=image_grid_thw,
- height_position_ids=height_position_ids,
- width_position_ids=width_position_ids,
- )
- last_hidden_state = self.post_layernorm(last_hidden_state)
- sample_hidden_state = list()
- if cu_seqlens is None:
- raise ValueError(
- "cu_seqlens cannot be None for "
- "SiglipVisionTransformer output processing."
- )
- for i in range(cu_seqlens.shape[0] - 1):
- start = cu_seqlens[i]
- end = cu_seqlens[i + 1]
- tensor = last_hidden_state[:, start:end, :].squeeze(0)
- sample_hidden_state.append(tensor)
- return sample_hidden_state
- class SiglipVisionModel(nn.Module):
- config_class = "PaddleOCRVisionConfig"
- main_input_name = "pixel_values"
- def __init__(
- self,
- config,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ):
- super().__init__()
- self.vision_model = SiglipVisionTransformer(
- config,
- quant_config=quant_config,
- prefix=f"{prefix}.vision_model",
- )
- self.quant_config = quant_config
- @property
- def dtype(self) -> torch.dtype:
- return self.vision_model.embeddings.patch_embedding.weight.dtype
- @property
- def device(self) -> torch.device:
- return self.vision_model.embeddings.patch_embedding.weight.device
- def get_input_embeddings(self) -> nn.Module:
- return self.vision_model.embeddings.patch_embedding
- def forward(
- self,
- pixel_values,
- interpolate_pos_encoding: bool = False,
- position_ids: Optional[torch.Tensor] = None,
- image_grid_thw: Optional[
- List[
- Union[
- Tuple[int, int, int],
- List[Tuple[int, int, int]],
- ]
- ]
- ] = None,
- cu_seqlens: Optional[List[torch.Tensor]] = None,
- ) -> BaseModelOutputWithPooling:
- return self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- position_ids=position_ids,
- image_grid_thw=image_grid_thw,
- cu_seqlens=cu_seqlens,
- )
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]:
- stacked_params_mapping = [
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ]
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- loaded_params: set[str] = set()
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- if "head.attention" in name or "head.layernorm" in name:
- continue
- if "head.mlp" in name or "head.probe" in name:
- continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(
- param,
- "weight_loader",
- default_weight_loader,
- )
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
- for (
- param_name,
- weight_name,
- shard_id,
- ) in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- if name.endswith(".bias") and name not in params_dict:
- continue
- if is_pp_missing_parameter(name, self):
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- if name.endswith(".bias") and name not in params_dict:
- continue
- name = maybe_remap_kv_scale_name(name, params_dict)
- if name is None:
- continue
- if is_pp_missing_parameter(name, self):
- continue
- param = params_dict[name]
- weight_loader = getattr(
- param,
- "weight_loader",
- default_weight_loader,
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- return loaded_params
- @MULTIMODAL_REGISTRY.register_processor(
- PaddleOCRVLMultiModalProcessor,
- info=PaddleOCRVLProcessingInfo,
- dummy_inputs=PaddleOCRVLDummyInputsBuilder,
- )
- @support_torch_compile(
- # set dynamic_arg_dims to support mrope
- dynamic_arg_dims={
- "input_ids": 0,
- "positions": -1,
- "intermediate_tensors": 0,
- "inputs_embeds": 0,
- }
- )
- class PaddleOCRVLForConditionalGeneration(Ernie4_5_ForCausalLM, SupportsMultiModal):
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__(vllm_config=vllm_config, prefix=prefix)
- config = self.config
- self.mlp_AR = Projector(config, config.vision_config)
- self.visual = SiglipVisionModel(config=config.vision_config)
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
- self.logits_processor = LogitsProcessor(config.vocab_size)
- for layer in self.model.layers:
- if not isinstance(layer, PPMissingLayer):
- layer.self_attn.rotary_emb.is_neox_style = True
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata,
- ) -> Optional[torch.Tensor]:
- logits = self.logits_processor(
- self.lm_head, hidden_states, sampling_metadata
- )
- return logits
- @property
- def language_model(self):
- return self.model
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- **kwargs,
- ):
- if intermediate_tensors is not None:
- inputs_embeds = None
- elif inputs_embeds is None:
- vision_embeddings = self.get_multimodal_embeddings(**kwargs)
- inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings)
- input_ids = None
- return self.language_model(
- input_ids, positions, intermediate_tensors, inputs_embeds
- )
- @classmethod
- def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
- if modality.startswith("image"):
- return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
- raise ValueError("Only image modality is supported")
- def encode_image(self, pixel_values, image_grid_thw):
- pixel_values = pixel_values.type(self.visual.dtype)
- siglip_position_ids = list()
- image_grid_hws = list()
- cu_seqlens = [0]
- for idx, thw in enumerate(image_grid_thw):
- thw_tuple = tuple(thw.detach().cpu().numpy().tolist())
- numel = np.prod(thw_tuple)
- image_grid_hws.append(thw_tuple)
- image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
- siglip_position_ids.append(image_position_ids)
- cu_seqlens.append(cu_seqlens[-1] + numel)
- siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
- pixel_values.device
- )
- cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
- pixel_values.device
- )
- vision_outputs = self.visual(
- pixel_values=pixel_values,
- image_grid_thw=image_grid_hws,
- position_ids=siglip_position_ids,
- interpolate_pos_encoding=True,
- cu_seqlens=cu_seqlens,
- )
- image_embeds = self.mlp_AR(vision_outputs, image_grid_thw)
- return image_embeds
- def get_multimodal_embeddings(self, **kwargs):
- pixel_values = kwargs["pixel_values"]
- image_grid_thw = kwargs["image_grid_thw"]
- multimodal_embeddings = []
- for pv, ig in zip(pixel_values, image_grid_thw):
- if pv is not None:
- image_embeds = self.encode_image(pv, ig)
- multimodal_embeddings += image_embeds
- return multimodal_embeddings
- def get_input_embeddings(
- self,
- input_ids: torch.Tensor,
- multimodal_embeddings: Optional[NestedTensors] = None,
- ) -> torch.Tensor:
- inputs_embeds = self.language_model.get_input_embeddings(input_ids)
- if multimodal_embeddings is not None and len(multimodal_embeddings) != 0:
- inputs_embeds = merge_multimodal_embeddings(
- input_ids,
- inputs_embeds,
- multimodal_embeddings,
- self.config.image_token_id,
- )
- return inputs_embeds
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
- autoloaded_weights = loader.load_weights(weights)
- return autoloaded_weights
|