GOT_ocr_2_0.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from functools import partial
  15. from typing import List, Optional, Tuple, Type
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from ...common.vlm.transformers.model_outputs import CausalLMOutputWithPast
  20. from .qwen2 import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
  21. class MLPBlock(paddle.nn.Layer):
  22. def __init__(
  23. self,
  24. embedding_dim: int,
  25. mlp_dim: int,
  26. act: Type[paddle.nn.Layer] = paddle.nn.GELU,
  27. ) -> None:
  28. super().__init__()
  29. self.lin1 = nn.Linear(embedding_dim, mlp_dim)
  30. self.lin2 = nn.Linear(mlp_dim, embedding_dim)
  31. self.act = act()
  32. def forward(self, x: paddle.Tensor) -> paddle.Tensor:
  33. return self.lin2(self.act(self.lin1(x)))
  34. class LayerNorm2d(paddle.nn.Layer):
  35. def __init__(self, num_channels: int, epsilon: float = 1e-06) -> None:
  36. super().__init__()
  37. self.weight = paddle.base.framework.EagerParamBase.from_tensor(
  38. tensor=paddle.ones(shape=num_channels)
  39. )
  40. self.bias = paddle.base.framework.EagerParamBase.from_tensor(
  41. tensor=paddle.zeros(shape=num_channels)
  42. )
  43. self.epsilon = epsilon
  44. def forward(self, x: paddle.Tensor) -> paddle.Tensor:
  45. u = x.mean(axis=1, keepdim=True)
  46. s = (x - u).pow(y=2).mean(axis=1, keepdim=True)
  47. x = (x - u) / paddle.sqrt(x=s + self.epsilon)
  48. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  49. return x
  50. class ImageEncoderViT(paddle.nn.Layer):
  51. def __init__(
  52. self,
  53. img_size: int = 1024,
  54. patch_size: int = 16,
  55. in_chans: int = 3,
  56. embed_dim: int = 768,
  57. depth: int = 12,
  58. num_heads: int = 12,
  59. mlp_ratio: float = 4.0,
  60. out_chans: int = 256,
  61. qkv_bias: bool = True,
  62. norm_layer: Type[nn.Layer] = nn.LayerNorm,
  63. act_layer: Type[nn.Layer] = nn.GELU,
  64. use_abs_pos: bool = True,
  65. use_rel_pos: bool = False,
  66. rel_pos_zero_init: bool = True,
  67. window_size: int = 0,
  68. global_attn_indexes: Tuple[int, ...] = (),
  69. ) -> None:
  70. """
  71. Args:
  72. img_size (int): Input image size.
  73. patch_size (int): Patch size.
  74. in_chans (int): Number of input image channels.
  75. embed_dim (int): Patch embedding dimension.
  76. depth (int): Depth of ViT.
  77. num_heads (int): Number of attention heads in each ViT block.
  78. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  79. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  80. norm_layer (nn.Layer): Normalization layer.
  81. act_layer (nn.Layer): Activation layer.
  82. use_abs_pos (bool): If True, use absolute positional embeddings.
  83. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  84. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  85. window_size (int): Window size for window attention blocks.
  86. global_attn_indexes (list): Indexes for blocks using global attention.
  87. """
  88. super().__init__()
  89. self.img_size = img_size
  90. self.patch_embed = PatchEmbed(
  91. kernel_size=(patch_size, patch_size),
  92. stride=(patch_size, patch_size),
  93. in_chans=in_chans,
  94. embed_dim=embed_dim,
  95. )
  96. self.pos_embed: Optional[paddle.base.framework.EagerParamBase.from_tensor] = (
  97. None
  98. )
  99. if use_abs_pos:
  100. self.pos_embed = paddle.base.framework.EagerParamBase.from_tensor(
  101. tensor=paddle.zeros(
  102. shape=[1, img_size // patch_size, img_size // patch_size, embed_dim]
  103. )
  104. )
  105. self.blocks = paddle.nn.LayerList()
  106. for i in range(depth):
  107. block = Block(
  108. dim=embed_dim,
  109. num_heads=num_heads,
  110. mlp_ratio=mlp_ratio,
  111. qkv_bias=qkv_bias,
  112. norm_layer=norm_layer,
  113. act_layer=act_layer,
  114. use_rel_pos=use_rel_pos,
  115. rel_pos_zero_init=rel_pos_zero_init,
  116. window_size=window_size if i not in global_attn_indexes else 0,
  117. input_size=(img_size // patch_size, img_size // patch_size),
  118. )
  119. self.blocks.append(block)
  120. self.neck = nn.Sequential(
  121. nn.Conv2D(
  122. embed_dim,
  123. out_chans,
  124. kernel_size=1,
  125. bias_attr=False,
  126. ),
  127. LayerNorm2d(out_chans),
  128. nn.Conv2D(
  129. out_chans,
  130. out_chans,
  131. kernel_size=3,
  132. padding=1,
  133. bias_attr=False,
  134. ),
  135. LayerNorm2d(out_chans),
  136. )
  137. self.net_2 = nn.Conv2D(
  138. 256, 512, kernel_size=3, stride=2, padding=1, bias_attr=False
  139. )
  140. self.net_3 = nn.Conv2D(
  141. 512, 1024, kernel_size=3, stride=2, padding=1, bias_attr=False
  142. )
  143. def forward(self, x: paddle.Tensor) -> paddle.Tensor:
  144. x = self.patch_embed(x)
  145. if self.pos_embed is not None:
  146. x = x + self.pos_embed
  147. for blk in self.blocks:
  148. x = blk(x)
  149. x = self.neck(x.transpose([0, 3, 1, 2]))
  150. x = self.net_2(x)
  151. x = self.net_3(x)
  152. return x
  153. class Block(paddle.nn.Layer):
  154. """Transformer blocks with support of window attention and residual propagation blocks"""
  155. def __init__(
  156. self,
  157. dim: int,
  158. num_heads: int,
  159. mlp_ratio: float = 4.0,
  160. qkv_bias: bool = True,
  161. norm_layer: Type[nn.Layer] = nn.LayerNorm,
  162. act_layer: Type[nn.Layer] = nn.GELU,
  163. use_rel_pos: bool = False,
  164. rel_pos_zero_init: bool = True,
  165. window_size: int = 0,
  166. input_size: Optional[Tuple[int, int]] = None,
  167. ) -> None:
  168. """
  169. Args:
  170. dim (int): Number of input channels.
  171. num_heads (int): Number of attention heads in each ViT block.
  172. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  173. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  174. norm_layer (nn.Layer): Normalization layer.
  175. act_layer (nn.Layer): Activation layer.
  176. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  177. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  178. window_size (int): Window size for window attention blocks. If it equals 0, then
  179. use global attention.
  180. input_size (tuple(int, int) or None): Input resolution for calculating the relative
  181. positional parameter size.
  182. """
  183. super().__init__()
  184. self.norm1 = norm_layer(dim)
  185. self.attn = Attention(
  186. dim,
  187. num_heads=num_heads,
  188. qkv_bias=qkv_bias,
  189. use_rel_pos=use_rel_pos,
  190. rel_pos_zero_init=rel_pos_zero_init,
  191. input_size=input_size if window_size == 0 else (window_size, window_size),
  192. )
  193. self.norm2 = norm_layer(dim)
  194. self.mlp = MLPBlock(
  195. embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
  196. )
  197. self.window_size = window_size
  198. def forward(self, x: paddle.Tensor) -> paddle.Tensor:
  199. shortcut = x
  200. x = self.norm1(x)
  201. # Window partition
  202. if self.window_size > 0:
  203. H, W = x.shape[1], x.shape[2]
  204. x, pad_hw = window_partition(x, self.window_size)
  205. x = self.attn(x)
  206. # Reverse window partition
  207. if self.window_size > 0:
  208. x = window_unpartition(x, self.window_size, pad_hw, (H, W))
  209. x = shortcut + x
  210. x = x + self.mlp(self.norm2(x))
  211. return x
  212. class Attention(paddle.nn.Layer):
  213. """Multi-head Attention block with relative position embeddings."""
  214. def __init__(
  215. self,
  216. dim: int,
  217. num_heads: int = 8,
  218. qkv_bias: bool = True,
  219. use_rel_pos: bool = False,
  220. rel_pos_zero_init: bool = True,
  221. input_size: Optional[Tuple[int, int]] = None,
  222. ) -> None:
  223. """
  224. Args:
  225. dim (int): Number of input channels.
  226. num_heads (int): Number of attention heads.
  227. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  228. rel_pos (bool): If True, add relative positional embeddings to the attention map.
  229. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  230. input_size (tuple(int, int) or None): Input resolution for calculating the relative
  231. positional parameter size.
  232. """
  233. super().__init__()
  234. self.num_heads = num_heads
  235. head_dim = dim // num_heads
  236. self.scale = head_dim**-0.5
  237. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  238. self.proj = nn.Linear(dim, dim)
  239. self.use_rel_pos = use_rel_pos
  240. if self.use_rel_pos:
  241. assert (
  242. input_size is not None
  243. ), "Input size must be provided if using relative positional encoding."
  244. self.rel_pos_h = paddle.base.framework.EagerParamBase.from_tensor(
  245. tensor=paddle.zeros(shape=[2 * input_size[0] - 1, head_dim])
  246. )
  247. self.rel_pos_w = paddle.base.framework.EagerParamBase.from_tensor(
  248. tensor=paddle.zeros(shape=[2 * input_size[1] - 1, head_dim])
  249. )
  250. def forward(self, x: paddle.Tensor) -> paddle.Tensor:
  251. B, H, W, _ = tuple(x.shape)
  252. qkv = (
  253. self.qkv(x)
  254. .reshape([B, H * W, 3, self.num_heads, -1])
  255. .transpose([2, 0, 3, 1, 4])
  256. )
  257. q, k, v = qkv.reshape([3, B * self.num_heads, H * W, -1]).unbind(axis=0)
  258. attn = (q * self.scale) @ k.transpose([0, 2, 1])
  259. if self.use_rel_pos:
  260. attn = add_decomposed_rel_pos(
  261. attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
  262. )
  263. attn = F.softmax(attn, axis=-1)
  264. x = (
  265. (attn @ v)
  266. .reshape([B, self.num_heads, H, W, -1])
  267. .transpose([0, 2, 3, 1, 4])
  268. .reshape([B, H, W, -1])
  269. )
  270. x = self.proj(x)
  271. return x
  272. def window_partition(
  273. x: paddle.Tensor, window_size: int
  274. ) -> Tuple[paddle.Tensor, Tuple[int, int]]:
  275. """
  276. Partition into non-overlapping windows with padding if needed.
  277. Args:
  278. x (tensor): input tokens with [B, H, W, C].
  279. window_size (int): window size.
  280. Returns:
  281. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  282. (Hp, Wp): padded height and width before partition
  283. """
  284. B, H, W, C = tuple(x.shape)
  285. pad_h = (window_size - H % window_size) % window_size
  286. pad_w = (window_size - W % window_size) % window_size
  287. if pad_h > 0 or pad_w > 0:
  288. x = F.pad(x, pad=(0, pad_w, 0, pad_h), data_format="NHWC")
  289. Hp, Wp = H + pad_h, W + pad_w
  290. x = x.reshape(
  291. [B, Hp // window_size, window_size, Wp // window_size, window_size, C]
  292. )
  293. windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C])
  294. return windows, (Hp, Wp)
  295. def window_unpartition(
  296. windows: paddle.Tensor,
  297. window_size: int,
  298. pad_hw: Tuple[int, int],
  299. hw: Tuple[int, int],
  300. ) -> paddle.Tensor:
  301. """
  302. Window unpartition into original sequences and removing padding.
  303. Args:
  304. windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  305. window_size (int): window size.
  306. pad_hw (Tuple): padded height and width (Hp, Wp).
  307. hw (Tuple): original height and width (H, W) before padding.
  308. Returns:
  309. x: unpartitioned sequences with [B, H, W, C].
  310. """
  311. Hp, Wp = pad_hw
  312. H, W = hw
  313. B = tuple(windows.shape)[0] // (Hp * Wp // window_size // window_size)
  314. x = windows.reshape(
  315. [B, Hp // window_size, Wp // window_size, window_size, window_size, -1]
  316. )
  317. x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, -1])
  318. if Hp > H or Wp > W:
  319. x = x[:, :H, :W, :]
  320. return x
  321. def get_rel_pos(q_size: int, k_size: int, rel_pos: paddle.Tensor) -> paddle.Tensor:
  322. """
  323. Get relative positional embeddings according to the relative positions of
  324. query and key sizes.
  325. Args:
  326. q_size (int): size of query q.
  327. k_size (int): size of key k.
  328. rel_pos (Tensor): relative position embeddings (L, C).
  329. Returns:
  330. Extracted positional embeddings according to relative positions.
  331. """
  332. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  333. if tuple(rel_pos.shape)[0] != max_rel_dist:
  334. rel_pos_resized = paddle.nn.functional.interpolate(
  335. rel_pos.reshape([1, tuple(rel_pos.shape)[0], -1]).transpose([0, 2, 1]),
  336. size=max_rel_dist,
  337. mode="linear",
  338. )
  339. rel_pos_resized = rel_pos_resized.reshape([-1, max_rel_dist]).transpose([1, 0])
  340. else:
  341. rel_pos_resized = rel_pos
  342. q_coords = paddle.arange(end=q_size)[:, None] * max(k_size / q_size, 1.0)
  343. k_coords = paddle.arange(end=k_size)[None, :] * max(q_size / k_size, 1.0)
  344. relative_coords = q_coords - k_coords + (k_size - 1) * max(q_size / k_size, 1.0)
  345. return rel_pos_resized[relative_coords.astype(dtype="int64")]
  346. def add_decomposed_rel_pos(
  347. attn: paddle.Tensor,
  348. q: paddle.Tensor,
  349. rel_pos_h: paddle.Tensor,
  350. rel_pos_w: paddle.Tensor,
  351. q_size: Tuple[int, int],
  352. k_size: Tuple[int, int],
  353. ) -> paddle.Tensor:
  354. """
  355. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  356. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
  357. Args:
  358. attn (Tensor): attention map.
  359. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
  360. rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
  361. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
  362. q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
  363. k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
  364. Returns:
  365. attn (Tensor): attention map with added relative positional embeddings.
  366. """
  367. q_h, q_w = q_size
  368. k_h, k_w = k_size
  369. Rh = get_rel_pos(q_h, k_h, rel_pos_h)
  370. Rw = get_rel_pos(q_w, k_w, rel_pos_w)
  371. B, _, dim = tuple(q.shape)
  372. r_q = q.reshape([B, q_h, q_w, dim])
  373. rel_h = paddle.einsum("bhwc,hkc->bhwk", r_q, Rh)
  374. rel_w = paddle.einsum("bhwc,wkc->bhwk", r_q, Rw)
  375. attn = (
  376. attn.reshape([B, q_h, q_w, k_h, k_w])
  377. + rel_h[:, :, :, :, None]
  378. + rel_w[:, :, :, None, :]
  379. ).reshape([B, q_h * q_w, k_h * k_w])
  380. return attn
  381. class PatchEmbed(paddle.nn.Layer):
  382. """
  383. Image to Patch Embedding.
  384. """
  385. def __init__(
  386. self,
  387. kernel_size: Tuple[int, int] = (16, 16),
  388. stride: Tuple[int, int] = (16, 16),
  389. padding: Tuple[int, int] = (0, 0),
  390. in_chans: int = 3,
  391. embed_dim: int = 768,
  392. ) -> None:
  393. """
  394. Args:
  395. kernel_size (Tuple): kernel size of the projection layer.
  396. stride (Tuple): stride of the projection layer.
  397. padding (Tuple): padding size of the projection layer.
  398. in_chans (int): Number of input image channels.
  399. embed_dim (int): Patch embedding dimension.
  400. """
  401. super().__init__()
  402. self.proj = nn.Conv2D(
  403. in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
  404. )
  405. def forward(self, x: paddle.Tensor) -> paddle.Tensor:
  406. x = self.proj(x)
  407. # B C H W -> B H W C
  408. x = x.transpose([0, 2, 3, 1])
  409. return x
  410. DEFAULT_IMAGE_TOKEN = "<image>"
  411. DEFAULT_IMAGE_PATCH_TOKEN = "<imgpad>"
  412. DEFAULT_IM_START_TOKEN = "<img>"
  413. DEFAULT_IM_END_TOKEN = "</img>"
  414. class Qwen2LMHead(nn.Layer):
  415. def __init__(
  416. self,
  417. config,
  418. embedding_weights=None,
  419. transpose_y=False,
  420. tensor_parallel_output=1,
  421. ):
  422. super(Qwen2LMHead, self).__init__()
  423. self.config = config
  424. vocab_size = config.vocab_size
  425. self.transpose_y = transpose_y
  426. if transpose_y:
  427. # only for weight from embedding_weights
  428. if embedding_weights is not None:
  429. self.weight = embedding_weights
  430. else:
  431. self.weight = self.create_parameter(
  432. shape=[vocab_size, config.hidden_size],
  433. dtype=paddle.get_default_dtype(),
  434. )
  435. else:
  436. # for weight from model init
  437. self.weight = self.create_parameter(
  438. shape=[config.hidden_size, vocab_size],
  439. dtype=paddle.get_default_dtype(),
  440. )
  441. def forward(self, hidden_states, tensor_parallel_output=1):
  442. logits = paddle.matmul(hidden_states, self.weight, transpose_y=self.transpose_y)
  443. return logits
  444. class GOTConfig(Qwen2Config):
  445. model_type = "GOT"
  446. class GOTQwenModel(Qwen2Model):
  447. config_class = GOTConfig
  448. def __init__(self, config: Qwen2Config):
  449. super(GOTQwenModel, self).__init__(config)
  450. self.vision_tower_high = ImageEncoderViT(
  451. depth=12,
  452. embed_dim=768,
  453. img_size=1024,
  454. mlp_ratio=4,
  455. norm_layer=partial(paddle.nn.LayerNorm, epsilon=1e-6),
  456. num_heads=12,
  457. patch_size=16,
  458. qkv_bias=True,
  459. use_rel_pos=True,
  460. global_attn_indexes=[2, 5, 8, 11],
  461. window_size=14,
  462. out_chans=256,
  463. )
  464. self.mm_projector_vary = nn.Linear(1024, 1024)
  465. def forward(
  466. self,
  467. input_ids: paddle.Tensor = None,
  468. attention_mask: Optional[paddle.Tensor] = None,
  469. position_ids: Optional[paddle.Tensor] = None,
  470. past_key_values: Optional[List[paddle.Tensor]] = None,
  471. inputs_embeds: Optional[paddle.Tensor] = None,
  472. use_cache: Optional[bool] = None,
  473. output_attentions: Optional[bool] = None,
  474. output_hidden_states: Optional[bool] = None,
  475. images: Optional[paddle.Tensor] = None,
  476. return_dict: Optional[bool] = None,
  477. ):
  478. # HACK: replace back original embeddings for LLaVA pretraining
  479. orig_embeds_params = getattr(self, "orig_embeds_params", None)
  480. if orig_embeds_params is not None:
  481. with paddle.no_grad():
  482. self.get_input_embeddings().weight[: -self.num_new_tokens] = (
  483. orig_embeds_params[: -self.num_new_tokens].data
  484. )
  485. if inputs_embeds is None:
  486. inputs_embeds = self.embed_tokens(input_ids)
  487. vision_tower_high = getattr(self, "vision_tower_high", None)
  488. if (
  489. vision_tower_high is not None
  490. and (input_ids.shape[1] != 1 or self.training)
  491. and images is not None
  492. ):
  493. use_im_start_end = getattr(self.config, "use_im_start_end", -1)
  494. im_patch_token = getattr(self.config, "im_patch_token", -1)
  495. im_start_token = getattr(self.config, "im_start_token", -1)
  496. im_end_token = getattr(self.config, "im_end_token", -1)
  497. im_patch_token = 151859
  498. im_start_token = 151857
  499. im_end_token = 151858
  500. image_features = []
  501. for image in images:
  502. if self.training:
  503. image = image[1]
  504. P, C, H, W = image.shape
  505. if P == 1:
  506. with paddle.set_grad_enabled(False):
  507. cnn_feature = vision_tower_high(image)
  508. cnn_feature = cnn_feature.flatten(2).transpose(
  509. [0, 2, 1]
  510. ) # 256*1024
  511. image_feature = self.mm_projector_vary(cnn_feature)
  512. image_features.append(image_feature)
  513. else:
  514. image_patches = paddle.unbind(image)
  515. image_patches_features = []
  516. for image_patch in image_patches:
  517. image_p = paddle.stack([image_patch])
  518. with paddle.set_grad_enabled(False):
  519. cnn_feature_p = vision_tower_high(image_p)
  520. cnn_feature_p = cnn_feature_p.flatten(2).transpose(
  521. [0, 2, 1]
  522. )
  523. image_feature_p = self.mm_projector_vary(cnn_feature_p)
  524. image_patches_features.append(image_feature_p)
  525. image_feature = paddle.concat(image_patches_features, axis=1)
  526. image_features.append(image_feature)
  527. dummy_image_features_2 = paddle.zeros(
  528. [256, 1024], dtype=inputs_embeds.dtype
  529. )
  530. dummy_image_features = dummy_image_features_2
  531. use_im_start_end = True
  532. new_input_embeds = []
  533. for cur_input_ids, cur_input_embeds, cur_image_features in zip(
  534. input_ids, inputs_embeds, image_features
  535. ):
  536. if (cur_input_ids == im_patch_token).sum() == 0:
  537. # multimodal LLM, but the current sample is not multimodal
  538. cur_input_embeds = (
  539. cur_input_embeds + (0.0 * dummy_image_features).sum()
  540. )
  541. new_input_embeds.append(cur_input_embeds)
  542. continue
  543. if use_im_start_end:
  544. if (cur_input_ids == im_start_token).sum() != (
  545. cur_input_ids == im_end_token
  546. ).sum():
  547. raise ValueError(
  548. "The number of image start tokens and image end tokens should be the same."
  549. )
  550. image_start_tokens = paddle.where(cur_input_ids == im_start_token)[
  551. 0
  552. ]
  553. for image_start_token_pos, per_cur_image_features in zip(
  554. image_start_tokens, cur_image_features
  555. ):
  556. num_patches = per_cur_image_features.shape[0]
  557. if (
  558. cur_input_ids[image_start_token_pos + num_patches + 1]
  559. != im_end_token
  560. ):
  561. raise ValueError(
  562. "The image end token should follow the image start token."
  563. )
  564. cur_input_embeds = paddle.concat(
  565. (
  566. cur_input_embeds[: image_start_token_pos + 1],
  567. per_cur_image_features,
  568. cur_input_embeds[
  569. image_start_token_pos + num_patches + 1 :
  570. ],
  571. ),
  572. axis=0,
  573. )
  574. new_input_embeds.append(cur_input_embeds)
  575. else:
  576. raise NotImplementedError
  577. inputs_embeds = paddle.stack(new_input_embeds, axis=0)
  578. return super().forward(
  579. input_ids=None,
  580. attention_mask=attention_mask,
  581. past_key_values=past_key_values,
  582. inputs_embeds=inputs_embeds,
  583. use_cache=use_cache,
  584. position_ids=position_ids,
  585. output_attentions=output_attentions,
  586. output_hidden_states=output_hidden_states,
  587. return_dict=return_dict,
  588. )
  589. class GOTQwenForCausalLM(Qwen2ForCausalLM):
  590. config_class = GOTConfig
  591. def __init__(self, config):
  592. super(Qwen2ForCausalLM, self).__init__(config)
  593. self.qwen2 = GOTQwenModel(config)
  594. self.vocab_size = config.vocab_size
  595. if config.tie_word_embeddings:
  596. self.lm_head = Qwen2LMHead(
  597. config,
  598. embedding_weights=self.qwen2.embed_tokens.weight,
  599. transpose_y=True,
  600. )
  601. self.tie_weights()
  602. else:
  603. self.lm_head = Qwen2LMHead(config)
  604. self.eval()
  605. def get_model(self):
  606. return self.qwen2
  607. def forward(
  608. self,
  609. input_ids: paddle.Tensor = None,
  610. attention_mask: Optional[paddle.Tensor] = None,
  611. position_ids: Optional[paddle.Tensor] = None,
  612. past_key_values: Optional[List[paddle.Tensor]] = None,
  613. inputs_embeds: Optional[paddle.Tensor] = None,
  614. labels: Optional[paddle.Tensor] = None,
  615. use_cache: Optional[bool] = None,
  616. output_attentions: Optional[bool] = None,
  617. output_hidden_states: Optional[bool] = None,
  618. images: Optional[paddle.Tensor] = None,
  619. return_dict: Optional[bool] = None,
  620. ):
  621. output_attentions = (
  622. output_attentions
  623. if output_attentions is not None
  624. else self.config.output_attentions
  625. )
  626. output_hidden_states = (
  627. output_hidden_states
  628. if output_hidden_states is not None
  629. else self.config.output_hidden_states
  630. )
  631. return_dict = (
  632. return_dict if return_dict is not None else self.config.use_return_dict
  633. )
  634. outputs = self.qwen2(
  635. input_ids=input_ids,
  636. past_key_values=past_key_values,
  637. attention_mask=attention_mask,
  638. position_ids=position_ids,
  639. inputs_embeds=inputs_embeds,
  640. use_cache=use_cache,
  641. output_attentions=output_attentions,
  642. output_hidden_states=output_hidden_states,
  643. images=images,
  644. return_dict=return_dict,
  645. )
  646. hidden_states = outputs[0]
  647. logits = self.lm_head(hidden_states)
  648. logits = logits.astype(dtype="float32")
  649. loss = None
  650. if labels is not None:
  651. # Shift so that tokens < n predict n
  652. shift_logits = logits[..., :-1, :]
  653. shift_labels = labels[..., 1:]
  654. loss_fct = nn.CrossEntropyLoss(reduction="sum")
  655. shift_logits = shift_logits.reshape([-1, self.config.vocab_size])
  656. shift_labels = shift_labels.reshape([-1])
  657. loss = loss_fct(shift_logits, shift_labels)
  658. label_sum = paddle.sum(shift_labels != -100)
  659. loss = loss / label_sum
  660. if not return_dict:
  661. output = (logits,) + outputs[1:]
  662. return (loss,) + output if loss is not None else output
  663. return CausalLMOutputWithPast(
  664. loss=loss,
  665. logits=logits,
  666. past_key_values=outputs.past_key_values,
  667. hidden_states=outputs.hidden_states,
  668. attentions=outputs.attentions,
  669. )
  670. def prepare_inputs_for_generation(
  671. self,
  672. input_ids,
  673. past_key_values=None,
  674. attention_mask=None,
  675. inputs_embeds=None,
  676. **kwargs
  677. ):
  678. batch_size, seq_length = input_ids.shape
  679. attention_mask = paddle.ones((batch_size, seq_length), dtype=paddle.bool)
  680. # Omit tokens covered by past_key_values
  681. if past_key_values is not None:
  682. past_length = past_key_values[0][0].shape[1]
  683. if past_length < input_ids.shape[1]:
  684. input_ids = input_ids[:, past_length:]
  685. position_ids = kwargs.get("position_ids", None)
  686. if attention_mask is not None and position_ids is None:
  687. # create position_ids on the fly for batch generation
  688. position_ids = attention_mask.astype(dtype="int64").cumsum(-1) - 1
  689. position_ids.masked_fill_(attention_mask == 0, 1)
  690. if past_key_values:
  691. position_ids = position_ids[:, -input_ids.shape[1] :]
  692. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  693. if inputs_embeds is not None and past_key_values is None:
  694. model_inputs = {"inputs_embeds": inputs_embeds}
  695. else:
  696. model_inputs = {"input_ids": input_ids}
  697. model_inputs.update(
  698. {
  699. "position_ids": position_ids,
  700. "past_key_values": past_key_values,
  701. "use_cache": kwargs.get("use_cache"),
  702. "attention_mask": attention_mask,
  703. "images": kwargs.get("images", None),
  704. }
  705. )
  706. return model_inputs
  707. class PPChart2TableInference(GOTQwenForCausalLM):
  708. def generate(self, inputs, **kwargs):
  709. max_new_tokens = kwargs.get("max_new_tokens", 1024)
  710. no_repeat_ngram_size = kwargs.get("no_repeat_ngram_size", 20)
  711. with paddle.no_grad():
  712. generated_ids = super().generate(
  713. inputs["input_ids"],
  714. images=inputs["images"],
  715. do_sample=False,
  716. num_beams=1,
  717. no_repeat_ngram_size=no_repeat_ngram_size,
  718. max_new_tokens=max_new_tokens,
  719. )
  720. return generated_ids