rec_donut_swin.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277
  1. import collections.abc
  2. from collections import OrderedDict
  3. import math
  4. from dataclasses import dataclass
  5. from typing import Optional, Tuple, Union
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. class DonutSwinConfig(object):
  10. model_type = "donut-swin"
  11. attribute_map = {
  12. "num_attention_heads": "num_heads",
  13. "num_hidden_layers": "num_layers",
  14. }
  15. def __init__(
  16. self,
  17. image_size=224,
  18. patch_size=4,
  19. num_channels=3,
  20. embed_dim=96,
  21. depths=[2, 2, 6, 2],
  22. num_heads=[3, 6, 12, 24],
  23. window_size=7,
  24. mlp_ratio=4.0,
  25. qkv_bias=True,
  26. hidden_dropout_prob=0.0,
  27. attention_probs_dropout_prob=0.0,
  28. drop_path_rate=0.1,
  29. hidden_act="gelu",
  30. use_absolute_embeddings=False,
  31. initializer_range=0.02,
  32. layer_norm_eps=1e-5,
  33. **kwargs,
  34. ):
  35. super().__init__()
  36. self.image_size = image_size
  37. self.patch_size = patch_size
  38. self.num_channels = num_channels
  39. self.embed_dim = embed_dim
  40. self.depths = depths
  41. self.num_layers = len(depths)
  42. self.num_heads = num_heads
  43. self.window_size = window_size
  44. self.mlp_ratio = mlp_ratio
  45. self.qkv_bias = qkv_bias
  46. self.hidden_dropout_prob = hidden_dropout_prob
  47. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  48. self.drop_path_rate = drop_path_rate
  49. self.hidden_act = hidden_act
  50. self.use_absolute_embeddings = use_absolute_embeddings
  51. self.layer_norm_eps = layer_norm_eps
  52. self.initializer_range = initializer_range
  53. self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
  54. for key, value in kwargs.items():
  55. try:
  56. setattr(self, key, value)
  57. except AttributeError as err:
  58. print(f"Can't set {key} with value {value} for {self}")
  59. raise err
  60. @dataclass
  61. # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
  62. class DonutSwinEncoderOutput(OrderedDict):
  63. last_hidden_state = None
  64. hidden_states = None
  65. attentions = None
  66. reshaped_hidden_states = None
  67. def __init__(self, *args, **kwargs):
  68. super().__init__(*args, **kwargs)
  69. def __getitem__(self, k):
  70. if isinstance(k, str):
  71. inner_dict = dict(self.items())
  72. return inner_dict[k]
  73. else:
  74. return self.to_tuple()[k]
  75. def __setattr__(self, name, value):
  76. if name in self.keys() and value is not None:
  77. super().__setitem__(name, value)
  78. super().__setattr__(name, value)
  79. def __setitem__(self, key, value):
  80. super().__setitem__(key, value)
  81. super().__setattr__(key, value)
  82. def to_tuple(self):
  83. """
  84. Convert self to a tuple containing all the attributes/keys that are not `None`.
  85. """
  86. return tuple(self[k] for k in self.keys())
  87. @dataclass
  88. # Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin
  89. class DonutSwinModelOutput(OrderedDict):
  90. last_hidden_state = None
  91. pooler_output = None
  92. hidden_states = None
  93. attentions = None
  94. reshaped_hidden_states = None
  95. def __init__(self, *args, **kwargs):
  96. super().__init__(*args, **kwargs)
  97. def __getitem__(self, k):
  98. if isinstance(k, str):
  99. inner_dict = dict(self.items())
  100. return inner_dict[k]
  101. else:
  102. return self.to_tuple()[k]
  103. def __setattr__(self, name, value):
  104. if name in self.keys() and value is not None:
  105. super().__setitem__(name, value)
  106. super().__setattr__(name, value)
  107. def __setitem__(self, key, value):
  108. super().__setitem__(key, value)
  109. super().__setattr__(key, value)
  110. def to_tuple(self):
  111. """
  112. Convert self to a tuple containing all the attributes/keys that are not `None`.
  113. """
  114. return tuple(self[k] for k in self.keys())
  115. # Copied from transformers.models.swin.modeling_swin.window_partition
  116. def window_partition(input_feature, window_size):
  117. """
  118. Partitions the given input into windows.
  119. """
  120. batch_size, height, width, num_channels = input_feature.shape
  121. input_feature = input_feature.reshape(
  122. [
  123. batch_size,
  124. height // window_size,
  125. window_size,
  126. width // window_size,
  127. window_size,
  128. num_channels,
  129. ]
  130. )
  131. windows = input_feature.transpose([0, 1, 3, 2, 4, 5]).reshape(
  132. [-1, window_size, window_size, num_channels]
  133. )
  134. return windows
  135. # Copied from transformers.models.swin.modeling_swin.window_reverse
  136. def window_reverse(windows, window_size, height, width):
  137. """
  138. Merges windows to produce higher resolution features.
  139. """
  140. num_channels = windows.shape[-1]
  141. windows = windows.reshape(
  142. [
  143. -1,
  144. height // window_size,
  145. width // window_size,
  146. window_size,
  147. window_size,
  148. num_channels,
  149. ]
  150. )
  151. windows = windows.transpose([0, 1, 3, 2, 4, 5]).reshape(
  152. [-1, height, width, num_channels]
  153. )
  154. return windows
  155. # Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
  156. class DonutSwinEmbeddings(nn.Module):
  157. """
  158. Construct the patch and position embeddings. Optionally, also the mask token.
  159. """
  160. def __init__(self, config, use_mask_token=False):
  161. super().__init__()
  162. self.patch_embeddings = DonutSwinPatchEmbeddings(config)
  163. num_patches = self.patch_embeddings.num_patches
  164. self.patch_grid = self.patch_embeddings.grid_size
  165. if use_mask_token:
  166. # self.mask_token = paddle.create_parameter(
  167. # [1, 1, config.embed_dim], dtype="float32"
  168. # )
  169. self.mask_token = nn.Parameter(
  170. nn.init.xavier_uniform_(torch.zeros(1, 1, config.embed_dim).to(torch.float32))
  171. )
  172. nn.init.zeros_(self.mask_token)
  173. else:
  174. self.mask_token = None
  175. if config.use_absolute_embeddings:
  176. # self.position_embeddings = paddle.create_parameter(
  177. # [1, num_patches + 1, config.embed_dim], dtype="float32"
  178. # )
  179. self.position_embeddings = nn.Parameter(
  180. nn.init.xavier_uniform_(torch.zeros(1, num_patches + 1, config.embed_dim).to(torch.float32))
  181. )
  182. nn.init.zeros_(self.position_embedding)
  183. else:
  184. self.position_embeddings = None
  185. self.norm = nn.LayerNorm(config.embed_dim)
  186. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  187. def forward(self, pixel_values, bool_masked_pos=None):
  188. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  189. embeddings = self.norm(embeddings)
  190. batch_size, seq_len, _ = embeddings.shape
  191. if bool_masked_pos is not None:
  192. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  193. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  194. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  195. if self.position_embeddings is not None:
  196. embeddings = embeddings + self.position_embeddings
  197. embeddings = self.dropout(embeddings)
  198. return embeddings, output_dimensions
  199. class MyConv2d(nn.Conv2d):
  200. def __init__(
  201. self,
  202. in_channel,
  203. out_channels,
  204. kernel_size,
  205. stride=1,
  206. padding="SAME",
  207. dilation=1,
  208. groups=1,
  209. bias_attr=False,
  210. eps=1e-6,
  211. ):
  212. super().__init__(
  213. in_channel,
  214. out_channels,
  215. kernel_size,
  216. stride=stride,
  217. padding=padding,
  218. dilation=dilation,
  219. groups=groups,
  220. bias_attr=bias_attr,
  221. )
  222. # self.weight = paddle.create_parameter(
  223. # [out_channels, in_channel, kernel_size[0], kernel_size[1]], dtype="float32"
  224. # )
  225. self.weight = torch.Parameter(
  226. nn.init.xavier_uniform_(
  227. torch.zeros(out_channels, in_channel, kernel_size[0], kernel_size[1]).to(torch.float32)
  228. )
  229. )
  230. # self.bias = paddle.create_parameter([out_channels], dtype="float32")
  231. self.bias = torch.Parameter(
  232. nn.init.xavier_uniform_(
  233. torch.zeros(out_channels).to(torch.float32)
  234. )
  235. )
  236. nn.init.ones_(self.weight)
  237. nn.init.zeros_(self.bias)
  238. def forward(self, x):
  239. x = F.conv2d(
  240. x,
  241. self.weight,
  242. self.bias,
  243. self._stride,
  244. self._padding,
  245. self._dilation,
  246. self._groups,
  247. )
  248. return x
  249. # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
  250. class DonutSwinPatchEmbeddings(nn.Module):
  251. """
  252. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  253. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  254. Transformer.
  255. """
  256. def __init__(self, config):
  257. super().__init__()
  258. image_size, patch_size = config.image_size, config.patch_size
  259. num_channels, hidden_size = config.num_channels, config.embed_dim
  260. image_size = (
  261. image_size
  262. if isinstance(image_size, collections.abc.Iterable)
  263. else (image_size, image_size)
  264. )
  265. patch_size = (
  266. patch_size
  267. if isinstance(patch_size, collections.abc.Iterable)
  268. else (patch_size, patch_size)
  269. )
  270. num_patches = (image_size[1] // patch_size[1]) * (
  271. image_size[0] // patch_size[0]
  272. )
  273. self.image_size = image_size
  274. self.patch_size = patch_size
  275. self.num_channels = num_channels
  276. self.num_patches = num_patches
  277. self.is_export = config.is_export
  278. self.grid_size = (
  279. image_size[0] // patch_size[0],
  280. image_size[1] // patch_size[1],
  281. )
  282. self.projection = nn.Conv2D(
  283. num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
  284. )
  285. def maybe_pad(self, pixel_values, height, width):
  286. if width % self.patch_size[1] != 0:
  287. pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
  288. if self.is_export:
  289. pad_values = torch.tensor(pad_values, dtype=torch.int32)
  290. pixel_values = nn.functional.pad(pixel_values, pad_values)
  291. if height % self.patch_size[0] != 0:
  292. pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
  293. if self.is_export:
  294. pad_values = torch.tensor(pad_values, dtype=torch.int32)
  295. pixel_values = nn.functional.pad(pixel_values, pad_values)
  296. return pixel_values
  297. def forward(self, pixel_values) -> Tuple[torch.Tensor, Tuple[int]]:
  298. _, num_channels, height, width = pixel_values.shape
  299. if num_channels != self.num_channels:
  300. raise ValueError(
  301. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  302. )
  303. pixel_values = self.maybe_pad(pixel_values, height, width)
  304. embeddings = self.projection(pixel_values)
  305. _, _, height, width = embeddings.shape
  306. output_dimensions = (height, width)
  307. embeddings = embeddings.flatten(2).transpose([0, 2, 1])
  308. return embeddings, output_dimensions
  309. # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
  310. class DonutSwinPatchMerging(nn.Module):
  311. """
  312. Patch Merging Layer.
  313. Args:
  314. input_resolution (`Tuple[int]`):
  315. Resolution of input feature.
  316. dim (`int`):
  317. Number of input channels.
  318. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  319. Normalization layer class.
  320. """
  321. def __init__(
  322. self,
  323. input_resolution: Tuple[int],
  324. dim: int,
  325. norm_layer: nn.Module = nn.LayerNorm,
  326. is_export=False,
  327. ):
  328. super().__init__()
  329. self.input_resolution = input_resolution
  330. self.dim = dim
  331. self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
  332. self.norm = norm_layer(4 * dim)
  333. self.is_export = is_export
  334. def maybe_pad(self, input_feature, height, width):
  335. should_pad = (height % 2 == 1) or (width % 2 == 1)
  336. if should_pad:
  337. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  338. if self.is_export:
  339. pad_values = torch.tensor(pad_values, dtype=torch.int32)
  340. input_feature = nn.functional.pad(input_feature, pad_values)
  341. return input_feature
  342. def forward(
  343. self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]
  344. ) -> torch.Tensor:
  345. height, width = input_dimensions
  346. batch_size, dim, num_channels = input_feature.shape
  347. input_feature = input_feature.reshape([batch_size, height, width, num_channels])
  348. input_feature = self.maybe_pad(input_feature, height, width)
  349. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  350. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  351. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  352. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  353. input_feature = torch.cat(
  354. [input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1
  355. )
  356. input_feature = input_feature.reshape(
  357. [batch_size, -1, 4 * num_channels]
  358. ) # batch_size height/2*width/2 4*C
  359. input_feature = self.norm(input_feature)
  360. input_feature = self.reduction(input_feature)
  361. return input_feature
  362. # Copied from transformers.models.beit.modeling_beit.drop_path
  363. def drop_path(
  364. input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
  365. ) -> torch.Tensor:
  366. if drop_prob == 0.0 or not training:
  367. return input
  368. keep_prob = 1 - drop_prob
  369. shape = (input.shape[0],) + (1,) * (
  370. input.ndim - 1
  371. ) # work with diff dim tensors, not just 2D ConvNets
  372. random_tensor = keep_prob + torch.rand(
  373. shape,
  374. dtype=input.dtype,
  375. )
  376. random_tensor.floor_() # binarize
  377. output = input / keep_prob * random_tensor
  378. return output
  379. # Copied from transformers.models.swin.modeling_swin.SwinDropPath
  380. class DonutSwinDropPath(nn.Module):
  381. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  382. def __init__(self, drop_prob: Optional[float] = None) -> None:
  383. super().__init__()
  384. self.drop_prob = drop_prob
  385. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  386. return drop_path(hidden_states, self.drop_prob, self.training)
  387. def extra_repr(self) -> str:
  388. return "p={}".format(self.drop_prob)
  389. class DonutSwinSelfAttention(nn.Module):
  390. def __init__(self, config, dim, num_heads, window_size):
  391. super().__init__()
  392. if dim % num_heads != 0:
  393. raise ValueError(
  394. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  395. )
  396. self.num_attention_heads = num_heads
  397. self.attention_head_size = int(dim / num_heads)
  398. self.all_head_size = self.num_attention_heads * self.attention_head_size
  399. self.window_size = (
  400. window_size
  401. if isinstance(window_size, collections.abc.Iterable)
  402. else (window_size, window_size)
  403. )
  404. # self.relative_position_bias_table = paddle.create_parameter(
  405. # [(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads],
  406. # dtype="float32",
  407. # )
  408. self.relative_position_bias_table = torch.Parameter(
  409. nn.init.xavier_normal_(
  410. torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads).to(torch.float32)
  411. )
  412. )
  413. nn.init.zeros_(self.relative_position_bias_table)
  414. # get pair-wise relative position index for each token inside the window
  415. coords_h = torch.arange(self.window_size[0])
  416. coords_w = torch.arange(self.window_size[1])
  417. coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
  418. coords_flatten = torch.flatten(coords, 1)
  419. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  420. relative_coords = relative_coords.transpose([1, 2, 0])
  421. relative_coords[:, :, 0] += self.window_size[0] - 1
  422. relative_coords[:, :, 1] += self.window_size[1] - 1
  423. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  424. relative_position_index = relative_coords.sum(-1)
  425. self.register_buffer("relative_position_index", relative_position_index)
  426. self.query = nn.Linear(
  427. self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
  428. )
  429. self.key = nn.Linear(
  430. self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
  431. )
  432. self.value = nn.Linear(
  433. self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
  434. )
  435. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  436. def transpose_for_scores(self, x):
  437. new_x_shape = x.shape[:-1] + [
  438. self.num_attention_heads,
  439. self.attention_head_size,
  440. ]
  441. x = x.reshape(new_x_shape)
  442. return x.transpose([0, 2, 1, 3])
  443. def forward(
  444. self,
  445. hidden_states: torch.Tensor,
  446. attention_mask=None,
  447. head_mask=None,
  448. output_attentions=False,
  449. ) -> Tuple[torch.Tensor]:
  450. batch_size, dim, num_channels = hidden_states.shape
  451. mixed_query_layer = self.query(hidden_states)
  452. key_layer = self.transpose_for_scores(self.key(hidden_states))
  453. value_layer = self.transpose_for_scores(self.value(hidden_states))
  454. query_layer = self.transpose_for_scores(mixed_query_layer)
  455. # Take the dot product between "query" and "key" to get the raw attention scores.
  456. attention_scores = torch.matmul(query_layer, key_layer.transpose([0, 1, 3, 2]))
  457. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  458. relative_position_bias = self.relative_position_bias_table[
  459. self.relative_position_index.reshape([-1])
  460. ]
  461. relative_position_bias = relative_position_bias.reshape(
  462. [
  463. self.window_size[0] * self.window_size[1],
  464. self.window_size[0] * self.window_size[1],
  465. -1,
  466. ]
  467. )
  468. relative_position_bias = relative_position_bias.transpose([2, 0, 1])
  469. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  470. if attention_mask is not None:
  471. # Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function)
  472. mask_shape = attention_mask.shape[0]
  473. attention_scores = attention_scores.reshape(
  474. [
  475. batch_size // mask_shape,
  476. mask_shape,
  477. self.num_attention_heads,
  478. dim,
  479. dim,
  480. ]
  481. )
  482. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(
  483. 0
  484. )
  485. attention_scores = attention_scores.reshape(
  486. [-1, self.num_attention_heads, dim, dim]
  487. )
  488. # Normalize the attention scores to probabilities.
  489. attention_probs = nn.functional.softmax(attention_scores, axis=-1)
  490. # This is actually dropping out entire tokens to attend to, which might
  491. # seem a bit unusual, but is taken from the original Transformer paper.
  492. attention_probs = self.dropout(attention_probs)
  493. # Mask heads if we want to
  494. if head_mask is not None:
  495. attention_probs = attention_probs * head_mask
  496. context_layer = torch.matmul(attention_probs, value_layer)
  497. context_layer = context_layer.transpose([0, 2, 1, 3])
  498. new_context_layer_shape = tuple(context_layer.shape[:-2]) + (
  499. self.all_head_size,
  500. )
  501. context_layer = context_layer.reshape(new_context_layer_shape)
  502. outputs = (
  503. (context_layer, attention_probs) if output_attentions else (context_layer,)
  504. )
  505. return outputs
  506. # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
  507. class DonutSwinSelfOutput(nn.Module):
  508. def __init__(self, config, dim):
  509. super().__init__()
  510. self.dense = nn.Linear(dim, dim)
  511. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  512. def forward(
  513. self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
  514. ) -> torch.Tensor:
  515. hidden_states = self.dense(hidden_states)
  516. hidden_states = self.dropout(hidden_states)
  517. return hidden_states
  518. # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
  519. class DonutSwinAttention(nn.Module):
  520. def __init__(self, config, dim, num_heads, window_size):
  521. super().__init__()
  522. self.self = DonutSwinSelfAttention(config, dim, num_heads, window_size)
  523. self.output = DonutSwinSelfOutput(config, dim)
  524. self.pruned_heads = set()
  525. def forward(
  526. self,
  527. hidden_states: torch.Tensor,
  528. attention_mask=None,
  529. head_mask=None,
  530. output_attentions=False,
  531. ) -> Tuple[torch.Tensor]:
  532. self_outputs = self.self(
  533. hidden_states, attention_mask, head_mask, output_attentions
  534. )
  535. attention_output = self.output(self_outputs[0], hidden_states)
  536. outputs = (attention_output,) + self_outputs[
  537. 1:
  538. ] # add attentions if we output them
  539. return outputs
  540. # Copied from transformers.models.swin.modeling_swin.SwinIntermediate
  541. class DonutSwinIntermediate(nn.Module):
  542. def __init__(self, config, dim):
  543. super().__init__()
  544. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  545. self.intermediate_act_fn = F.gelu
  546. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  547. hidden_states = self.dense(hidden_states)
  548. hidden_states = self.intermediate_act_fn(hidden_states)
  549. return hidden_states
  550. # Copied from transformers.models.swin.modeling_swin.SwinOutput
  551. class DonutSwinOutput(nn.Module):
  552. def __init__(self, config, dim):
  553. super().__init__()
  554. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  555. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  556. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  557. hidden_states = self.dense(hidden_states)
  558. hidden_states = self.dropout(hidden_states)
  559. return hidden_states
  560. # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
  561. class DonutSwinLayer(nn.Module):
  562. def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
  563. super().__init__()
  564. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  565. self.shift_size = shift_size
  566. self.window_size = config.window_size
  567. self.input_resolution = input_resolution
  568. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  569. self.attention = DonutSwinAttention(
  570. config, dim, num_heads, window_size=self.window_size
  571. )
  572. self.drop_path = (
  573. DonutSwinDropPath(config.drop_path_rate)
  574. if config.drop_path_rate > 0.0
  575. else nn.Identity()
  576. )
  577. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  578. self.intermediate = DonutSwinIntermediate(config, dim)
  579. self.output = DonutSwinOutput(config, dim)
  580. self.is_export = config.is_export
  581. def set_shift_and_window_size(self, input_resolution):
  582. if min(input_resolution) <= self.window_size:
  583. # if window size is larger than input resolution, we don't partition windows
  584. self.shift_size = 0
  585. self.window_size = min(input_resolution)
  586. def get_attn_mask_export(self, height, width, dtype):
  587. attn_mask = None
  588. height_slices = (
  589. slice(0, -self.window_size),
  590. slice(-self.window_size, -self.shift_size),
  591. slice(-self.shift_size, None),
  592. )
  593. width_slices = (
  594. slice(0, -self.window_size),
  595. slice(-self.window_size, -self.shift_size),
  596. slice(-self.shift_size, None),
  597. )
  598. img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
  599. count = 0
  600. for height_slice in height_slices:
  601. for width_slice in width_slices:
  602. if self.shift_size > 0:
  603. img_mask[:, height_slice, width_slice, :] = count
  604. count += 1
  605. if torch.Tensor(self.shift_size > 0).to(torch.bool):
  606. # calculate attention mask for SW-MSA
  607. mask_windows = window_partition(img_mask, self.window_size)
  608. mask_windows = mask_windows.reshape(
  609. [-1, self.window_size * self.window_size]
  610. )
  611. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  612. attn_mask = attn_mask.masked_fill(
  613. attn_mask != 0, float(-100.0)
  614. ).masked_fill(attn_mask == 0, float(0.0))
  615. return attn_mask
  616. def get_attn_mask(self, height, width, dtype):
  617. if self.shift_size > 0:
  618. # calculate attention mask for SW-MSA
  619. img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
  620. height_slices = (
  621. slice(0, -self.window_size),
  622. slice(-self.window_size, -self.shift_size),
  623. slice(-self.shift_size, None),
  624. )
  625. width_slices = (
  626. slice(0, -self.window_size),
  627. slice(-self.window_size, -self.shift_size),
  628. slice(-self.shift_size, None),
  629. )
  630. count = 0
  631. for height_slice in height_slices:
  632. for width_slice in width_slices:
  633. img_mask[:, height_slice, width_slice, :] = count
  634. count += 1
  635. mask_windows = window_partition(img_mask, self.window_size)
  636. mask_windows = mask_windows.reshape(
  637. [-1, self.window_size * self.window_size]
  638. )
  639. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  640. attn_mask = attn_mask.masked_fill(
  641. attn_mask != 0, float(-100.0)
  642. ).masked_fill(attn_mask == 0, float(0.0))
  643. else:
  644. attn_mask = None
  645. return attn_mask
  646. def maybe_pad(self, hidden_states, height, width):
  647. pad_right = (self.window_size - width % self.window_size) % self.window_size
  648. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  649. pad_values = (0, 0, 0, pad_bottom, 0, pad_right, 0, 0)
  650. hidden_states = nn.functional.pad(hidden_states, pad_values)
  651. return hidden_states, pad_values
  652. def forward(
  653. self,
  654. hidden_states: torch.Tensor,
  655. input_dimensions: Tuple[int, int],
  656. head_mask=None,
  657. output_attentions=False,
  658. always_partition=False,
  659. ) -> Tuple[torch.Tensor, torch.Tensor]:
  660. if not always_partition:
  661. self.set_shift_and_window_size(input_dimensions)
  662. else:
  663. pass
  664. height, width = input_dimensions
  665. batch_size, _, channels = hidden_states.shape
  666. shortcut = hidden_states
  667. hidden_states = self.layernorm_before(hidden_states)
  668. hidden_states = hidden_states.reshape([batch_size, height, width, channels])
  669. # pad hidden_states to multiples of window size
  670. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  671. _, height_pad, width_pad, _ = hidden_states.shape
  672. # cyclic shift
  673. if self.shift_size > 0:
  674. shift_value = (-self.shift_size, -self.shift_size)
  675. if self.is_export:
  676. shift_value = torch.tensor(shift_value, dtype=torch.int32)
  677. shifted_hidden_states = torch.roll(
  678. hidden_states, shifts=shift_value, dims=(1, 2)
  679. )
  680. else:
  681. shifted_hidden_states = hidden_states
  682. # partition windows
  683. hidden_states_windows = window_partition(
  684. shifted_hidden_states, self.window_size
  685. )
  686. hidden_states_windows = hidden_states_windows.reshape(
  687. [-1, self.window_size * self.window_size, channels]
  688. )
  689. attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
  690. attention_outputs = self.attention(
  691. hidden_states_windows,
  692. attn_mask,
  693. head_mask,
  694. output_attentions=output_attentions,
  695. )
  696. attention_output = attention_outputs[0]
  697. attention_windows = attention_output.reshape(
  698. [-1, self.window_size, self.window_size, channels]
  699. )
  700. shifted_windows = window_reverse(
  701. attention_windows, self.window_size, height_pad, width_pad
  702. )
  703. # reverse cyclic shift
  704. if self.shift_size > 0:
  705. shift_value = (self.shift_size, self.shift_size)
  706. if self.is_export:
  707. shift_value = torch.tensor(shift_value, dtype=torch.int32)
  708. attention_windows = torch.roll(
  709. shifted_windows, shifts=shift_value, dims=(1, 2)
  710. )
  711. else:
  712. attention_windows = shifted_windows
  713. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  714. if was_padded:
  715. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  716. attention_windows = attention_windows.reshape(
  717. [batch_size, height * width, channels]
  718. )
  719. hidden_states = shortcut + self.drop_path(attention_windows)
  720. layer_output = self.layernorm_after(hidden_states)
  721. layer_output = self.intermediate(layer_output)
  722. layer_output = hidden_states + self.output(layer_output)
  723. layer_outputs = (
  724. (layer_output, attention_outputs[1])
  725. if output_attentions
  726. else (layer_output,)
  727. )
  728. return layer_outputs
  729. # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
  730. class DonutSwinStage(nn.Module):
  731. def __init__(
  732. self, config, dim, input_resolution, depth, num_heads, drop_path, downsample
  733. ):
  734. super().__init__()
  735. self.config = config
  736. self.dim = dim
  737. self.blocks = nn.ModuleList(
  738. [
  739. DonutSwinLayer(
  740. config=config,
  741. dim=dim,
  742. input_resolution=input_resolution,
  743. num_heads=num_heads,
  744. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  745. )
  746. for i in range(depth)
  747. ]
  748. )
  749. self.is_export = config.is_export
  750. # patch merging layer
  751. if downsample is not None:
  752. self.downsample = downsample(
  753. input_resolution,
  754. dim=dim,
  755. norm_layer=nn.LayerNorm,
  756. is_export=self.is_export,
  757. )
  758. else:
  759. self.downsample = None
  760. self.pointing = False
  761. def forward(
  762. self,
  763. hidden_states: torch.Tensor,
  764. input_dimensions: Tuple[int, int],
  765. head_mask=None,
  766. output_attentions=False,
  767. always_partition=False,
  768. ) -> Tuple[torch.Tensor]:
  769. height, width = input_dimensions
  770. for i, layer_module in enumerate(self.blocks):
  771. layer_head_mask = head_mask[i] if head_mask is not None else None
  772. layer_outputs = layer_module(
  773. hidden_states,
  774. input_dimensions,
  775. layer_head_mask,
  776. output_attentions,
  777. always_partition,
  778. )
  779. hidden_states = layer_outputs[0]
  780. hidden_states_before_downsampling = hidden_states
  781. if self.downsample is not None:
  782. height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
  783. output_dimensions = (height, width, height_downsampled, width_downsampled)
  784. hidden_states = self.downsample(
  785. hidden_states_before_downsampling, input_dimensions
  786. )
  787. else:
  788. output_dimensions = (height, width, height, width)
  789. stage_outputs = (
  790. hidden_states,
  791. hidden_states_before_downsampling,
  792. output_dimensions,
  793. )
  794. if output_attentions:
  795. stage_outputs += layer_outputs[1:]
  796. return stage_outputs
  797. # Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
  798. class DonutSwinEncoder(nn.Module):
  799. def __init__(self, config, grid_size):
  800. super().__init__()
  801. self.num_layers = len(config.depths)
  802. self.config = config
  803. dpr = [
  804. x.item()
  805. for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))
  806. ]
  807. self.layers = nn.ModuleList(
  808. [
  809. DonutSwinStage(
  810. config=config,
  811. dim=int(config.embed_dim * 2**i_layer),
  812. input_resolution=(
  813. grid_size[0] // (2**i_layer),
  814. grid_size[1] // (2**i_layer),
  815. ),
  816. depth=config.depths[i_layer],
  817. num_heads=config.num_heads[i_layer],
  818. drop_path=dpr[
  819. sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])
  820. ],
  821. downsample=(
  822. DonutSwinPatchMerging
  823. if (i_layer < self.num_layers - 1)
  824. else None
  825. ),
  826. )
  827. for i_layer in range(self.num_layers)
  828. ]
  829. )
  830. self.gradient_checkpointing = False
  831. def forward(
  832. self,
  833. hidden_states: torch.Tensor,
  834. input_dimensions: Tuple[int, int],
  835. head_mask=None,
  836. output_attentions=False,
  837. output_hidden_states=False,
  838. output_hidden_states_before_downsampling=False,
  839. always_partition=False,
  840. return_dict=True,
  841. ):
  842. all_hidden_states = () if output_hidden_states else None
  843. all_reshaped_hidden_states = () if output_hidden_states else None
  844. all_self_attentions = () if output_attentions else None
  845. if output_hidden_states:
  846. batch_size, _, hidden_size = hidden_states.shape
  847. reshaped_hidden_state = hidden_states.view(
  848. batch_size, *input_dimensions, hidden_size
  849. )
  850. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  851. all_hidden_states += (hidden_states,)
  852. all_reshaped_hidden_states += (reshaped_hidden_state,)
  853. for i, layer_module in enumerate(self.layers):
  854. layer_head_mask = head_mask[i] if head_mask is not None else None
  855. if self.gradient_checkpointing and self.training:
  856. layer_outputs = self._gradient_checkpointing_func(
  857. layer_module.__call__,
  858. hidden_states,
  859. input_dimensions,
  860. layer_head_mask,
  861. output_attentions,
  862. always_partition,
  863. )
  864. else:
  865. layer_outputs = layer_module(
  866. hidden_states,
  867. input_dimensions,
  868. layer_head_mask,
  869. output_attentions,
  870. always_partition,
  871. )
  872. hidden_states = layer_outputs[0]
  873. hidden_states_before_downsampling = layer_outputs[1]
  874. output_dimensions = layer_outputs[2]
  875. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  876. if output_hidden_states and output_hidden_states_before_downsampling:
  877. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  878. reshaped_hidden_state = hidden_states_before_downsampling.reshape(
  879. [
  880. batch_size,
  881. *(output_dimensions[0], output_dimensions[1]),
  882. hidden_size,
  883. ]
  884. )
  885. reshaped_hidden_state = reshaped_hidden_state.transpose([0, 3, 1, 2])
  886. all_hidden_states += (hidden_states_before_downsampling,)
  887. all_reshaped_hidden_states += (reshaped_hidden_state,)
  888. elif output_hidden_states and not output_hidden_states_before_downsampling:
  889. batch_size, _, hidden_size = hidden_states.shape
  890. reshaped_hidden_state = hidden_states.reshape(
  891. [batch_size, *input_dimensions, hidden_size]
  892. )
  893. reshaped_hidden_state = reshaped_hidden_state.transpose([0, 3, 1, 2])
  894. all_hidden_states += (hidden_states,)
  895. all_reshaped_hidden_states += (reshaped_hidden_state,)
  896. if output_attentions:
  897. all_self_attentions += layer_outputs[3:]
  898. if not return_dict:
  899. return tuple(
  900. v
  901. for v in [hidden_states, all_hidden_states, all_self_attentions]
  902. if v is not None
  903. )
  904. return DonutSwinEncoderOutput(
  905. last_hidden_state=hidden_states,
  906. hidden_states=all_hidden_states,
  907. attentions=all_self_attentions,
  908. reshaped_hidden_states=all_reshaped_hidden_states,
  909. )
  910. class DonutSwinPreTrainedModel(nn.Module):
  911. """
  912. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  913. models.
  914. """
  915. config_class = DonutSwinConfig
  916. base_model_prefix = "swin"
  917. main_input_name = "pixel_values"
  918. supports_gradient_checkpointing = True
  919. def _init_weights(self, module):
  920. """Initialize the weights"""
  921. if isinstance(module, (nn.Linear, nn.Conv2D)):
  922. # normal_ = Normal(mean=0.0, std=self.config.initializer_range)
  923. nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  924. if module.bias is not None:
  925. nn.init.zeros_(module.bias)
  926. elif isinstance(module, nn.LayerNorm):
  927. nn.init.zeros_(module.bias)
  928. nn.init.ones_(module.weight)
  929. def _initialize_weights(self, module):
  930. """
  931. Initialize the weights if they are not already initialized.
  932. """
  933. if getattr(module, "_is_hf_initialized", False):
  934. return
  935. self._init_weights(module)
  936. def post_init(self):
  937. self.apply(self._initialize_weights)
  938. def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
  939. if head_mask is not None:
  940. head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
  941. if is_attention_chunked is True:
  942. head_mask = head_mask.unsqueeze(-1)
  943. else:
  944. head_mask = [None] * num_hidden_layers
  945. return head_mask
  946. class DonutSwinModel(DonutSwinPreTrainedModel):
  947. def __init__(
  948. self,
  949. in_channels=3,
  950. hidden_size=1024,
  951. num_layers=4,
  952. num_heads=[4, 8, 16, 32],
  953. add_pooling_layer=True,
  954. use_mask_token=False,
  955. is_export=False,
  956. ):
  957. super().__init__()
  958. donut_swin_config = {
  959. "return_dict": True,
  960. "output_hidden_states": False,
  961. "output_attentions": False,
  962. "use_bfloat16": False,
  963. "tf_legacy_loss": False,
  964. "pruned_heads": {},
  965. "tie_word_embeddings": True,
  966. "chunk_size_feed_forward": 0,
  967. "is_encoder_decoder": False,
  968. "is_decoder": False,
  969. "cross_attention_hidden_size": None,
  970. "add_cross_attention": False,
  971. "tie_encoder_decoder": False,
  972. "max_length": 20,
  973. "min_length": 0,
  974. "do_sample": False,
  975. "early_stopping": False,
  976. "num_beams": 1,
  977. "num_beam_groups": 1,
  978. "diversity_penalty": 0.0,
  979. "temperature": 1.0,
  980. "top_k": 50,
  981. "top_p": 1.0,
  982. "typical_p": 1.0,
  983. "repetition_penalty": 1.0,
  984. "length_penalty": 1.0,
  985. "no_repeat_ngram_size": 0,
  986. "encoder_no_repeat_ngram_size": 0,
  987. "bad_words_ids": None,
  988. "num_return_sequences": 1,
  989. "output_scores": False,
  990. "return_dict_in_generate": False,
  991. "forced_bos_token_id": None,
  992. "forced_eos_token_id": None,
  993. "remove_invalid_values": False,
  994. "exponential_decay_length_penalty": None,
  995. "suppress_tokens": None,
  996. "begin_suppress_tokens": None,
  997. "architectures": None,
  998. "finetuning_task": None,
  999. "id2label": {0: "LABEL_0", 1: "LABEL_1"},
  1000. "label2id": {"LABEL_0": 0, "LABEL_1": 1},
  1001. "tokenizer_class": None,
  1002. "prefix": None,
  1003. "bos_token_id": None,
  1004. "pad_token_id": None,
  1005. "eos_token_id": None,
  1006. "sep_token_id": None,
  1007. "decoder_start_token_id": None,
  1008. "task_specific_params": None,
  1009. "problem_type": None,
  1010. "_name_or_path": "",
  1011. "_commit_hash": None,
  1012. "_attn_implementation_internal": None,
  1013. "transformers_version": None,
  1014. "hidden_size": hidden_size,
  1015. "num_layers": num_layers,
  1016. "path_norm": True,
  1017. "use_2d_embeddings": False,
  1018. "image_size": [420, 420],
  1019. "patch_size": 4,
  1020. "num_channels": in_channels,
  1021. "embed_dim": 128,
  1022. "depths": [2, 2, 14, 2],
  1023. "num_heads": num_heads,
  1024. "window_size": 5,
  1025. "mlp_ratio": 4.0,
  1026. "qkv_bias": True,
  1027. "hidden_dropout_prob": 0.0,
  1028. "attention_probs_dropout_prob": 0.0,
  1029. "drop_path_rate": 0.1,
  1030. "hidden_act": "gelu",
  1031. "use_absolute_embeddings": False,
  1032. "layer_norm_eps": 1e-05,
  1033. "initializer_range": 0.02,
  1034. "is_export": is_export,
  1035. }
  1036. config = DonutSwinConfig(**donut_swin_config)
  1037. self.config = config
  1038. self.num_layers = len(config.depths)
  1039. self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
  1040. self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
  1041. self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
  1042. self.pooler = nn.AdaptiveAvgPool1D(1) if add_pooling_layer else None
  1043. self.out_channels = hidden_size
  1044. self.post_init()
  1045. def get_input_embeddings(self):
  1046. return self.embeddings.patch_embeddings
  1047. def forward(
  1048. self,
  1049. input_data=None,
  1050. bool_masked_pos=None,
  1051. head_mask=None,
  1052. output_attentions=None,
  1053. output_hidden_states=None,
  1054. return_dict=None,
  1055. ) -> Union[Tuple, DonutSwinModelOutput]:
  1056. r"""
  1057. bool_masked_pos (`paddle.BoolTensor` of shape `(batch_size, num_patches)`):
  1058. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  1059. """
  1060. if self.training:
  1061. pixel_values, label, attention_mask = input_data
  1062. else:
  1063. if isinstance(input_data, list):
  1064. pixel_values = input_data[0]
  1065. else:
  1066. pixel_values = input_data
  1067. output_attentions = (
  1068. output_attentions
  1069. if output_attentions is not None
  1070. else self.config.output_attentions
  1071. )
  1072. output_hidden_states = (
  1073. output_hidden_states
  1074. if output_hidden_states is not None
  1075. else self.config.output_hidden_states
  1076. )
  1077. return_dict = (
  1078. return_dict if return_dict is not None else self.config.return_dict
  1079. )
  1080. if pixel_values is None:
  1081. raise ValueError("You have to specify pixel_values")
  1082. num_channels = pixel_values.shape[1]
  1083. if num_channels == 1:
  1084. pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
  1085. head_mask = self.get_head_mask(head_mask, len(self.config.depths))
  1086. embedding_output, input_dimensions = self.embeddings(
  1087. pixel_values, bool_masked_pos=bool_masked_pos
  1088. )
  1089. encoder_outputs = self.encoder(
  1090. embedding_output,
  1091. input_dimensions,
  1092. head_mask=head_mask,
  1093. output_attentions=output_attentions,
  1094. output_hidden_states=output_hidden_states,
  1095. return_dict=return_dict,
  1096. )
  1097. sequence_output = encoder_outputs[0]
  1098. pooled_output = None
  1099. if self.pooler is not None:
  1100. pooled_output = self.pooler(sequence_output.transpose([0, 2, 1]))
  1101. pooled_output = torch.flatten(pooled_output, 1)
  1102. if not return_dict:
  1103. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  1104. return output
  1105. donut_swin_output = DonutSwinModelOutput(
  1106. last_hidden_state=sequence_output,
  1107. pooler_output=pooled_output,
  1108. hidden_states=encoder_outputs.hidden_states,
  1109. attentions=encoder_outputs.attentions,
  1110. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  1111. )
  1112. if self.training:
  1113. return donut_swin_output, label, attention_mask
  1114. else:
  1115. return donut_swin_output