_siglip.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # This file is based on https://github.com/Kwai-Keye/Keye/blob/main/keye-vl-8b-preview/modeling_keye.py
  15. # Original header:
  16. # Copyright 2025 The Keye Team and The HuggingFace Inc. team. All rights reserved.
  17. #
  18. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  19. # and OPT implementations in this library. It has been modified from its
  20. # original forms to accommodate minor architectural differences compared
  21. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  22. #
  23. # Licensed under the Apache License, Version 2.0 (the "License");
  24. # you may not use this file except in compliance with the License.
  25. # You may obtain a copy of the License at
  26. #
  27. # http://www.apache.org/licenses/LICENSE-2.0
  28. #
  29. # Unless required by applicable law or agreed to in writing, software
  30. # distributed under the License is distributed on an "AS IS" BASIS,
  31. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  32. # See the License for the specific language governing permissions and
  33. # limitations under the License.
  34. # TODO: Weight initialization
  35. from typing import List, Optional, Tuple, Union
  36. import numpy as np
  37. import paddle
  38. import paddle.nn as nn
  39. import paddle.nn.functional as F
  40. from ....common.vlm.activations import ACT2FN
  41. from ....common.vlm.transformers import PretrainedModel
  42. from ....common.vlm.transformers.model_outputs import (
  43. BaseModelOutput,
  44. BaseModelOutputWithPooling,
  45. )
  46. from ._config import PaddleOCRVLConfig, PPOCRVisionConfig
  47. def rotate_half(x):
  48. Dh = x.shape[-1]
  49. x1 = x[..., : Dh // 2]
  50. x2 = x[..., Dh // 2 :]
  51. return paddle.concat([-x2, x1], axis=-1)
  52. def _ensure_cos_sin_dim(cos, sin, dim_needed):
  53. last = cos.shape[-1]
  54. if last == dim_needed:
  55. return cos, sin
  56. elif last * 2 == dim_needed:
  57. cos = paddle.concat([cos, cos], axis=-1)
  58. sin = paddle.concat([sin, sin], axis=-1)
  59. return cos, sin
  60. else:
  61. raise ValueError(
  62. f"Unexpected cos/sin last-dim: {last}, expected {dim_needed} or {dim_needed//2}"
  63. )
  64. def apply_rotary_pos_emb_vision(q, k, cos, sin):
  65. orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
  66. q = q.astype("float32")
  67. k = k.astype("float32")
  68. Dh = q.shape[-1]
  69. cos = cos.astype("float32")
  70. sin = sin.astype("float32")
  71. cos, sin = _ensure_cos_sin_dim(cos, sin, Dh)
  72. cos = cos.unsqueeze(-2)
  73. sin = sin.unsqueeze(-2)
  74. q_embed = (q * cos) + (rotate_half(q) * sin)
  75. k_embed = (k * cos) + (rotate_half(k) * sin)
  76. return q_embed.astype(orig_q_dtype), k_embed.astype(orig_k_dtype)
  77. def eager_attention_forward(
  78. module,
  79. query,
  80. key,
  81. value,
  82. attention_mask,
  83. scaling: float,
  84. dropout: float = 0.0,
  85. **kwargs,
  86. ):
  87. attn_weights = paddle.matmul(query, key.transpose((0, 1, 3, 2))) * scaling
  88. if attention_mask is not None:
  89. attn_weights = attn_weights + attention_mask
  90. attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query.dtype)
  91. attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
  92. attn_output = paddle.matmul(attn_weights, value)
  93. attn_output = attn_output.transpose((0, 2, 1, 3)).contiguous()
  94. return attn_output, attn_weights
  95. class SiglipAttention(nn.Layer):
  96. def __init__(self, config):
  97. super().__init__()
  98. self.config = config
  99. self.embed_dim = config.hidden_size
  100. self.num_heads = config.num_attention_heads
  101. self.head_dim = self.embed_dim // self.num_heads
  102. assert self.head_dim * self.num_heads == self.embed_dim
  103. self.scale = self.head_dim**-0.5
  104. self.dropout = getattr(config, "attention_dropout", 0.0)
  105. self.is_causal = False
  106. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  107. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  108. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  109. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  110. def forward(
  111. self,
  112. hidden_states: paddle.Tensor, # [B, L, D]
  113. attention_mask: Optional[paddle.Tensor] = None,
  114. output_attentions: Optional[bool] = False,
  115. cu_seqlens: Optional[List[paddle.Tensor]] = None,
  116. rope_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, # (cos, sin)
  117. ):
  118. B, L, D = hidden_states.shape
  119. q = self.q_proj(hidden_states)
  120. k = self.k_proj(hidden_states)
  121. v = self.v_proj(hidden_states)
  122. # [B, L, H, Dh]
  123. q = q.reshape([B, L, self.num_heads, self.head_dim])
  124. k = k.reshape([B, L, self.num_heads, self.head_dim])
  125. v = v.reshape([B, L, self.num_heads, self.head_dim])
  126. if rope_emb is not None:
  127. cos, sin = rope_emb
  128. q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
  129. # → [B, H, L, Dh]
  130. q = q.transpose([0, 2, 1, 3])
  131. k = k.transpose([0, 2, 1, 3])
  132. v = v.transpose([0, 2, 1, 3])
  133. attn_output, attn_weights = eager_attention_forward(
  134. self,
  135. q,
  136. k,
  137. v,
  138. attention_mask,
  139. is_causal=self.is_causal,
  140. scaling=self.scale,
  141. dropout=0.0 if not self.training else self.dropout,
  142. )
  143. attn_output = attn_output.reshape([B, L, D]).contiguous()
  144. attn_output = self.out_proj(attn_output)
  145. if not output_attentions:
  146. attn_weights = None
  147. return attn_output, attn_weights
  148. class SiglipVisionEmbeddings(nn.Layer):
  149. def __init__(self, config):
  150. super().__init__()
  151. self.config = config
  152. self.embed_dim = config.hidden_size # 1152
  153. self.image_size = config.image_size # 384
  154. self.patch_size = config.patch_size # 14
  155. # 注意:Paddle 要用 "VALID" 或 0
  156. self.patch_embedding = nn.Conv2D(
  157. in_channels=config.num_channels,
  158. out_channels=self.embed_dim,
  159. kernel_size=self.patch_size,
  160. stride=self.patch_size,
  161. padding="VALID",
  162. )
  163. self.num_patches = (self.image_size // self.patch_size) ** 2 # 729
  164. self.num_positions = self.num_patches
  165. self.cache_position_embedding = dict()
  166. self.cache_position_count = dict()
  167. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  168. self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
  169. self.register_buffer(
  170. "position_ids",
  171. paddle.arange(self.num_positions).unsqueeze(0),
  172. persistable=False,
  173. )
  174. def interpolate_pos_encoding(
  175. self, embeddings, height: int, width: int, is_after_patchify: bool = False
  176. ):
  177. num_positions = self.position_embedding.weight.shape[0]
  178. patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
  179. dim = embeddings.shape[-1]
  180. if is_after_patchify:
  181. new_height = height
  182. new_width = width
  183. else:
  184. new_height = height // self.patch_size
  185. new_width = width // self.patch_size
  186. sqrt_num_positions = paddle.to_tensor(num_positions**0.5, dtype=paddle.int64)
  187. patch_pos_embed = patch_pos_embed.reshape(
  188. (1, sqrt_num_positions, sqrt_num_positions, dim)
  189. )
  190. patch_pos_embed = patch_pos_embed.transpose((0, 3, 1, 2))
  191. patch_pos_embed = nn.functional.interpolate(
  192. patch_pos_embed,
  193. size=(new_height, new_width),
  194. mode="bilinear",
  195. align_corners=False,
  196. )
  197. patch_pos_embed = patch_pos_embed.transpose((0, 2, 3, 1)).reshape((1, -1, dim))
  198. return patch_pos_embed
  199. @staticmethod
  200. def flatten_list(image_grid_thw):
  201. tmp_image_grid_thw = list()
  202. for image_grid in image_grid_thw:
  203. if isinstance(image_grid, list):
  204. tmp_image_grid_thw.extend(image_grid)
  205. else:
  206. tmp_image_grid_thw.append(image_grid)
  207. return tmp_image_grid_thw
  208. def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache=20):
  209. grid = (h, w)
  210. if grid in self.cache_position_embedding:
  211. self.cache_position_count[grid] += 1
  212. return self.cache_position_embedding[grid]
  213. if len(self.cache_position_embedding) >= max_cache:
  214. min_hit_grid = min(
  215. self.cache_position_count, key=self.cache_position_count.get
  216. )
  217. self.cache_position_count.pop(min_hit_grid)
  218. self.cache_position_embedding.pop(min_hit_grid)
  219. position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
  220. self.cache_position_count[grid] = 1
  221. self.cache_position_embedding[grid] = position_embedding
  222. return position_embedding
  223. def forward(
  224. self,
  225. pixel_values: paddle.Tensor, # [B, L, C, H, W]
  226. position_ids: Optional[paddle.Tensor] = None, # [B or 1, S]
  227. image_grid_thw: Optional[
  228. List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
  229. ] = None,
  230. interpolate_pos_encoding: bool = False,
  231. ) -> paddle.Tensor:
  232. if pixel_values.dim() == 5:
  233. assert position_ids is not None
  234. from einops import rearrange
  235. batch_size, squence_len, channel, height, width = pixel_values.shape
  236. target_dtype = self.patch_embedding.weight.dtype
  237. pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
  238. patch_embeds = self.patch_embedding(
  239. pixel_values.to(dtype=target_dtype)
  240. ) # shape = [*, width, grid, grid]
  241. embeddings = patch_embeds.flatten(-2).squeeze(-1)
  242. embeddings = rearrange(
  243. embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len
  244. )
  245. # todo: not dubug
  246. if interpolate_pos_encoding and image_grid_thw is not None:
  247. flatten_image_grid_thw = self.flatten_list(image_grid_thw)
  248. assert batch_size == 1
  249. start = 0
  250. image_embedding_list = list()
  251. assert (
  252. sum([np.prod(x) for x in flatten_image_grid_thw])
  253. == embeddings.shape[1]
  254. ), (flatten_image_grid_thw, embeddings.shape)
  255. embeddings = embeddings.squeeze(0)
  256. tmp_embeddings = list()
  257. for image_grid in image_grid_thw:
  258. t, h, w = image_grid
  259. end = start + t * h * w
  260. image_embeddings = embeddings[int(start) : int(end), :]
  261. position_embedding = (
  262. self.interpolate_pos_encoding(image_embeddings, h, w, True)
  263. .squeeze(0)
  264. .tile((t, 1))
  265. )
  266. image_embeddings = image_embeddings + position_embedding
  267. tmp_embeddings.append(image_embeddings)
  268. start = end
  269. embeddings = paddle.concat(tmp_embeddings, axis=0).unsqueeze(0)
  270. else:
  271. embeddings = embeddings + self.packing_position_embedding(position_ids)
  272. return embeddings
  273. else:
  274. raise NotImplementedError(str(pixel_values.shape))
  275. class SiglipMLP(nn.Layer):
  276. def __init__(self, config):
  277. super().__init__()
  278. self.config = config
  279. self.activation_fn = ACT2FN[config.hidden_act]
  280. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  281. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  282. def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
  283. hidden_states = self.fc1(hidden_states)
  284. hidden_states = self.activation_fn(hidden_states)
  285. hidden_states = self.fc2(hidden_states)
  286. return hidden_states
  287. class SiglipEncoderLayer(paddle.nn.Layer):
  288. def __init__(self, config):
  289. super().__init__()
  290. self.embed_dim = config.hidden_size
  291. self.layer_norm1 = paddle.nn.LayerNorm(
  292. self.embed_dim, epsilon=config.layer_norm_eps
  293. )
  294. self.self_attn = SiglipAttention(config)
  295. self.layer_norm2 = paddle.nn.LayerNorm(
  296. self.embed_dim, epsilon=config.layer_norm_eps
  297. )
  298. self.mlp = SiglipMLP(config)
  299. def forward(
  300. self,
  301. hidden_states,
  302. attention_mask,
  303. output_attentions=False,
  304. cu_seqlens=None,
  305. rope_emb=None,
  306. ):
  307. residual = hidden_states
  308. ############################
  309. ln1_out = self.layer_norm1(hidden_states)
  310. x, attn_w = self.self_attn(
  311. hidden_states=ln1_out,
  312. attention_mask=attention_mask,
  313. output_attentions=output_attentions,
  314. cu_seqlens=cu_seqlens,
  315. rope_emb=rope_emb,
  316. )
  317. hs_post_attn = residual + x
  318. residual = hs_post_attn
  319. ln2_out = self.layer_norm2(residual)
  320. mlp_out = self.mlp(ln2_out)
  321. hidden_states_out = residual + mlp_out
  322. outputs = (hidden_states_out,)
  323. if output_attentions:
  324. outputs += (attn_w,)
  325. return outputs
  326. class SigLIPRotaryEmbedding(nn.Layer):
  327. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  328. super().__init__()
  329. self.dim = dim
  330. self.theta = theta
  331. self.rope_init()
  332. def rope_init(self):
  333. arange = paddle.arange(0, self.dim, 2, dtype="float32")
  334. inv_freq = 1.0 / (self.theta ** (arange / self.dim))
  335. self.register_buffer("inv_freq", inv_freq, persistable=False)
  336. def forward(self, seqlen: int) -> paddle.Tensor:
  337. seq = paddle.arange(seqlen, dtype=self.inv_freq.dtype)
  338. freqs = paddle.outer(seq, self.inv_freq)
  339. return freqs
  340. class SiglipEncoder(nn.Layer):
  341. def __init__(self, config):
  342. super().__init__()
  343. self.config = config
  344. embed_dim = config.hidden_size
  345. num_heads = config.num_attention_heads
  346. head_dim = embed_dim // num_heads
  347. self.layers = nn.LayerList(
  348. [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
  349. )
  350. self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
  351. self.gradient_checkpointing = False
  352. @staticmethod
  353. def flatten_list(image_grid_thw):
  354. tmp_image_grid_thw = list()
  355. for image_grid in image_grid_thw:
  356. if isinstance(image_grid, list):
  357. tmp_image_grid_thw.extend(image_grid)
  358. else:
  359. tmp_image_grid_thw.append(image_grid)
  360. return tmp_image_grid_thw
  361. def build_window_index(self, image_grid, window_size):
  362. """
  363. 返回:
  364. window_indices: int64 [sum(t*h*w_valid)]
  365. cu_seqlens_within_windows: int32 [num_windows_total*t],首位补 0 的前缀和
  366. """
  367. from einops import rearrange
  368. window_indices = list()
  369. pad_values = -100
  370. start_window_index = 0
  371. cu_seqlens_within_windows = list()
  372. for t, h, w in map(int, image_grid):
  373. window_index = paddle.arange(t * h * w).reshape((t, h, w))
  374. pad_h = (-h) % window_size
  375. pad_w = (-w) % window_size
  376. assert pad_h >= 0 and pad_w >= 0, (pad_h, pad_w)
  377. window_index = F.pad(window_index, (0, pad_w, 0, pad_h), value=pad_values)
  378. window_index = rearrange(
  379. window_index,
  380. "t (h p1) (w p2) -> t (h w) (p1 p2)",
  381. p1=window_size,
  382. p2=window_size,
  383. )
  384. window_seqlens = (window_index != pad_values).long().sum(-1).reshape(-1)
  385. window_index = window_index.reshape(-1)
  386. window_index = window_index[window_index != pad_values]
  387. window_indices.append(window_index + start_window_index)
  388. cu_seqlens_within_windows.append(
  389. window_seqlens.cumsum(0) + start_window_index
  390. )
  391. start_window_index += t * h * w
  392. window_indices = paddle.concat(window_indices, axis=0)
  393. cu_seqlens_within_windows = paddle.concat(cu_seqlens_within_windows, axis=0)
  394. cu_seqlens_within_windows = F.pad(
  395. cu_seqlens_within_windows, (1, 0), value=0
  396. ).astype("int32")
  397. return window_indices, cu_seqlens_within_windows
  398. def forward(
  399. self,
  400. inputs_embeds: paddle.Tensor,
  401. attention_mask: Optional[paddle.Tensor] = None,
  402. output_attentions: Optional[bool] = None,
  403. output_hidden_states: Optional[bool] = None,
  404. cu_seqlens: Optional[paddle.Tensor] = None,
  405. image_grid_thw: Optional[
  406. List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
  407. ] = None,
  408. height_position_ids: Optional[paddle.Tensor] = None,
  409. width_position_ids: Optional[paddle.Tensor] = None,
  410. use_rope: Optional[bool] = False,
  411. window_size: Optional[int] = -1,
  412. vision_or_text: str = "vision",
  413. ):
  414. vision_or_text = "vision"
  415. assert vision_or_text in ["vision", "text"]
  416. use_window_attn = window_size > 0 and vision_or_text == "vision"
  417. use_rope = (use_rope is True) and (vision_or_text == "vision")
  418. output_attentions = (
  419. output_attentions
  420. if output_attentions is not None
  421. else self.config.output_attentions
  422. )
  423. output_hidden_states = (
  424. output_hidden_states
  425. if output_hidden_states is not None
  426. else self.config.output_hidden_states
  427. )
  428. encoder_states = () if output_hidden_states else None
  429. all_attentions = () if output_attentions else None
  430. hidden_states = inputs_embeds
  431. attention_mask = (
  432. attention_mask.to(inputs_embeds.dtype)
  433. if attention_mask is not None
  434. else None
  435. )
  436. if use_rope is True:
  437. flatten_image_grid_thw = self.flatten_list(image_grid_thw)
  438. assert (
  439. sum([np.prod(x) for x in flatten_image_grid_thw])
  440. == hidden_states.shape[1]
  441. ), (flatten_image_grid_thw, hidden_states.shape)
  442. if width_position_ids is None or height_position_ids is None:
  443. split_hids = list()
  444. split_wids = list()
  445. for t, h, w in flatten_image_grid_thw:
  446. t, h, w = map(int, (t, h, w))
  447. image_pids = paddle.arange(t * h * w) % (h * w)
  448. sample_hids = image_pids // w
  449. sample_wids = image_pids % w
  450. split_hids.append(sample_hids)
  451. split_wids.append(sample_wids)
  452. width_position_ids = paddle.concat(split_wids, axis=0)
  453. height_position_ids = paddle.concat(split_hids, axis=0)
  454. window_indices, cu_seqlens_within_windows = None, None
  455. if use_window_attn:
  456. window_indices, cu_seqlens_within_windows = self.build_window_index(
  457. flatten_image_grid_thw, window_size
  458. )
  459. reversed_window_indices = window_indices.argsort()
  460. height_position_ids = height_position_ids[window_indices]
  461. width_position_ids = width_position_ids[window_indices]
  462. pids = paddle.stack(
  463. [height_position_ids, width_position_ids], axis=-1
  464. ).astype(paddle.int64)
  465. max_grid_size = pids.max() + 1
  466. rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
  467. rope_emb = rope_emb_max_grid[pids].flatten(1)
  468. rope_emb = rope_emb.tile((1, 2))
  469. rope_emb = (rope_emb.cos(), rope_emb.sin())
  470. else:
  471. rope_emb = None
  472. window_indices, cu_seqlens_within_windows = None, None
  473. if use_window_attn:
  474. flatten_image_grid_thw = self.flatten_list(image_grid_thw)
  475. assert (
  476. sum(
  477. [
  478. np.prod(x.astype("float32").cpu().numpy())
  479. for x in flatten_image_grid_thw
  480. ]
  481. )
  482. == hidden_states.shape[1]
  483. ), (flatten_image_grid_thw, hidden_states.shape)
  484. window_indices, cu_seqlens_within_windows = self.build_window_index(
  485. flatten_image_grid_thw, window_size
  486. )
  487. reversed_window_indices = window_indices.argsort()
  488. if use_window_attn:
  489. assert cu_seqlens_within_windows is not None
  490. attn_cu_seqlens = cu_seqlens_within_windows
  491. hidden_states = hidden_states[:, window_indices, :]
  492. else:
  493. attn_cu_seqlens = cu_seqlens
  494. for encoder_layer in self.layers:
  495. if output_hidden_states:
  496. encoder_states = encoder_states + (
  497. (hidden_states[:, reversed_window_indices, :],)
  498. if use_window_attn
  499. else (hidden_states,)
  500. )
  501. layer_outputs = encoder_layer(
  502. hidden_states,
  503. attention_mask,
  504. output_attentions=output_attentions,
  505. cu_seqlens=attn_cu_seqlens,
  506. rope_emb=rope_emb,
  507. )
  508. hidden_states = layer_outputs[0]
  509. if output_attentions:
  510. all_attentions = all_attentions + (layer_outputs[1],)
  511. if use_window_attn:
  512. hidden_states = hidden_states[:, reversed_window_indices, :]
  513. if output_hidden_states:
  514. encoder_states = encoder_states + (hidden_states,)
  515. return BaseModelOutput(
  516. last_hidden_state=hidden_states,
  517. hidden_states=encoder_states,
  518. attentions=all_attentions,
  519. )
  520. class SiglipMultiheadAttentionPoolingHead(nn.Layer):
  521. """Multihead Attention Pooling."""
  522. def __init__(self, config: PPOCRVisionConfig):
  523. super().__init__()
  524. self.probe = self.create_parameter(
  525. shape=(1, 1, config.hidden_size),
  526. default_initializer=paddle.nn.initializer.Normal(),
  527. )
  528. self.attention = nn.MultiHeadAttention(
  529. config.hidden_size, config.num_attention_heads
  530. )
  531. self.layernorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
  532. self.mlp = SiglipMLP(config)
  533. def forward(self, hidden_state, key_padding_mask=None):
  534. batch_size = hidden_state.shape[0]
  535. probe = self.probe.tile((batch_size, 1, 1))
  536. hidden_state = self.attention(
  537. probe, hidden_state, hidden_state, key_padding_mask=key_padding_mask
  538. )[0]
  539. residual = hidden_state
  540. hidden_state = self.layernorm(hidden_state)
  541. hidden_state = residual + self.mlp(hidden_state)
  542. return hidden_state[:, 0]
  543. class SiglipVisionTransformer(nn.Layer):
  544. def __init__(self, config: PPOCRVisionConfig):
  545. super().__init__()
  546. self.config = config
  547. embed_dim = config.hidden_size
  548. self.embeddings = SiglipVisionEmbeddings(config)
  549. self.encoder = SiglipEncoder(config)
  550. self.post_layernorm = nn.LayerNorm(embed_dim, epsilon=config.layer_norm_eps)
  551. self.use_head = (
  552. True if not hasattr(config, "vision_use_head") else config.vision_use_head
  553. )
  554. if self.use_head:
  555. self.head = SiglipMultiheadAttentionPoolingHead(config)
  556. def forward(
  557. self,
  558. pixel_values,
  559. output_attentions: Optional[bool] = None,
  560. output_hidden_states: Optional[bool] = None,
  561. interpolate_pos_encoding: Optional[bool] = False,
  562. attention_mask=None,
  563. sample_indices=None,
  564. image_indices=None,
  565. position_ids=None,
  566. height_position_ids=None,
  567. width_position_ids=None,
  568. cu_seqlens=None,
  569. padding_mask=None,
  570. vision_return_embed_list: Optional[bool] = False,
  571. image_grid_thw: Optional[
  572. List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
  573. ] = None,
  574. return_pooler_output: Optional[bool] = True,
  575. use_rope: Optional[bool] = False,
  576. window_size: Optional[bool] = -1,
  577. ) -> BaseModelOutputWithPooling:
  578. output_attentions = (
  579. output_attentions
  580. if output_attentions is not None
  581. else self.config.output_attentions
  582. )
  583. output_hidden_states = (
  584. output_hidden_states
  585. if output_hidden_states is not None
  586. else self.config.output_hidden_states
  587. )
  588. hidden_states = self.embeddings(
  589. pixel_values,
  590. interpolate_pos_encoding=interpolate_pos_encoding,
  591. position_ids=position_ids,
  592. image_grid_thw=image_grid_thw,
  593. )
  594. encoder_outputs: BaseModelOutput = self.encoder(
  595. inputs_embeds=hidden_states,
  596. output_attentions=output_attentions,
  597. output_hidden_states=output_hidden_states,
  598. attention_mask=attention_mask,
  599. cu_seqlens=cu_seqlens,
  600. image_grid_thw=image_grid_thw,
  601. use_rope=use_rope,
  602. height_position_ids=height_position_ids,
  603. width_position_ids=width_position_ids,
  604. window_size=window_size,
  605. vision_or_text="vision",
  606. )
  607. last_hidden_state = encoder_outputs.last_hidden_state
  608. last_hidden_state = self.post_layernorm(last_hidden_state)
  609. if return_pooler_output is True:
  610. if sample_indices is not None:
  611. assert self.use_head is True
  612. dim = last_hidden_state.shape[-1]
  613. sample_hidden_state_list = list()
  614. hidden_state = last_hidden_state.squeeze(0)
  615. sample_index = sample_indices
  616. unique_sample_index = (
  617. paddle.unique(sample_index).sort().values.unbind(0)
  618. )
  619. unique_sample_index = list(unique_sample_index)
  620. if len(unique_sample_index) > 0 and unique_sample_index[0] == -1:
  621. unique_sample_index = unique_sample_index[1:]
  622. for sample_idx in unique_sample_index:
  623. token_indices = (sample_index == sample_idx).nonzero().flatten()
  624. sample_hidden_state = hidden_state[token_indices]
  625. sample_hidden_state_list.append(sample_hidden_state)
  626. if not vision_return_embed_list:
  627. max_length = max(
  628. [_state.shape[0] for _state in sample_hidden_state_list]
  629. )
  630. tmp_sample_hidden_state_list = list()
  631. padding_mask = list()
  632. for idx, _state in enumerate(sample_hidden_state_list):
  633. padding_length = max_length - _state.shape[0]
  634. mask = _state.new_zeros(size=(max_length,), dtype=paddle.int64)
  635. mask[-padding_length:] = 1
  636. padding_mask.append(mask)
  637. padding = _state.new_zeros(size=(padding_length, dim))
  638. new_state = paddle.concat([_state, padding], axis=0)
  639. tmp_sample_hidden_state_list.append(new_state)
  640. sample_hidden_state = paddle.stack(
  641. tmp_sample_hidden_state_list, axis=0
  642. )
  643. padding_mask = (
  644. paddle.stack(padding_mask, axis=0)
  645. .astype("float32")
  646. .to(last_hidden_state.dtype)
  647. )
  648. pooler_output = self.head(
  649. sample_hidden_state, key_padding_mask=padding_mask
  650. )
  651. else:
  652. pooler_output = list()
  653. for state in sample_hidden_state_list:
  654. sample_pooler_output = self.head(state.unsqueeze(0))
  655. pooler_output.append(sample_pooler_output)
  656. pooler_output = paddle.concat(pooler_output, axis=0)
  657. sample_hidden_state = sample_hidden_state_list
  658. return BaseModelOutputWithPooling(
  659. last_hidden_state=sample_hidden_state,
  660. pooler_output=pooler_output,
  661. hidden_states=encoder_outputs.hidden_states,
  662. attentions=encoder_outputs.attentions,
  663. )
  664. else:
  665. pooler_output = self.head(last_hidden_state) if self.use_head else None
  666. return BaseModelOutputWithPooling(
  667. last_hidden_state=last_hidden_state,
  668. pooler_output=pooler_output,
  669. hidden_states=encoder_outputs.hidden_states,
  670. attentions=encoder_outputs.attentions,
  671. )
  672. sample_hidden_state = list()
  673. assert cu_seqlens is not None
  674. for i in range(cu_seqlens.shape[0] - 1):
  675. start = cu_seqlens[i]
  676. end = cu_seqlens[i + 1]
  677. tensor = last_hidden_state[:, start:end, :].squeeze(0)
  678. sample_hidden_state.append(tensor)
  679. return BaseModelOutputWithPooling(
  680. last_hidden_state=sample_hidden_state,
  681. pooler_output=None,
  682. hidden_states=encoder_outputs.hidden_states,
  683. attentions=encoder_outputs.attentions,
  684. )
  685. class SiglipPreTrainedModel(PretrainedModel):
  686. config_class = PaddleOCRVLConfig
  687. base_model_prefix = "siglip"
  688. supports_gradient_checkpointing = True
  689. _no_split_modules = [
  690. "SiglipTextEmbeddings",
  691. "SiglipEncoderLayer",
  692. "SiglipVisionEmbeddings",
  693. "SiglipMultiheadAttentionPoolingHead",
  694. ]
  695. _supports_flash_attn_2 = True
  696. _supports_sdpa = True
  697. class SiglipVisionModel(SiglipPreTrainedModel):
  698. config_class = PPOCRVisionConfig
  699. main_input_name = "pixel_values"
  700. def __init__(self, config: PPOCRVisionConfig):
  701. super().__init__(config)
  702. self.vision_model = SiglipVisionTransformer(config)
  703. def get_input_embeddings(self) -> nn.Layer:
  704. return self.vision_model.embeddings.patch_embedding
  705. def forward(
  706. self,
  707. pixel_values,
  708. sample_indices=None,
  709. output_attentions: Optional[bool] = None,
  710. output_hidden_states: Optional[bool] = None,
  711. interpolate_pos_encoding: bool = False,
  712. position_ids=None,
  713. vision_return_embed_list: Optional[bool] = False,
  714. image_grid_thw: Optional[
  715. List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
  716. ] = None,
  717. cu_seqlens=None,
  718. return_pooler_output: Optional[bool] = True,
  719. use_rope: Optional[bool] = False,
  720. window_size: Optional[bool] = -1,
  721. ) -> BaseModelOutputWithPooling:
  722. return self.vision_model(
  723. pixel_values=pixel_values,
  724. output_attentions=output_attentions,
  725. output_hidden_states=output_hidden_states,
  726. interpolate_pos_encoding=interpolate_pos_encoding,
  727. position_ids=position_ids,
  728. vision_return_embed_list=vision_return_embed_list,
  729. image_grid_thw=image_grid_thw,
  730. sample_indices=sample_indices,
  731. cu_seqlens=cu_seqlens,
  732. return_pooler_output=return_pooler_output,
  733. use_rope=use_rope,
  734. window_size=window_size,
  735. )