_vllm.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from collections.abc import Iterable, Mapping, Sequence
  16. from functools import partial
  17. from typing import List, Optional, Tuple, Union
  18. import numpy as np
  19. from .....utils.deps import is_dep_available
  20. if all(
  21. map(is_dep_available, ("einops", "torch", "transformers", "vllm", "flash-attn"))
  22. ):
  23. import torch
  24. import torch.nn as nn
  25. from einops import rearrange, repeat
  26. from transformers import BatchFeature
  27. from transformers.activations import GELUActivation
  28. from transformers.modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPooling,
  31. )
  32. from transformers.utils import torch_int
  33. from vllm.compilation.decorators import support_torch_compile
  34. from vllm.config import VllmConfig
  35. from vllm.distributed import get_tensor_model_parallel_world_size
  36. from vllm.model_executor.layers.activation import get_act_fn
  37. from vllm.model_executor.layers.linear import (
  38. ColumnParallelLinear,
  39. QKVParallelLinear,
  40. RowParallelLinear,
  41. )
  42. from vllm.model_executor.layers.logits_processor import LogitsProcessor
  43. from vllm.model_executor.layers.quantization import QuantizationConfig
  44. from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
  45. from vllm.model_executor.model_loader.weight_utils import (
  46. default_weight_loader,
  47. maybe_remap_kv_scale_name,
  48. )
  49. from vllm.model_executor.models.vision import get_vit_attn_backend
  50. from vllm.platforms import _Backend, current_platform
  51. try:
  52. from vllm.model_executor.models.ernie45 import Ernie4_5_ForCausalLM
  53. except ImportError:
  54. from vllm.model_executor.models.ernie45 import (
  55. Ernie4_5ForCausalLM as Ernie4_5_ForCausalLM,
  56. )
  57. from vllm.model_executor.models.interfaces import SupportsMultiModal
  58. from vllm.model_executor.models.utils import (
  59. AutoWeightsLoader,
  60. PPMissingLayer,
  61. is_pp_missing_parameter,
  62. merge_multimodal_embeddings,
  63. )
  64. from vllm.multimodal import MULTIMODAL_REGISTRY
  65. from vllm.multimodal.inputs import (
  66. MultiModalDataDict,
  67. MultiModalFieldConfig,
  68. MultiModalKwargs,
  69. NestedTensors,
  70. )
  71. from vllm.multimodal.parse import (
  72. ImageProcessorItems,
  73. ImageSize,
  74. MultiModalDataItems,
  75. )
  76. from vllm.multimodal.processing import (
  77. BaseMultiModalProcessor,
  78. BaseProcessingInfo,
  79. PromptReplacement,
  80. PromptUpdate,
  81. )
  82. from vllm.multimodal.profiling import BaseDummyInputsBuilder
  83. from vllm.sequence import IntermediateTensors
  84. def smart_resize(
  85. height: int,
  86. width: int,
  87. factor: int = 28,
  88. min_pixels: int = 28 * 28 * 130,
  89. max_pixels: int = 28 * 28 * 1280,
  90. ):
  91. """Rescales the image so that the following conditions are met:
  92. 1. Both dimensions (height and width) are divisible by 'factor'.
  93. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
  94. 3. The aspect ratio of the image is maintained as closely as possible.
  95. """
  96. # if height < factor or width < factor:
  97. # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
  98. # if int(height < factor//4) + int(width < factor//4):
  99. # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor//4}")
  100. if height < factor:
  101. print(
  102. f"smart_resize: height={height} < factor={factor}, reset height=factor"
  103. )
  104. width = round((width * factor) / height)
  105. height = factor
  106. if width < factor:
  107. print(f"smart_resize: width={width} < factor={factor}, reset width=factor")
  108. height = round((height * factor) / width)
  109. width = factor
  110. if max(height, width) / min(height, width) > 200:
  111. raise ValueError(
  112. f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
  113. )
  114. h_bar = round(height / factor) * factor
  115. w_bar = round(width / factor) * factor
  116. if h_bar * w_bar > max_pixels:
  117. beta = math.sqrt((height * width) / max_pixels)
  118. h_bar = math.floor(height / beta / factor) * factor
  119. w_bar = math.floor(width / beta / factor) * factor
  120. elif h_bar * w_bar < min_pixels:
  121. beta = math.sqrt(min_pixels / (height * width))
  122. h_bar = math.ceil(height * beta / factor) * factor
  123. w_bar = math.ceil(width * beta / factor) * factor
  124. return h_bar, w_bar
  125. class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
  126. def get_hf_config(self):
  127. return self.ctx.get_hf_config()
  128. def get_hf_processor(self, **kwargs: object):
  129. return self.ctx.get_hf_processor(**kwargs)
  130. def get_image_processor(self, **kwargs: object):
  131. return self.get_hf_processor(**kwargs).image_processor
  132. def get_supported_mm_limits(self):
  133. return {"image": None}
  134. def get_num_image_tokens(
  135. self,
  136. *,
  137. image_width: int,
  138. image_height: int,
  139. image_processor,
  140. ) -> int:
  141. if image_processor is None:
  142. image_processor = self.get_image_processor()
  143. do_resize = True
  144. hf_config = self.get_hf_config()
  145. vision_config = hf_config.vision_config
  146. patch_size = vision_config.patch_size
  147. merge_size = vision_config.spatial_merge_size
  148. if do_resize:
  149. resized_height, resized_width = smart_resize(
  150. height=image_height,
  151. width=image_width,
  152. factor=patch_size * merge_size,
  153. min_pixels=image_processor.min_pixels,
  154. max_pixels=image_processor.max_pixels,
  155. )
  156. preprocessed_size = ImageSize(
  157. width=resized_width, height=resized_height
  158. )
  159. else:
  160. preprocessed_size = ImageSize(width=image_width, height=image_height)
  161. grid_t = 1
  162. grid_h = preprocessed_size.height // patch_size
  163. grid_w = preprocessed_size.width // patch_size
  164. num_patches = grid_t * grid_h * grid_w
  165. num_image_tokens = num_patches // (merge_size**2)
  166. return num_image_tokens
  167. def get_image_size_with_most_features(self) -> ImageSize:
  168. hf_config = self.get_hf_config()
  169. image_size = hf_config.vision_config.image_size
  170. return ImageSize(height=image_size, width=image_size)
  171. class PaddleOCRVLDummyInputsBuilder(
  172. BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]
  173. ):
  174. def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
  175. num_images = mm_counts.get("image", 0)
  176. processor = self.info.get_hf_processor()
  177. image_token = processor.image_token
  178. return image_token * num_images
  179. def get_dummy_mm_data(
  180. self,
  181. seq_len: int,
  182. mm_counts: Mapping[str, int],
  183. ) -> MultiModalDataDict:
  184. num_images = mm_counts.get("image", 0)
  185. (target_width, target_height) = (
  186. self.info.get_image_size_with_most_features()
  187. )
  188. return {
  189. "image": self._get_dummy_images(
  190. width=target_width, height=target_height, num_images=num_images
  191. )
  192. }
  193. class PaddleOCRVLMultiModalProcessor(
  194. BaseMultiModalProcessor[PaddleOCRVLProcessingInfo]
  195. ):
  196. def _call_hf_processor(
  197. self,
  198. prompt: str,
  199. mm_data: Mapping[str, object],
  200. mm_kwargs: Mapping[str, object],
  201. tok_kwargs: Mapping[str, object],
  202. ) -> BatchFeature:
  203. if mm_data:
  204. processed_outputs = self.info.ctx.call_hf_processor(
  205. self.info.get_hf_processor(**mm_kwargs),
  206. dict(text=prompt, **mm_data),
  207. dict(**mm_kwargs, **tok_kwargs),
  208. )
  209. processed_outputs["pixel_values"] = processed_outputs[
  210. "pixel_values"
  211. ].unsqueeze(0)
  212. else:
  213. tokenizer = self.info.get_tokenizer()
  214. processed_outputs = tokenizer(
  215. prompt, add_special_tokens=True, return_tensors="pt"
  216. )
  217. return processed_outputs
  218. def _get_mm_fields_config(
  219. self,
  220. hf_inputs: BatchFeature,
  221. hf_processor_mm_kwargs: Mapping[str, object],
  222. ) -> Mapping[str, MultiModalFieldConfig]:
  223. return dict(
  224. pixel_values=MultiModalFieldConfig.batched("image"),
  225. image_grid_thw=MultiModalFieldConfig.batched("image"),
  226. )
  227. def _get_prompt_updates(
  228. self,
  229. mm_items: MultiModalDataItems,
  230. hf_processor_mm_kwargs: Mapping[str, object],
  231. out_mm_kwargs: MultiModalKwargs,
  232. ) -> Sequence[PromptUpdate]:
  233. image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
  234. hf_config = self.info.get_hf_config()
  235. image_token_id = hf_config.image_token_id
  236. def get_replacement(item_idx: int, image_processor):
  237. images = mm_items.get_items("image", ImageProcessorItems)
  238. image_size = images.get_image_size(item_idx)
  239. num_image_tokens = self.info.get_num_image_tokens(
  240. image_width=image_size.width,
  241. image_height=image_size.height,
  242. image_processor=image_processor,
  243. )
  244. return [image_token_id] * num_image_tokens
  245. return [
  246. PromptReplacement(
  247. modality="image",
  248. target=[image_token_id],
  249. replacement=partial(
  250. get_replacement, image_processor=image_processor
  251. ),
  252. ),
  253. ]
  254. class Projector(nn.Module):
  255. def __init__(
  256. self,
  257. text_config,
  258. vision_config,
  259. prefix: str = "",
  260. ):
  261. super().__init__()
  262. self.text_config = text_config
  263. self.vision_config = vision_config
  264. self.merge_kernel_size = (2, 2)
  265. self.hidden_size = (
  266. self.vision_config.hidden_size
  267. * self.merge_kernel_size[0]
  268. * self.merge_kernel_size[1]
  269. )
  270. self.pre_norm = torch.nn.LayerNorm(
  271. self.vision_config.hidden_size, eps=1e-05
  272. )
  273. self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
  274. self.act = GELUActivation()
  275. self.linear_2 = nn.Linear(
  276. self.hidden_size, self.text_config.hidden_size, bias=True
  277. )
  278. def forward(
  279. self,
  280. image_features: torch.Tensor,
  281. image_grid_thw: List[Tuple[int, int, int]],
  282. ) -> torch.Tensor:
  283. m1, m2 = self.merge_kernel_size
  284. if isinstance(image_features, (list, tuple)):
  285. processed_features = list()
  286. for image_feature, image_grid in zip(image_features, image_grid_thw):
  287. image_feature = self.pre_norm(image_feature)
  288. t, h, w = image_grid
  289. image_feature = rearrange(
  290. image_feature,
  291. "(t h p1 w p2) d -> (t h w) (p1 p2 d)",
  292. t=t,
  293. h=h // m1,
  294. p1=m1,
  295. w=w // m2,
  296. p2=m2,
  297. )
  298. hidden_states = self.linear_1(image_feature)
  299. hidden_states = self.act(hidden_states)
  300. hidden_states = self.linear_2(hidden_states)
  301. processed_features.append(hidden_states)
  302. return processed_features
  303. dims = image_features.shape[:-1]
  304. dim = image_features.shape[-1]
  305. image_features = image_features.view(np.prod(dims), dim)
  306. hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
  307. hidden_states = self.linear_1(hidden_states)
  308. hidden_states = self.act(hidden_states)
  309. hidden_states = self.linear_2(hidden_states)
  310. return hidden_states.view(*dims, -1)
  311. class SiglipVisionEmbeddings(nn.Module):
  312. def __init__(self, config):
  313. super().__init__()
  314. self.config = config
  315. self.embed_dim = config.hidden_size
  316. self.image_size = config.image_size
  317. self.patch_size = config.patch_size
  318. self.patch_embedding = nn.Conv2d(
  319. in_channels=config.num_channels,
  320. out_channels=self.embed_dim,
  321. kernel_size=self.patch_size,
  322. stride=self.patch_size,
  323. padding="valid",
  324. )
  325. self.num_patches = (self.image_size // self.patch_size) ** 2
  326. self.num_positions = self.num_patches
  327. self.cache_position_embedding = dict()
  328. self.cache_position_count = dict()
  329. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  330. self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
  331. self.register_buffer(
  332. "position_ids",
  333. torch.arange(self.num_positions).expand((1, -1)),
  334. persistent=False,
  335. )
  336. def interpolate_pos_encoding(
  337. self,
  338. embeddings: torch.Tensor,
  339. height: int,
  340. width: int,
  341. is_after_patchify: bool = False,
  342. ) -> torch.Tensor:
  343. num_positions = self.position_embedding.weight.shape[0]
  344. patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
  345. dim = embeddings.shape[-1]
  346. if is_after_patchify:
  347. new_height = height
  348. new_width = width
  349. else:
  350. new_height = height // self.patch_size
  351. new_width = width // self.patch_size
  352. sqrt_num_positions = torch_int(num_positions**0.5)
  353. patch_pos_embed = patch_pos_embed.reshape(
  354. 1, sqrt_num_positions, sqrt_num_positions, dim
  355. )
  356. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  357. patch_pos_embed = nn.functional.interpolate(
  358. patch_pos_embed,
  359. size=(new_height, new_width),
  360. mode="bilinear",
  361. align_corners=False,
  362. )
  363. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  364. return patch_pos_embed
  365. def fetch_position_embedding_lfu_cache(
  366. self, embeddings, h, w, max_cache: int = 20
  367. ):
  368. grid = (h, w)
  369. if grid in self.cache_position_embedding:
  370. self.cache_position_count[grid] += 1
  371. return self.cache_position_embedding[grid]
  372. if len(self.cache_position_embedding) >= max_cache:
  373. min_hit_grid = min(
  374. self.cache_position_count,
  375. key=self.cache_position_count.get,
  376. )
  377. self.cache_position_count.pop(min_hit_grid)
  378. self.cache_position_embedding.pop(min_hit_grid)
  379. position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
  380. self.cache_position_count[grid] = 1
  381. self.cache_position_embedding[grid] = position_embedding
  382. return position_embedding
  383. def forward(
  384. self,
  385. pixel_values: torch.FloatTensor,
  386. position_ids: Optional[torch.Tensor] = None,
  387. image_grid_thw: Optional[
  388. List[
  389. Union[
  390. Tuple[int, int, int],
  391. List[Tuple[int, int, int]],
  392. ]
  393. ]
  394. ] = None,
  395. interpolate_pos_encoding=False,
  396. ) -> torch.Tensor:
  397. if pixel_values.dim() == 4:
  398. pixel_values = pixel_values.unsqueeze(0)
  399. if pixel_values.dim() == 5:
  400. if position_ids is None:
  401. raise ValueError(
  402. "position_ids cannot be None when pixel_values.dim() is 5."
  403. )
  404. (
  405. batch_size,
  406. squence_len,
  407. channel,
  408. height,
  409. width,
  410. ) = pixel_values.shape
  411. target_dtype = self.patch_embedding.weight.dtype
  412. pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
  413. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  414. embeddings = patch_embeds.flatten(-2).squeeze(-1)
  415. if interpolate_pos_encoding and image_grid_thw is not None:
  416. start = 0
  417. tmp_embeddings = list()
  418. for image_grid in image_grid_thw:
  419. t, h, w = image_grid
  420. end = start + t * h * w
  421. image_embeddings = embeddings[start:end, :]
  422. position_embedding = (
  423. self.interpolate_pos_encoding(image_embeddings, h, w, True)
  424. .squeeze(0)
  425. .repeat(t, 1)
  426. )
  427. image_embeddings = image_embeddings + position_embedding
  428. tmp_embeddings.append(image_embeddings)
  429. start = end
  430. embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
  431. else:
  432. embeddings = embeddings + self.packing_position_embedding(
  433. position_ids
  434. )
  435. return embeddings
  436. else:
  437. raise ValueError(
  438. "Unsupported pixel_values dimension:"
  439. f" {pixel_values.dim()}. Expected 4 or 5."
  440. )
  441. def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
  442. if not interleaved:
  443. x1, x2 = x.chunk(2, dim=-1)
  444. return torch.cat((-x2, x1), dim=-1)
  445. else:
  446. x1, x2 = x[..., ::2], x[..., 1::2]
  447. return rearrange(
  448. torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
  449. )
  450. def apply_rotary_emb_torch(
  451. x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
  452. ) -> torch.Tensor:
  453. """
  454. x: (batch_size, seqlen, nheads, headdim)
  455. cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
  456. """
  457. ro_dim = cos.shape[-1] * 2
  458. assert ro_dim <= x.shape[-1]
  459. cos = repeat(
  460. cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
  461. )
  462. sin = repeat(
  463. sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
  464. )
  465. return torch.cat(
  466. [
  467. x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
  468. x[..., ro_dim:],
  469. ],
  470. dim=-1,
  471. )
  472. def apply_rotary_pos_emb_flashatt(
  473. q: torch.Tensor,
  474. k: torch.Tensor,
  475. cos: torch.Tensor,
  476. sin: torch.Tensor,
  477. ) -> Tuple[torch.Tensor, torch.Tensor]:
  478. cos = cos.chunk(2, dim=-1)[0].contiguous()
  479. sin = sin.chunk(2, dim=-1)[0].contiguous()
  480. apply_rotary_emb = apply_rotary_emb_torch
  481. if current_platform.is_cuda():
  482. from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
  483. q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
  484. k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
  485. return q_embed, k_embed
  486. class SiglipAttention(nn.Module):
  487. """Multi-headed attention from 'Attention Is All You
  488. Need' paper."""
  489. def __init__(
  490. self,
  491. config,
  492. quant_config: Optional[QuantizationConfig] = None,
  493. prefix: str = "",
  494. ):
  495. super().__init__()
  496. self.config = config
  497. hidden_size = config.hidden_size
  498. self.hidden_size = config.hidden_size
  499. tp_size = get_tensor_model_parallel_world_size()
  500. self.total_num_heads = config.num_attention_heads
  501. assert self.total_num_heads % tp_size == 0
  502. self.num_heads = self.total_num_heads // tp_size
  503. self.total_num_kv_heads = config.num_attention_heads
  504. if self.total_num_kv_heads >= tp_size:
  505. assert self.total_num_kv_heads % tp_size == 0
  506. else:
  507. assert tp_size % self.total_num_kv_heads == 0
  508. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  509. self.head_dim = config.hidden_size // self.total_num_heads
  510. self.q_size = self.num_heads * self.head_dim
  511. self.kv_size = self.num_kv_heads * self.head_dim
  512. self.scale = self.head_dim**-0.5
  513. self.qkv_proj = QKVParallelLinear(
  514. hidden_size,
  515. self.head_dim,
  516. self.total_num_heads,
  517. self.total_num_kv_heads,
  518. bias=True,
  519. quant_config=quant_config,
  520. prefix=f"{prefix}.qkv_proj",
  521. )
  522. self.out_proj = RowParallelLinear(
  523. input_size=hidden_size,
  524. output_size=hidden_size,
  525. quant_config=quant_config,
  526. prefix=f"{prefix}.out_proj",
  527. )
  528. # Detect attention implementation.
  529. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
  530. if self.attn_backend not in {
  531. _Backend.FLASH_ATTN,
  532. _Backend.TORCH_SDPA,
  533. _Backend.XFORMERS,
  534. }:
  535. raise RuntimeError(
  536. f"PaddleOCR-VL does not support {self.attn_backend} backend now."
  537. )
  538. def forward(
  539. self,
  540. hidden_states: torch.Tensor,
  541. cu_seqlens: Optional[List[torch.Tensor]] = None,
  542. rope_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
  543. ) -> torch.Tensor:
  544. batch_size, seq_length, embed_dim = hidden_states.shape
  545. qkv_states, _ = self.qkv_proj(hidden_states)
  546. q, k, v = qkv_states.chunk(3, dim=-1)
  547. q = q.view(batch_size, seq_length, self.num_heads, self.head_dim)
  548. k = k.view(batch_size, seq_length, self.num_heads, self.head_dim)
  549. v = v.view(batch_size, seq_length, self.num_heads, self.head_dim)
  550. if rope_emb is not None:
  551. cos, sin = rope_emb
  552. q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
  553. if self.attn_backend == _Backend.FLASH_ATTN:
  554. from flash_attn import flash_attn_varlen_func
  555. q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
  556. max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
  557. output = flash_attn_varlen_func(
  558. q,
  559. k,
  560. v,
  561. cu_seqlens_q=cu_seqlens,
  562. cu_seqlens_k=cu_seqlens,
  563. max_seqlen_q=max_seqlen,
  564. max_seqlen_k=max_seqlen,
  565. )
  566. context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
  567. elif self.attn_backend == _Backend.TORCH_SDPA:
  568. # Execute attention entry by entry for speed & less VRAM.
  569. import torch.nn.functional as F
  570. outputs = []
  571. for i in range(1, len(cu_seqlens)):
  572. start_idx = cu_seqlens[i - 1]
  573. end_idx = cu_seqlens[i]
  574. q_i = q[:, start_idx:end_idx]
  575. k_i = k[:, start_idx:end_idx]
  576. v_i = v[:, start_idx:end_idx]
  577. q_i, k_i, v_i = (
  578. rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
  579. )
  580. output_i = F.scaled_dot_product_attention(
  581. q_i, k_i, v_i, dropout_p=0.0
  582. )
  583. output_i = rearrange(output_i, "b h s d -> b s h d ")
  584. outputs.append(output_i)
  585. context_layer = torch.cat(outputs, dim=1)
  586. elif self.attn_backend == _Backend.XFORMERS:
  587. from xformers import ops as xops
  588. from xformers.ops.fmha.attn_bias import BlockDiagonalMask
  589. seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
  590. attn_bias = BlockDiagonalMask.from_seqlens(
  591. q_seqlen=seqlens, kv_seqlen=None, device=q.device
  592. )
  593. context_layer = xops.memory_efficient_attention_forward(
  594. q, k, v, attn_bias=attn_bias, p=0, scale=None
  595. )
  596. context_layer = rearrange(
  597. context_layer, "b s h d -> b s (h d)"
  598. ).contiguous()
  599. output, _ = self.out_proj(context_layer)
  600. return output
  601. class SigLIPRotaryEmbedding(nn.Module):
  602. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  603. super().__init__()
  604. self.dim = dim
  605. self.theta = theta
  606. self.rope_init()
  607. def rope_init(self):
  608. inv_freq = 1.0 / (
  609. self.theta
  610. ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)
  611. )
  612. self.register_buffer("inv_freq", inv_freq, persistent=False)
  613. def forward(self, seqlen: int) -> torch.Tensor:
  614. seq = torch.arange(
  615. seqlen,
  616. device=self.inv_freq.device,
  617. dtype=self.inv_freq.dtype,
  618. )
  619. freqs = torch.outer(seq, self.inv_freq)
  620. return freqs
  621. class SiglipMLP(nn.Module):
  622. def __init__(
  623. self,
  624. config,
  625. quant_config: Optional[QuantizationConfig] = None,
  626. prefix: str = "",
  627. ) -> None:
  628. super().__init__()
  629. self.config = config
  630. self.activation_fn = get_act_fn(config.hidden_act)
  631. # Special handling for BNB and torchao quantization
  632. if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
  633. quantizable = True
  634. else:
  635. # For other quantization, we require the hidden size to be a
  636. # multiple of 64
  637. quantizable = (
  638. config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
  639. )
  640. self.fc1 = ColumnParallelLinear(
  641. config.hidden_size,
  642. config.intermediate_size,
  643. quant_config=quant_config if quantizable else None,
  644. prefix=f"{prefix}.fc1",
  645. )
  646. self.fc2 = RowParallelLinear(
  647. config.intermediate_size,
  648. config.hidden_size,
  649. quant_config=quant_config if quantizable else None,
  650. prefix=f"{prefix}.fc2",
  651. )
  652. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  653. hidden_states, _ = self.fc1(hidden_states)
  654. hidden_states = self.activation_fn(hidden_states)
  655. hidden_states, _ = self.fc2(hidden_states)
  656. return hidden_states
  657. class SiglipEncoderLayer(nn.Module):
  658. def __init__(
  659. self,
  660. config,
  661. quant_config: Optional[QuantizationConfig] = None,
  662. prefix: str = "",
  663. ):
  664. super().__init__()
  665. self.embed_dim = config.hidden_size
  666. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  667. self.self_attn = SiglipAttention(
  668. config,
  669. quant_config=quant_config,
  670. prefix=f"{prefix}.self_attn",
  671. )
  672. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  673. self.mlp = SiglipMLP(
  674. config,
  675. quant_config=quant_config,
  676. prefix=f"{prefix}.mlp",
  677. )
  678. def forward(
  679. self,
  680. hidden_states: torch.Tensor,
  681. cu_seqlens: Optional[List[torch.Tensor]] = None,
  682. rope_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
  683. ) -> Tuple[torch.FloatTensor]:
  684. residual = hidden_states
  685. hidden_states = self.layer_norm1(hidden_states)
  686. hidden_states = self.self_attn(
  687. hidden_states=hidden_states,
  688. cu_seqlens=cu_seqlens,
  689. rope_emb=rope_emb,
  690. )
  691. hidden_states = residual + hidden_states
  692. residual = hidden_states
  693. hidden_states = self.layer_norm2(hidden_states)
  694. hidden_states = self.mlp(hidden_states)
  695. hidden_states = residual + hidden_states
  696. return hidden_states
  697. class SiglipEncoder(nn.Module):
  698. def __init__(
  699. self,
  700. config,
  701. quant_config: Optional[QuantizationConfig] = None,
  702. prefix: str = "",
  703. ):
  704. super().__init__()
  705. self.config = config
  706. embed_dim = config.hidden_size
  707. num_heads = config.num_attention_heads
  708. head_dim = embed_dim // num_heads
  709. self.layers = nn.ModuleList(
  710. [
  711. SiglipEncoderLayer(
  712. config,
  713. quant_config=quant_config,
  714. prefix=f"{prefix}.layers.{layer_idx}",
  715. )
  716. for layer_idx in range(config.num_hidden_layers)
  717. ]
  718. )
  719. self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
  720. @staticmethod
  721. def flatten_list(image_grid_thw):
  722. tmp_image_grid_thw = list()
  723. for image_grid in image_grid_thw:
  724. if isinstance(image_grid, list):
  725. tmp_image_grid_thw.extend(image_grid)
  726. else:
  727. tmp_image_grid_thw.append(image_grid)
  728. return tmp_image_grid_thw
  729. def forward(
  730. self,
  731. inputs_embeds,
  732. cu_seqlens: Optional[List[torch.Tensor]] = None,
  733. image_grid_thw: Optional[
  734. List[
  735. Union[
  736. Tuple[int, int, int],
  737. List[Tuple[int, int, int]],
  738. ]
  739. ]
  740. ] = None,
  741. height_position_ids: Optional[torch.Tensor] = None,
  742. width_position_ids: Optional[torch.Tensor] = None,
  743. ) -> BaseModelOutput:
  744. device = inputs_embeds.device
  745. hidden_states = inputs_embeds
  746. flatten_image_grid_thw = self.flatten_list(image_grid_thw)
  747. if width_position_ids is None or height_position_ids is None:
  748. split_hids = list()
  749. split_wids = list()
  750. for t, h, w in flatten_image_grid_thw:
  751. image_pids = torch.arange(t * h * w, device=device) % (h * w)
  752. sample_hids = image_pids // w
  753. sample_wids = image_pids % w
  754. split_hids.append(sample_hids)
  755. split_wids.append(sample_wids)
  756. width_position_ids = torch.concat(split_wids, dim=0)
  757. height_position_ids = torch.concat(split_hids, dim=0)
  758. pids = torch.stack(
  759. [height_position_ids, width_position_ids],
  760. dim=-1,
  761. )
  762. max_grid_size = pids.max() + 1
  763. rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
  764. rope_emb = rope_emb_max_grid[pids].flatten(1)
  765. rope_emb = rope_emb.repeat(1, 2)
  766. rope_emb = (rope_emb.cos(), rope_emb.sin())
  767. attn_cu_seqlens = cu_seqlens
  768. hidden_states = inputs_embeds
  769. for encoder_layer in self.layers:
  770. hidden_states = encoder_layer(
  771. hidden_states,
  772. cu_seqlens=attn_cu_seqlens,
  773. rope_emb=rope_emb,
  774. )
  775. return hidden_states
  776. class SiglipVisionTransformer(nn.Module):
  777. def __init__(
  778. self,
  779. config,
  780. quant_config: Optional[QuantizationConfig] = None,
  781. prefix: str = "",
  782. ):
  783. super().__init__()
  784. self.config = config
  785. embed_dim = config.hidden_size
  786. self.embeddings = SiglipVisionEmbeddings(config)
  787. self.encoder = SiglipEncoder(
  788. config,
  789. quant_config=quant_config,
  790. prefix=f"{prefix}.encoder",
  791. )
  792. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  793. def forward(
  794. self,
  795. pixel_values,
  796. interpolate_pos_encoding: Optional[bool] = False,
  797. position_ids: Optional[torch.Tensor] = None,
  798. height_position_ids: Optional[torch.Tensor] = None,
  799. width_position_ids: Optional[torch.Tensor] = None,
  800. cu_seqlens: Optional[List[torch.Tensor]] = None,
  801. image_grid_thw: Optional[
  802. List[
  803. Union[
  804. Tuple[int, int, int],
  805. List[Tuple[int, int, int]],
  806. ]
  807. ]
  808. ] = None,
  809. ) -> BaseModelOutputWithPooling:
  810. hidden_states = self.embeddings(
  811. pixel_values,
  812. interpolate_pos_encoding=interpolate_pos_encoding,
  813. position_ids=position_ids,
  814. image_grid_thw=image_grid_thw,
  815. )
  816. last_hidden_state = self.encoder(
  817. inputs_embeds=hidden_states,
  818. cu_seqlens=cu_seqlens,
  819. image_grid_thw=image_grid_thw,
  820. height_position_ids=height_position_ids,
  821. width_position_ids=width_position_ids,
  822. )
  823. last_hidden_state = self.post_layernorm(last_hidden_state)
  824. sample_hidden_state = list()
  825. if cu_seqlens is None:
  826. raise ValueError(
  827. "cu_seqlens cannot be None for "
  828. "SiglipVisionTransformer output processing."
  829. )
  830. for i in range(cu_seqlens.shape[0] - 1):
  831. start = cu_seqlens[i]
  832. end = cu_seqlens[i + 1]
  833. tensor = last_hidden_state[:, start:end, :].squeeze(0)
  834. sample_hidden_state.append(tensor)
  835. return sample_hidden_state
  836. class SiglipVisionModel(nn.Module):
  837. config_class = "PaddleOCRVisionConfig"
  838. main_input_name = "pixel_values"
  839. def __init__(
  840. self,
  841. config,
  842. quant_config: Optional[QuantizationConfig] = None,
  843. prefix: str = "",
  844. ):
  845. super().__init__()
  846. self.vision_model = SiglipVisionTransformer(
  847. config,
  848. quant_config=quant_config,
  849. prefix=f"{prefix}.vision_model",
  850. )
  851. self.quant_config = quant_config
  852. @property
  853. def dtype(self) -> torch.dtype:
  854. return self.vision_model.embeddings.patch_embedding.weight.dtype
  855. @property
  856. def device(self) -> torch.device:
  857. return self.vision_model.embeddings.patch_embedding.weight.device
  858. def get_input_embeddings(self) -> nn.Module:
  859. return self.vision_model.embeddings.patch_embedding
  860. def forward(
  861. self,
  862. pixel_values,
  863. interpolate_pos_encoding: bool = False,
  864. position_ids: Optional[torch.Tensor] = None,
  865. image_grid_thw: Optional[
  866. List[
  867. Union[
  868. Tuple[int, int, int],
  869. List[Tuple[int, int, int]],
  870. ]
  871. ]
  872. ] = None,
  873. cu_seqlens: Optional[List[torch.Tensor]] = None,
  874. ) -> BaseModelOutputWithPooling:
  875. return self.vision_model(
  876. pixel_values=pixel_values,
  877. interpolate_pos_encoding=interpolate_pos_encoding,
  878. position_ids=position_ids,
  879. image_grid_thw=image_grid_thw,
  880. cu_seqlens=cu_seqlens,
  881. )
  882. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]:
  883. stacked_params_mapping = [
  884. ("qkv_proj", "q_proj", "q"),
  885. ("qkv_proj", "k_proj", "k"),
  886. ("qkv_proj", "v_proj", "v"),
  887. ]
  888. params_dict = dict(self.named_parameters(remove_duplicate=False))
  889. loaded_params: set[str] = set()
  890. for name, loaded_weight in weights:
  891. if "rotary_emb.inv_freq" in name:
  892. continue
  893. if "head.attention" in name or "head.layernorm" in name:
  894. continue
  895. if "head.mlp" in name or "head.probe" in name:
  896. continue
  897. if self.quant_config is not None and (
  898. scale_name := self.quant_config.get_cache_scale(name)
  899. ):
  900. param = params_dict[scale_name]
  901. weight_loader = getattr(
  902. param,
  903. "weight_loader",
  904. default_weight_loader,
  905. )
  906. loaded_weight = (
  907. loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
  908. )
  909. weight_loader(param, loaded_weight)
  910. loaded_params.add(scale_name)
  911. continue
  912. for (
  913. param_name,
  914. weight_name,
  915. shard_id,
  916. ) in stacked_params_mapping:
  917. if weight_name not in name:
  918. continue
  919. name = name.replace(weight_name, param_name)
  920. if name.endswith(".bias") and name not in params_dict:
  921. continue
  922. if is_pp_missing_parameter(name, self):
  923. continue
  924. param = params_dict[name]
  925. weight_loader = param.weight_loader
  926. weight_loader(param, loaded_weight, shard_id)
  927. break
  928. else:
  929. if name.endswith(".bias") and name not in params_dict:
  930. continue
  931. name = maybe_remap_kv_scale_name(name, params_dict)
  932. if name is None:
  933. continue
  934. if is_pp_missing_parameter(name, self):
  935. continue
  936. param = params_dict[name]
  937. weight_loader = getattr(
  938. param,
  939. "weight_loader",
  940. default_weight_loader,
  941. )
  942. weight_loader(param, loaded_weight)
  943. loaded_params.add(name)
  944. return loaded_params
  945. @MULTIMODAL_REGISTRY.register_processor(
  946. PaddleOCRVLMultiModalProcessor,
  947. info=PaddleOCRVLProcessingInfo,
  948. dummy_inputs=PaddleOCRVLDummyInputsBuilder,
  949. )
  950. @support_torch_compile(
  951. # set dynamic_arg_dims to support mrope
  952. dynamic_arg_dims={
  953. "input_ids": 0,
  954. "positions": -1,
  955. "intermediate_tensors": 0,
  956. "inputs_embeds": 0,
  957. }
  958. )
  959. class PaddleOCRVLForConditionalGeneration(Ernie4_5_ForCausalLM, SupportsMultiModal):
  960. def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
  961. super().__init__(vllm_config=vllm_config, prefix=prefix)
  962. config = self.config
  963. self.mlp_AR = Projector(config, config.vision_config)
  964. self.visual = SiglipVisionModel(config=config.vision_config)
  965. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  966. self.logits_processor = LogitsProcessor(config.vocab_size)
  967. for layer in self.model.layers:
  968. if not isinstance(layer, PPMissingLayer):
  969. layer.self_attn.rotary_emb.is_neox_style = True
  970. def compute_logits(
  971. self,
  972. hidden_states: torch.Tensor,
  973. sampling_metadata,
  974. ) -> Optional[torch.Tensor]:
  975. logits = self.logits_processor(
  976. self.lm_head, hidden_states, sampling_metadata
  977. )
  978. return logits
  979. @property
  980. def language_model(self):
  981. return self.model
  982. def forward(
  983. self,
  984. input_ids: torch.Tensor,
  985. positions: torch.Tensor,
  986. intermediate_tensors: Optional[IntermediateTensors] = None,
  987. inputs_embeds: Optional[torch.Tensor] = None,
  988. **kwargs,
  989. ):
  990. if intermediate_tensors is not None:
  991. inputs_embeds = None
  992. elif inputs_embeds is None:
  993. vision_embeddings = self.get_multimodal_embeddings(**kwargs)
  994. inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings)
  995. input_ids = None
  996. return self.language_model(
  997. input_ids, positions, intermediate_tensors, inputs_embeds
  998. )
  999. @classmethod
  1000. def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
  1001. if modality.startswith("image"):
  1002. return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
  1003. raise ValueError("Only image modality is supported")
  1004. def encode_image(self, pixel_values, image_grid_thw):
  1005. pixel_values = pixel_values.type(self.visual.dtype)
  1006. siglip_position_ids = list()
  1007. image_grid_hws = list()
  1008. cu_seqlens = [0]
  1009. for idx, thw in enumerate(image_grid_thw):
  1010. thw_tuple = tuple(thw.detach().cpu().numpy().tolist())
  1011. numel = np.prod(thw_tuple)
  1012. image_grid_hws.append(thw_tuple)
  1013. image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
  1014. siglip_position_ids.append(image_position_ids)
  1015. cu_seqlens.append(cu_seqlens[-1] + numel)
  1016. siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
  1017. pixel_values.device
  1018. )
  1019. cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
  1020. pixel_values.device
  1021. )
  1022. vision_outputs = self.visual(
  1023. pixel_values=pixel_values,
  1024. image_grid_thw=image_grid_hws,
  1025. position_ids=siglip_position_ids,
  1026. interpolate_pos_encoding=True,
  1027. cu_seqlens=cu_seqlens,
  1028. )
  1029. image_embeds = self.mlp_AR(vision_outputs, image_grid_thw)
  1030. return image_embeds
  1031. def get_multimodal_embeddings(self, **kwargs):
  1032. pixel_values = kwargs["pixel_values"]
  1033. image_grid_thw = kwargs["image_grid_thw"]
  1034. multimodal_embeddings = []
  1035. for pv, ig in zip(pixel_values, image_grid_thw):
  1036. if pv is not None:
  1037. image_embeds = self.encode_image(pv, ig)
  1038. multimodal_embeddings += image_embeds
  1039. return multimodal_embeddings
  1040. def get_input_embeddings(
  1041. self,
  1042. input_ids: torch.Tensor,
  1043. multimodal_embeddings: Optional[NestedTensors] = None,
  1044. ) -> torch.Tensor:
  1045. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  1046. if multimodal_embeddings is not None and len(multimodal_embeddings) != 0:
  1047. inputs_embeds = merge_multimodal_embeddings(
  1048. input_ids,
  1049. inputs_embeds,
  1050. multimodal_embeddings,
  1051. self.config.image_token_id,
  1052. )
  1053. return inputs_embeds
  1054. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]:
  1055. loader = AutoWeightsLoader(self)
  1056. autoloaded_weights = loader.load_weights(weights)
  1057. return autoloaded_weights