rec_unimernet_head.py 93 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624
  1. import copy
  2. import math
  3. import re
  4. import numpy as np
  5. import inspect
  6. import warnings
  7. from collections import OrderedDict
  8. from typing import Optional, Tuple, Union, List, Dict, Any
  9. from dataclasses import dataclass, fields, is_dataclass
  10. import torch
  11. import torch.nn as nn
  12. from torch import Tensor
  13. import torch.nn.functional as F
  14. from torch.nn import CrossEntropyLoss
  15. class ModelOutput(OrderedDict):
  16. def __init__(self, *args, **kwargs):
  17. super().__init__(*args, **kwargs)
  18. def __post_init__(self):
  19. class_fields = fields(self)
  20. if not len(class_fields):
  21. raise ValueError(f"{self.__class__.__name__} has no fields.")
  22. if not all(field.default is None for field in class_fields[1:]):
  23. raise ValueError(
  24. f"{self.__class__.__name__} should not have more than one required field."
  25. )
  26. first_field = getattr(self, class_fields[0].name)
  27. other_fields_are_none = all(
  28. getattr(self, field.name) is None for field in class_fields[1:]
  29. )
  30. if other_fields_are_none:
  31. if isinstance(first_field, dict):
  32. iterator = first_field.items()
  33. first_field_iterator = True
  34. else:
  35. try:
  36. iterator = iter(first_field)
  37. first_field_iterator = True
  38. except TypeError:
  39. first_field_iterator = False
  40. if first_field_iterator:
  41. for idx, element in enumerate(iterator):
  42. if (
  43. not isinstance(element, (list, tuple))
  44. or not len(element) == 2
  45. or not isinstance(element[0], str)
  46. ):
  47. if idx == 0:
  48. self[class_fields[0].name] = first_field
  49. else:
  50. raise ValueError(
  51. f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
  52. )
  53. break
  54. setattr(self, element[0], element[1])
  55. if element[1] is not None:
  56. self[element[0]] = element[1]
  57. elif first_field is not None:
  58. self[class_fields[0].name] = first_field
  59. else:
  60. for field in class_fields:
  61. v = getattr(self, field.name)
  62. if v is not None:
  63. self[field.name] = v
  64. def __delitem__(self, *args, **kwargs):
  65. raise Exception(
  66. f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
  67. )
  68. def setdefault(self, *args, **kwargs):
  69. raise Exception(
  70. f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
  71. )
  72. def pop(self, *args, **kwargs):
  73. raise Exception(
  74. f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
  75. )
  76. def update(self, *args, **kwargs):
  77. raise Exception(
  78. f"You cannot use ``update`` on a {self.__class__.__name__} instance."
  79. )
  80. def __getitem__(self, k):
  81. if isinstance(k, str):
  82. inner_dict = dict(self.items())
  83. return inner_dict[k]
  84. else:
  85. return self.to_tuple()[k]
  86. def __setattr__(self, name, value):
  87. if name in self.keys() and value is not None:
  88. super().__setitem__(name, value)
  89. super().__setattr__(name, value)
  90. def __setitem__(self, key, value):
  91. super().__setitem__(key, value)
  92. super().__setattr__(key, value)
  93. def __reduce__(self):
  94. if not is_dataclass(self):
  95. return super().__reduce__()
  96. callable, _args, *remaining = super().__reduce__()
  97. args = tuple(getattr(self, field.name) for field in fields(self))
  98. return callable, args, *remaining
  99. def to_tuple(self):
  100. return tuple(self[k] for k in self.keys())
  101. @dataclass
  102. class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
  103. last_hidden_state = None
  104. past_key_values = None
  105. hidden_states = None
  106. attentions = None
  107. cross_attentions = None
  108. def __init__(self, *args, **kwargs):
  109. super().__init__(*args, **kwargs)
  110. @dataclass
  111. class Seq2SeqLMOutput(ModelOutput):
  112. loss = None
  113. logits = None
  114. past_key_values = None
  115. decoder_hidden_states = None
  116. decoder_attentions = None
  117. cross_attentions = None
  118. encoder_last_hidden_state = None
  119. encoder_hidden_states = None
  120. encoder_attentions = None
  121. def __init__(self, *args, **kwargs):
  122. super().__init__(*args, **kwargs)
  123. class MBartConfig(object):
  124. model_type = "mbart"
  125. keys_to_ignore_at_inference = ["past_key_values"]
  126. attribute_map = {
  127. "num_attention_heads": "encoder_attention_heads",
  128. "hidden_size": "d_model",
  129. }
  130. def __init__(
  131. self,
  132. vocab_size=50265,
  133. max_position_embeddings=1024,
  134. encoder_layers=12,
  135. encoder_ffn_dim=4096,
  136. encoder_attention_heads=16,
  137. decoder_layers=12,
  138. decoder_ffn_dim=4096,
  139. decoder_attention_heads=16,
  140. encoder_layerdrop=0.0,
  141. decoder_layerdrop=0.0,
  142. use_cache=True,
  143. is_encoder_decoder=True,
  144. activation_function="gelu",
  145. d_model=1024,
  146. dropout=0.1,
  147. output_hidden_states=False,
  148. use_return_dict=True,
  149. attention_dropout=0.0,
  150. activation_dropout=0.0,
  151. init_std=0.02,
  152. classifier_dropout=0.0,
  153. scale_embedding=False,
  154. pad_token_id=1,
  155. bos_token_id=0,
  156. eos_token_id=2,
  157. forced_eos_token_id=2,
  158. _attn_implementation="eager",
  159. hidden_size=1024,
  160. use_parallel=False,
  161. parallel_step=2,
  162. is_export=False,
  163. **kwargs,
  164. ):
  165. self.vocab_size = vocab_size
  166. self.hidden_size = hidden_size
  167. self.max_position_embeddings = max_position_embeddings
  168. self.d_model = d_model
  169. self.encoder_ffn_dim = encoder_ffn_dim
  170. self.encoder_layers = encoder_layers
  171. self.encoder_attention_heads = encoder_attention_heads
  172. self.decoder_ffn_dim = decoder_ffn_dim
  173. self.decoder_layers = decoder_layers
  174. self.decoder_attention_heads = decoder_attention_heads
  175. self.dropout = dropout
  176. self.output_hidden_states = output_hidden_states
  177. self.use_return_dict = use_return_dict
  178. self.attention_dropout = attention_dropout
  179. self.activation_dropout = activation_dropout
  180. self.activation_function = activation_function
  181. self.init_std = init_std
  182. self.encoder_layerdrop = encoder_layerdrop
  183. self.decoder_layerdrop = decoder_layerdrop
  184. self.classifier_dropout = classifier_dropout
  185. self.use_cache = use_cache
  186. self.num_hidden_layers = encoder_layers
  187. self.scale_embedding = (
  188. scale_embedding # scale factor will be sqrt(d_model) if True
  189. )
  190. self.pad_token_id = pad_token_id
  191. self.bos_token_id = bos_token_id
  192. self.eos_token_id = eos_token_id
  193. self.is_encoder_decoder = is_encoder_decoder
  194. self.forced_eos_token_id = forced_eos_token_id
  195. self._attn_implementation = _attn_implementation
  196. self.use_parallel = use_parallel
  197. self.parallel_step = parallel_step
  198. self.is_export = is_export
  199. super().__init__()
  200. @dataclass
  201. class AttentionMaskConverter:
  202. """
  203. A utility class for converting attention masks used in transformer models.
  204. This class handles the conversion of attention masks based on whether the
  205. attention mechanism is causal (i.e., preventing information flow from future
  206. tokens to past tokens) and whether a sliding window approach is used.
  207. Attributes:
  208. is_causal (bool): Indicates if the attention mechanism is causal.
  209. sliding_window (Optional[int]): Specifies the size of the sliding window
  210. for local attention, if applicable.
  211. Args:
  212. is_causal (bool): Determines if the attention mask should enforce causality.
  213. sliding_window (Optional[int], optional): The size of the sliding window
  214. for local attention. Default is None.
  215. """
  216. is_causal: bool
  217. sliding_window: int
  218. def __init__(self, is_causal: bool, sliding_window=None):
  219. self.is_causal = is_causal
  220. self.sliding_window = sliding_window
  221. if self.sliding_window is not None and self.sliding_window <= 0:
  222. raise ValueError(
  223. f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
  224. )
  225. @staticmethod
  226. def _make_causal_mask(
  227. input_ids_shape,
  228. dtype,
  229. past_key_values_length=0,
  230. sliding_window=None,
  231. is_export=False,
  232. ):
  233. bsz, tgt_len = input_ids_shape
  234. if is_export:
  235. mask = torch.full(
  236. [tgt_len, tgt_len], fill_value=torch.finfo(dtype).min, dtype=torch.float64
  237. )
  238. else:
  239. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
  240. mask_cond = torch.arange(mask.shape[-1])
  241. mask = mask.masked_fill_(
  242. mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
  243. )
  244. return mask[None, None, :, :].expand(
  245. [bsz, 1, tgt_len, tgt_len + past_key_values_length]
  246. )
  247. def to_4d_export(
  248. self,
  249. attention_mask_2d,
  250. query_length,
  251. dtype,
  252. key_value_length,
  253. is_export=False,
  254. ):
  255. input_shape = (attention_mask_2d.shape[0], query_length)
  256. expanded_attn_mask = self._expand_mask(
  257. attention_mask_2d, dtype, tgt_len=input_shape[-1]
  258. )
  259. expanded_4d_mask = expanded_attn_mask
  260. return expanded_4d_mask
  261. def to_4d(
  262. self,
  263. attention_mask_2d,
  264. query_length,
  265. dtype,
  266. key_value_length,
  267. is_export=False,
  268. ):
  269. input_shape = (attention_mask_2d.shape[0], query_length)
  270. causal_4d_mask = None
  271. if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
  272. if key_value_length is None:
  273. raise ValueError(
  274. "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
  275. )
  276. past_key_values_length = key_value_length - query_length
  277. causal_4d_mask = self._make_causal_mask(
  278. input_shape,
  279. dtype,
  280. past_key_values_length=past_key_values_length,
  281. sliding_window=self.sliding_window,
  282. is_export=is_export,
  283. )
  284. elif self.sliding_window is not None:
  285. raise NotImplementedError(
  286. "Sliding window is currently only implemented for causal masking"
  287. )
  288. expanded_attn_mask = self._expand_mask(
  289. attention_mask_2d, dtype, tgt_len=input_shape[-1]
  290. )
  291. if causal_4d_mask is not None:
  292. if is_export:
  293. expanded_attn_mask = causal_4d_mask
  294. return expanded_attn_mask
  295. else:
  296. expanded_attn_mask = causal_4d_mask.masked_fill_(
  297. expanded_attn_mask.to(torch.bool), torch.finfo(dtype).min
  298. )
  299. expanded_4d_mask = expanded_attn_mask
  300. return expanded_4d_mask
  301. def _expand_mask(self, mask, dtype, tgt_len=None):
  302. bsz, src_len = mask.shape
  303. tgt_len = tgt_len if tgt_len is not None else src_len
  304. expanded_mask = (
  305. mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
  306. )
  307. inverted_mask = 1.0 - expanded_mask
  308. return inverted_mask.masked_fill_(
  309. inverted_mask.to(torch.bool), torch.finfo(dtype).min
  310. )
  311. def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
  312. return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
  313. def _prepare_4d_causal_attention_mask_export(
  314. attention_mask,
  315. input_shape,
  316. inputs_embeds,
  317. past_key_values_length,
  318. sliding_window=None,
  319. is_export=False,
  320. ):
  321. attn_mask_converter = AttentionMaskConverter(
  322. is_causal=True, sliding_window=sliding_window
  323. )
  324. key_value_length = input_shape[-1] + past_key_values_length
  325. shape = attention_mask.shape
  326. len_shape = len(shape)
  327. attention_mask = attn_mask_converter.to_4d_export(
  328. attention_mask,
  329. input_shape[-1],
  330. key_value_length=key_value_length,
  331. dtype=inputs_embeds.dtype,
  332. is_export=is_export,
  333. )
  334. return attention_mask
  335. def _prepare_4d_causal_attention_mask(
  336. attention_mask,
  337. input_shape,
  338. inputs_embeds,
  339. past_key_values_length,
  340. sliding_window=None,
  341. is_export=False,
  342. ):
  343. attn_mask_converter = AttentionMaskConverter(
  344. is_causal=True, sliding_window=sliding_window
  345. )
  346. key_value_length = input_shape[-1] + past_key_values_length
  347. shape = attention_mask.shape
  348. len_shape = len(shape)
  349. if (attention_mask is not None) and (len_shape == 2):
  350. attention_mask = attn_mask_converter.to_4d(
  351. attention_mask,
  352. input_shape[-1],
  353. key_value_length=key_value_length,
  354. dtype=inputs_embeds.dtype,
  355. is_export=is_export,
  356. )
  357. return attention_mask
  358. elif attention_mask is not None and len(attention_mask.shape) == 4:
  359. expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
  360. if tuple(attention_mask.shape) != expected_shape:
  361. raise ValueError(
  362. f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
  363. )
  364. else:
  365. inverted_mask = 1.0 - attention_mask
  366. attention_mask = inverted_mask.masked_fill_(
  367. inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
  368. )
  369. else:
  370. attention_mask = attn_mask_converter.to_causal_4d(
  371. input_shape[0],
  372. input_shape[-1],
  373. key_value_length,
  374. dtype=inputs_embeds.dtype,
  375. )
  376. return attention_mask
  377. class MBartLearnedPositionalEmbedding(nn.Embedding):
  378. """
  379. This module learns positional embeddings up to a fixed maximum size.
  380. """
  381. def __init__(self, num_embeddings, embedding_dim):
  382. self.offset = 2
  383. super().__init__(num_embeddings + self.offset, embedding_dim)
  384. def forward(self, input_ids, past_key_values_length=0):
  385. """`input_ids' shape is expected to be [bsz x seqlen]."""
  386. bsz, seq_len = input_ids.shape[:2]
  387. positions = torch.arange(
  388. past_key_values_length, past_key_values_length + seq_len, dtype=torch.int64
  389. ).expand([bsz, -1])
  390. return nn.Embedding.forward(self, positions + self.offset)
  391. class MBartPreTrainedModel(nn.Module):
  392. base_model_prefix = "model"
  393. supports_gradient_checkpointing = True
  394. _no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
  395. _supports_flash_attn_2 = True
  396. def __init__(self, config):
  397. super().__init__()
  398. self.config = config
  399. def _initialize_weights(self, module):
  400. """
  401. Initialize the weights if they are not already initialized.
  402. """
  403. if getattr(module, "_is_hf_initialized", False):
  404. return
  405. self._init_weights(module)
  406. def post_init(self):
  407. self.apply(self._initialize_weights)
  408. def _init_weights(self, module):
  409. std = self.config.init_std
  410. if isinstance(module, nn.Linear):
  411. torch.nn.init.normal_(module.weight, mean=0.0, std=std)
  412. if module.bias is not None:
  413. torch.nn.init.constant_(module.bias, val=0.0)
  414. elif isinstance(module, nn.Embedding):
  415. torch.nn.init.normal_(module.weight, mean=0.0, std=std)
  416. if module.padding_idx is not None:
  417. torch.nn.init.constant_(module.weight[module.padding_idx], val=0.0)
  418. @property
  419. def dummy_inputs(self):
  420. pad_token = self.config.pad_token_id
  421. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
  422. dummy_inputs = {
  423. "attention_mask": input_ids.ne(pad_token),
  424. "input_ids": input_ids,
  425. }
  426. return dummy_inputs
  427. class MBartAttention(nn.Module):
  428. """Multi-headed attention from 'Attention Is All You Need' paper"""
  429. def __init__(
  430. self,
  431. embed_dim,
  432. num_heads,
  433. dropout: float = 0.0,
  434. is_decoder: bool = False,
  435. bias: bool = True,
  436. is_causal: bool = False,
  437. config=None,
  438. ):
  439. super().__init__()
  440. self.embed_dim = embed_dim
  441. self.num_heads = num_heads
  442. self.dropout = dropout
  443. self.head_dim = embed_dim // num_heads
  444. self.config = config
  445. if (self.head_dim * num_heads) != self.embed_dim:
  446. raise ValueError(
  447. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  448. f" and `num_heads`: {num_heads})."
  449. )
  450. self.scaling = self.head_dim ** -0.5
  451. self.is_decoder = is_decoder
  452. self.is_causal = is_causal
  453. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  454. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  455. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  456. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  457. def _shape(self, tensor, seq_len, bsz):
  458. return tensor.reshape([bsz, seq_len, self.num_heads, self.head_dim]).permute(
  459. 0, 2, 1, 3
  460. )
  461. def forward(
  462. self,
  463. hidden_states,
  464. key_value_states=None,
  465. past_key_value=None,
  466. attention_mask=None,
  467. layer_head_mask=None,
  468. output_attentions=False,
  469. ):
  470. is_cross_attention = key_value_states is not None
  471. bsz, tgt_len, _ = hidden_states.shape
  472. query_states = self.q_proj(hidden_states) * self.scaling
  473. if (
  474. is_cross_attention
  475. and past_key_value is not None
  476. and past_key_value[0].shape[2] == key_value_states.shape[1]
  477. ):
  478. key_states = past_key_value[0]
  479. value_states = past_key_value[1]
  480. elif is_cross_attention:
  481. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  482. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  483. elif past_key_value is not None:
  484. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  485. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  486. key_states = torch.concat([past_key_value[0], key_states], dim=2)
  487. value_states = torch.concat([past_key_value[1], value_states], dim=2)
  488. else:
  489. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  490. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  491. if self.is_decoder:
  492. past_key_value = (key_states, value_states)
  493. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  494. query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape)
  495. key_states = key_states.reshape(proj_shape)
  496. value_states = value_states.reshape(proj_shape)
  497. src_len = key_states.shape[1]
  498. attn_weights = torch.bmm(query_states, key_states.permute([0, 2, 1]))
  499. if attention_mask is not None:
  500. attn_weights = (
  501. attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
  502. + attention_mask
  503. )
  504. attn_weights = attn_weights.reshape(
  505. [bsz * self.num_heads, tgt_len, src_len]
  506. )
  507. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  508. if layer_head_mask is not None:
  509. if tuple(layer_head_mask.shape) != (self.num_heads,):
  510. raise ValueError(
  511. f"Head mask for a single layer should be of shape {(self.num_heads,)}, but is"
  512. f" {layer_head_mask.shape}"
  513. )
  514. attn_weights = layer_head_mask.reshape(
  515. [1, -1, 1, 1]
  516. ) * attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
  517. attn_weights = attn_weights.reshape(
  518. [bsz * self.num_heads, tgt_len, src_len]
  519. )
  520. if output_attentions:
  521. attn_weights_reshaped = attn_weights.reshape(
  522. [bsz, self.num_heads, tgt_len, src_len]
  523. )
  524. attn_weights = attn_weights_reshaped.reshape(
  525. [bsz * self.num_heads, tgt_len, src_len]
  526. )
  527. else:
  528. attn_weights_reshaped = None
  529. attn_probs = nn.functional.dropout(
  530. attn_weights, p=self.dropout, training=self.training
  531. )
  532. attn_output = torch.bmm(attn_probs, value_states)
  533. attn_output = attn_output.reshape([bsz, self.num_heads, tgt_len, self.head_dim])
  534. attn_output = attn_output.permute([0, 2, 1, 3])
  535. attn_output = attn_output.reshape([bsz, tgt_len, self.embed_dim])
  536. attn_output = self.out_proj(attn_output)
  537. return attn_output, attn_weights_reshaped, past_key_value
  538. MBART_ATTENTION_CLASSES = {
  539. "eager": MBartAttention,
  540. }
  541. class MBartDecoderLayer(nn.Module):
  542. def __init__(self, config):
  543. super().__init__()
  544. self.embed_dim = config.d_model
  545. self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
  546. embed_dim=self.embed_dim,
  547. num_heads=config.decoder_attention_heads,
  548. dropout=config.attention_dropout,
  549. is_decoder=True,
  550. is_causal=True,
  551. config=config,
  552. )
  553. self.is_export = config.is_export
  554. self.dropout = config.dropout
  555. self.activation_fn = F.gelu
  556. self.activation_dropout = config.activation_dropout
  557. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  558. self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
  559. self.embed_dim,
  560. config.decoder_attention_heads,
  561. dropout=config.attention_dropout,
  562. is_decoder=True,
  563. config=config,
  564. )
  565. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  566. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  567. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  568. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  569. def forward(
  570. self,
  571. hidden_states,
  572. attention_mask=None,
  573. encoder_hidden_states=None,
  574. encoder_attention_mask=None,
  575. layer_head_mask=None,
  576. cross_attn_layer_head_mask=None,
  577. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  578. output_attentions: Optional[bool] = False,
  579. use_cache: Optional[bool] = True,
  580. ) -> torch.Tensor:
  581. residual = hidden_states
  582. hidden_states = self.self_attn_layer_norm(hidden_states)
  583. self_attn_past_key_value = (
  584. past_key_value[:2] if past_key_value is not None else None
  585. )
  586. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  587. hidden_states=hidden_states,
  588. past_key_value=self_attn_past_key_value,
  589. attention_mask=attention_mask,
  590. layer_head_mask=layer_head_mask,
  591. output_attentions=output_attentions,
  592. )
  593. hidden_states = nn.functional.dropout(
  594. hidden_states, p=self.dropout, training=self.training
  595. )
  596. hidden_states = residual + hidden_states
  597. cross_attn_present_key_value = None
  598. cross_attn_weights = None
  599. if encoder_hidden_states is not None:
  600. residual = hidden_states
  601. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  602. cross_attn_past_key_value = (
  603. past_key_value[-2:] if past_key_value is not None else None
  604. )
  605. hidden_states, cross_attn_weights, cross_attn_present_key_value = (
  606. self.encoder_attn(
  607. hidden_states=hidden_states,
  608. key_value_states=encoder_hidden_states,
  609. attention_mask=encoder_attention_mask,
  610. layer_head_mask=cross_attn_layer_head_mask,
  611. past_key_value=cross_attn_past_key_value,
  612. output_attentions=output_attentions,
  613. )
  614. )
  615. hidden_states = nn.functional.dropout(
  616. hidden_states, p=self.dropout, training=self.training
  617. )
  618. hidden_states = residual + hidden_states
  619. present_key_value = present_key_value + cross_attn_present_key_value
  620. residual = hidden_states
  621. hidden_states = self.final_layer_norm(hidden_states)
  622. hidden_states = self.activation_fn(self.fc1(hidden_states))
  623. hidden_states = nn.functional.dropout(
  624. hidden_states, p=self.activation_dropout, training=self.training
  625. )
  626. hidden_states = self.fc2(hidden_states)
  627. hidden_states = nn.functional.dropout(
  628. hidden_states, p=self.dropout, training=self.training
  629. )
  630. hidden_states = residual + hidden_states
  631. outputs = (hidden_states,)
  632. if output_attentions:
  633. outputs += (self_attn_weights, cross_attn_weights)
  634. if self.is_export:
  635. outputs += (present_key_value,)
  636. else:
  637. if use_cache:
  638. outputs += (present_key_value,)
  639. return outputs
  640. class MBartForCausalLM(MBartPreTrainedModel):
  641. _tied_weights_keys = ["lm_head.weight"]
  642. def __init__(self, config):
  643. config = copy.deepcopy(config)
  644. config.is_decoder = True
  645. config.is_encoder_decoder = False
  646. super().__init__(config)
  647. self.model = MBartDecoderWrapper(config)
  648. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  649. self.post_init()
  650. def get_input_embeddings(self):
  651. return self.model.decoder.embed_tokens
  652. def set_input_embeddings(self, value):
  653. self.model.decoder.embed_tokens = value
  654. def get_output_embeddings(self):
  655. return self.lm_head
  656. def set_output_embeddings(self, new_embeddings):
  657. self.lm_head = new_embeddings
  658. def set_decoder(self, decoder):
  659. self.model.decoder = decoder
  660. def get_decoder(self):
  661. return self.model.decoder
  662. def forward(
  663. self,
  664. input_ids=None,
  665. attention_mask=None,
  666. encoder_hidden_states=None,
  667. encoder_attention_mask=None,
  668. head_mask=None,
  669. cross_attn_head_mask=None,
  670. past_key_values=None,
  671. inputs_embeds=None,
  672. labels=None,
  673. use_cache=None,
  674. output_attentions=None,
  675. output_hidden_states=None,
  676. return_dict=None,
  677. ):
  678. output_attentions = (
  679. output_attentions
  680. if output_attentions is not None
  681. else self.config.output_attentions
  682. )
  683. output_hidden_states = (
  684. output_hidden_states
  685. if output_hidden_states is not None
  686. else self.config.output_hidden_states
  687. )
  688. return_dict = (
  689. return_dict if return_dict is not None else self.config.use_return_dict
  690. )
  691. outputs = self.model.decoder(
  692. input_ids=input_ids,
  693. attention_mask=attention_mask,
  694. encoder_hidden_states=encoder_hidden_states,
  695. encoder_attention_mask=encoder_attention_mask,
  696. head_mask=head_mask,
  697. cross_attn_head_mask=cross_attn_head_mask,
  698. past_key_values=past_key_values,
  699. inputs_embeds=inputs_embeds,
  700. use_cache=use_cache,
  701. output_attentions=output_attentions,
  702. output_hidden_states=output_hidden_states,
  703. return_dict=return_dict,
  704. )
  705. logits = self.lm_head(outputs[0])
  706. loss = None
  707. if labels is not None:
  708. labels = labels
  709. loss_fct = CrossEntropyLoss()
  710. loss = loss_fct(
  711. logits.reshape([-1, self.config.vocab_size]), labels.reshape([-1])
  712. )
  713. if not return_dict:
  714. output = (logits,) + outputs[1:]
  715. return (loss,) + output if loss is not None else output
  716. return CausalLMOutputWithCrossAttentions(
  717. loss=loss,
  718. logits=logits,
  719. past_key_values=outputs.past_key_values,
  720. hidden_states=outputs.hidden_states,
  721. attentions=outputs.attentions,
  722. cross_attentions=outputs.cross_attentions,
  723. )
  724. def prepare_inputs_for_generation(
  725. self,
  726. input_ids,
  727. past_key_values=None,
  728. attention_mask=None,
  729. use_cache=None,
  730. **kwargs,
  731. ):
  732. if attention_mask is None:
  733. attention_mask = input_ids.new_ones(input_ids.shape)
  734. if past_key_values:
  735. past_length = past_key_values[0][0].shape[2]
  736. if input_ids.shape[1] > past_length:
  737. remove_prefix_length = past_length
  738. else:
  739. remove_prefix_length = input_ids.shape[1] - 1
  740. input_ids = input_ids[:, remove_prefix_length:]
  741. return {
  742. "input_ids": input_ids,
  743. "attention_mask": attention_mask,
  744. "past_key_values": past_key_values,
  745. "use_cache": use_cache,
  746. }
  747. @staticmethod
  748. def _reorder_cache(past_key_values, beam_idx):
  749. reordered_past = ()
  750. for layer_past in past_key_values:
  751. reordered_past += (
  752. tuple(
  753. past_state.index_select(0, beam_idx) for past_state in layer_past
  754. ),
  755. )
  756. return reordered_past
  757. class myLayerNorm(nn.LayerNorm):
  758. """
  759. Custom implementation of Layer Normalization, with additional options.
  760. This class extends the standard LayerNorm to include optional features,
  761. such as drop block regularization, which might be used for improving
  762. model generalization.
  763. Args:
  764. num_channels (int): The number of features or channels in the input.
  765. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-5.
  766. affine (bool, optional): If True, this module has learnable affine parameters (gamma and beta). Default is True.
  767. drop_block (optional): Additional regularization technique that might be applied. Default is None.
  768. """
  769. def __init__(
  770. self,
  771. num_channels,
  772. eps=1e-5,
  773. affine=True,
  774. drop_block=None,
  775. ):
  776. super(nn.LayerNorm, self).__init__()
  777. self._epsilon = eps
  778. self.num_channels = num_channels
  779. if affine:
  780. self.weight = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
  781. self.bias = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
  782. torch.nn.init.ones_(self.weight)
  783. torch.nn.init.zeros_(self.bias)
  784. def forward(self, x):
  785. x = F.layer_norm(
  786. x,
  787. [self.num_channels],
  788. weight=self.weight,
  789. bias=self.bias,
  790. eps=self._epsilon,
  791. )
  792. return x
  793. class MBartDecoder(MBartPreTrainedModel):
  794. """
  795. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
  796. Args:
  797. config
  798. embed_tokens (nn.Embedding): output embedding
  799. """
  800. def __init__(self, config, embed_tokens=None):
  801. super().__init__(config)
  802. self.dropout = config.dropout
  803. self.layerdrop = config.decoder_layerdrop
  804. self.padding_idx = config.pad_token_id
  805. self.max_target_positions = config.max_position_embeddings
  806. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  807. self.embed_tokens = nn.Embedding(
  808. config.vocab_size, config.d_model, self.padding_idx
  809. )
  810. if embed_tokens is not None:
  811. self.embed_tokens.weight = embed_tokens.weight
  812. self.embed_positions = MBartLearnedPositionalEmbedding(
  813. config.max_position_embeddings,
  814. config.d_model,
  815. )
  816. self.layers = nn.ModuleList(
  817. [MBartDecoderLayer(config) for _ in range(config.decoder_layers)]
  818. )
  819. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  820. self.layernorm_embedding = myLayerNorm(config.d_model, affine=True)
  821. self.layer_norm = nn.LayerNorm(config.d_model)
  822. self.gradient_checkpointing = False
  823. # Initialize weights and apply final processing
  824. self.post_init()
  825. self.is_export = config.is_export
  826. def get_input_embeddings(self):
  827. return self.embed_tokens
  828. def set_input_embeddings(self, value):
  829. self.embed_tokens = value
  830. def forward(
  831. self,
  832. input_ids=None,
  833. attention_mask=None,
  834. encoder_hidden_states=None,
  835. encoder_attention_mask=None,
  836. head_mask=None,
  837. cross_attn_head_mask=None,
  838. past_key_values=None,
  839. inputs_embeds=None,
  840. use_cache=None,
  841. output_attentions=None,
  842. output_hidden_states=None,
  843. return_dict=None,
  844. ):
  845. output_attentions = (
  846. output_attentions
  847. if output_attentions is not None
  848. else self.config.output_attentions
  849. )
  850. output_hidden_states = (
  851. output_hidden_states
  852. if output_hidden_states is not None
  853. else self.config.output_hidden_states
  854. )
  855. use_cache = use_cache if use_cache is not None else self.config.use_cache
  856. return_dict = (
  857. return_dict if return_dict is not None else self.config.use_return_dict
  858. )
  859. if input_ids is not None and inputs_embeds is not None:
  860. raise ValueError(
  861. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  862. )
  863. elif input_ids is not None:
  864. input = input_ids
  865. input_shape = input.shape
  866. input_ids = input_ids.reshape([-1, input_shape[-1]])
  867. elif inputs_embeds is not None:
  868. input_shape = inputs_embeds.shape[:-1]
  869. input = inputs_embeds[:, :, -1]
  870. else:
  871. raise ValueError(
  872. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  873. )
  874. past_key_values_length = (
  875. past_key_values[0][0].shape[2] if past_key_values is not None else 0
  876. )
  877. if inputs_embeds is None:
  878. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  879. if self._use_flash_attention_2:
  880. attention_mask = (
  881. attention_mask
  882. if (attention_mask is not None and 0 in attention_mask)
  883. else None
  884. )
  885. else:
  886. attention_mask = _prepare_4d_causal_attention_mask(
  887. attention_mask,
  888. input_shape,
  889. inputs_embeds,
  890. past_key_values_length,
  891. is_export=self.is_export,
  892. )
  893. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  894. if self._use_flash_attention_2:
  895. encoder_attention_mask = (
  896. encoder_attention_mask if 0 in encoder_attention_mask else None
  897. )
  898. else:
  899. encoder_attention_mask = _prepare_4d_attention_mask(
  900. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  901. )
  902. # embed positions
  903. positions = self.embed_positions(input, past_key_values_length)
  904. hidden_states = inputs_embeds + positions
  905. hidden_states = self.layernorm_embedding(hidden_states)
  906. hidden_states = nn.functional.dropout(
  907. hidden_states, p=self.dropout, training=self.training
  908. )
  909. if self.gradient_checkpointing and self.training:
  910. if use_cache:
  911. print(
  912. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  913. )
  914. use_cache = False
  915. all_hidden_states = () if output_hidden_states else None
  916. all_self_attns = () if output_attentions else None
  917. all_cross_attentions = (
  918. () if (output_attentions and encoder_hidden_states is not None) else None
  919. )
  920. next_decoder_cache = () if use_cache else None
  921. for attn_mask, mask_name in zip(
  922. [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
  923. ):
  924. if attn_mask is not None:
  925. if attn_mask.shape[0] != len(self.layers):
  926. raise ValueError(
  927. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  928. f" {attn_mask.shape[0]}."
  929. )
  930. for idx, decoder_layer in enumerate(self.layers):
  931. if output_hidden_states:
  932. all_hidden_states += (hidden_states,)
  933. if self.training:
  934. dropout_probability = torch.rand([])
  935. if dropout_probability < self.layerdrop:
  936. continue
  937. past_key_value = (
  938. past_key_values[idx] if past_key_values is not None else None
  939. )
  940. if self.gradient_checkpointing and self.training:
  941. layer_outputs = self._gradient_checkpointing_func(
  942. decoder_layer.__call__,
  943. hidden_states,
  944. attention_mask,
  945. encoder_hidden_states,
  946. encoder_attention_mask,
  947. head_mask[idx] if head_mask is not None else None,
  948. (
  949. cross_attn_head_mask[idx]
  950. if cross_attn_head_mask is not None
  951. else None
  952. ),
  953. None,
  954. output_attentions,
  955. use_cache,
  956. )
  957. else:
  958. layer_outputs = decoder_layer(
  959. hidden_states,
  960. attention_mask=attention_mask,
  961. encoder_hidden_states=encoder_hidden_states,
  962. encoder_attention_mask=encoder_attention_mask,
  963. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  964. cross_attn_layer_head_mask=(
  965. cross_attn_head_mask[idx]
  966. if cross_attn_head_mask is not None
  967. else None
  968. ),
  969. past_key_value=past_key_value,
  970. output_attentions=output_attentions,
  971. use_cache=use_cache,
  972. )
  973. hidden_states = layer_outputs[0]
  974. if use_cache:
  975. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  976. if output_attentions:
  977. all_self_attns += (layer_outputs[1],)
  978. if encoder_hidden_states is not None:
  979. all_cross_attentions += (layer_outputs[2],)
  980. hidden_states = self.layer_norm(hidden_states)
  981. if output_hidden_states:
  982. all_hidden_states += (hidden_states,)
  983. next_cache = next_decoder_cache if use_cache else None
  984. if not return_dict:
  985. return tuple(
  986. v
  987. for v in [
  988. hidden_states,
  989. next_cache,
  990. all_hidden_states,
  991. all_self_attns,
  992. all_cross_attentions,
  993. ]
  994. if v is not None
  995. )
  996. return BaseModelOutputWithPastAndCrossAttentions(
  997. last_hidden_state=hidden_states,
  998. past_key_values=next_cache,
  999. hidden_states=all_hidden_states,
  1000. attentions=all_self_attns,
  1001. cross_attentions=all_cross_attentions,
  1002. )
  1003. class MBartDecoderWrapper(MBartPreTrainedModel):
  1004. """
  1005. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1006. used in combination with the [`EncoderDecoderModel`] framework.
  1007. """
  1008. def __init__(self, config):
  1009. super().__init__(config)
  1010. self.decoder = MBartDecoder(config)
  1011. def forward(self, *args, **kwargs):
  1012. return self.decoder(*args, **kwargs)
  1013. def _in_projection(
  1014. q: torch.Tensor,
  1015. k: torch.Tensor,
  1016. v: torch.Tensor,
  1017. w_q: torch.Tensor,
  1018. w_k: torch.Tensor,
  1019. w_v: torch.Tensor,
  1020. b_q: Optional[torch.Tensor] = None,
  1021. b_k: Optional[torch.Tensor] = None,
  1022. b_v: Optional[torch.Tensor] = None,
  1023. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  1024. Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
  1025. assert w_q.shape == (
  1026. Eq,
  1027. Eq,
  1028. ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
  1029. assert w_k.shape == (
  1030. Eq,
  1031. Ek,
  1032. ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
  1033. assert w_v.shape == (
  1034. Eq,
  1035. Ev,
  1036. ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
  1037. assert b_q is None or b_q.shape == (
  1038. Eq,
  1039. ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
  1040. assert b_k is None or b_k.shape == (
  1041. Eq,
  1042. ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
  1043. assert b_v is None or b_v.shape == (
  1044. Eq,
  1045. ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
  1046. return linear(q, w_q.T, b_q), linear(k, w_k.T, b_k), linear(v, w_v.T, b_v)
  1047. def _scaled_dot_product_attention(
  1048. q: torch.Tensor,
  1049. k: torch.Tensor,
  1050. v: torch.Tensor,
  1051. attn_mask: Optional[torch.Tensor] = None,
  1052. dropout_p: float = 0.0,
  1053. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1054. B, Nt, E = q.shape
  1055. q = q / math.sqrt(E)
  1056. attn = torch.bmm(q, k.permute([0, 2, 1]))
  1057. if attn_mask is not None:
  1058. attn += attn_mask
  1059. attn = F.softmax(attn, dim=-1)
  1060. if dropout_p > 0.0:
  1061. attn = F.dropout(attn, p=dropout_p)
  1062. output = torch.bmm(attn, v)
  1063. return output, attn
  1064. def linear(x, w, b, is_transpose):
  1065. if is_transpose:
  1066. w = w.T
  1067. if b is not None:
  1068. return torch.matmul(x, w) + b
  1069. else:
  1070. return torch.matmul(x, w)
  1071. def _in_projection_packed(
  1072. q: Tensor,
  1073. k: Tensor,
  1074. v: Tensor,
  1075. w: Tensor,
  1076. b: Optional[Tensor] = None,
  1077. is_export=False,
  1078. ) -> List[Tensor]:
  1079. E = q.shape[-1]
  1080. if k is v:
  1081. if q is k:
  1082. proj = linear(q, w, b, is_transpose=True)
  1083. if is_export:
  1084. B, D, L = proj.shape
  1085. proj = proj.reshape([B, D, 3, E])
  1086. proj = (
  1087. proj.unsqueeze(0)
  1088. .permute([3, 1, 2, 0, 4])
  1089. .squeeze(-2)
  1090. .contiguous()
  1091. )
  1092. else:
  1093. proj = (
  1094. proj.unflatten(-1, (3, E))
  1095. .unsqueeze(0)
  1096. .permute([3, 1, 2, 0, 4])
  1097. .squeeze(-2)
  1098. .contiguous()
  1099. )
  1100. return proj[0], proj[1], proj[2]
  1101. else:
  1102. w_q, w_k, w_v = w.chunk(3)
  1103. if b is None:
  1104. b_q = b_k = b_v = None
  1105. else:
  1106. b_q, b_k, b_v = b.chunk(3)
  1107. return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
  1108. def multi_head_attention_forward(
  1109. query: torch.Tensor,
  1110. key: torch.Tensor,
  1111. value: torch.Tensor,
  1112. embed_dim_to_check: int,
  1113. num_heads: int,
  1114. in_proj_weight: torch.Tensor,
  1115. in_proj_bias: Optional[torch.Tensor],
  1116. bias_k: Optional[torch.Tensor],
  1117. bias_v: Optional[torch.Tensor],
  1118. add_zero_attn: bool,
  1119. dropout_p: float,
  1120. out_proj_weight: torch.Tensor,
  1121. out_proj_bias: Optional[torch.Tensor],
  1122. training: bool = True,
  1123. key_padding_mask: Optional[torch.Tensor] = None,
  1124. need_weights: bool = True,
  1125. attn_mask: Optional[torch.Tensor] = None,
  1126. use_separate_proj_weight: bool = False,
  1127. q_proj_weight: Optional[torch.Tensor] = None,
  1128. k_proj_weight: Optional[torch.Tensor] = None,
  1129. v_proj_weight: Optional[torch.Tensor] = None,
  1130. static_k: Optional[torch.Tensor] = None,
  1131. static_v: Optional[torch.Tensor] = None,
  1132. is_export=False,
  1133. ):
  1134. tgt_len, bsz, embed_dim = query.shape
  1135. src_len, _, _ = key.shape
  1136. if isinstance(embed_dim, torch.Tensor):
  1137. head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
  1138. else:
  1139. head_dim = embed_dim // num_heads
  1140. q, k, v = _in_projection_packed(
  1141. query, key, value, in_proj_weight, in_proj_bias, is_export
  1142. )
  1143. if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
  1144. warnings.warn(
  1145. "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
  1146. )
  1147. key_padding_mask = key_padding_mask.to(torch.bool)
  1148. if bias_k is not None and bias_v is not None: # False
  1149. assert static_k is None, "bias cannot be added to static key."
  1150. assert static_v is None, "bias cannot be added to static value."
  1151. k = torch.concat([k, bias_k.repeat(1, bsz, 1)])
  1152. v = torch.concat([v, bias_v.repeat(1, bsz, 1)])
  1153. else:
  1154. assert bias_k is None
  1155. assert bias_v is None
  1156. q = q.reshape([tgt_len, bsz * num_heads, head_dim]).permute([1, 0, 2])
  1157. if static_k is None: # True
  1158. k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).permute([1, 0, 2])
  1159. else:
  1160. assert (
  1161. static_k.shape[0] == bsz * num_heads
  1162. ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.shape[0]}"
  1163. assert (
  1164. static_k.shape[2] == head_dim
  1165. ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.shape[2]}"
  1166. k = static_k
  1167. if static_v is None: # True
  1168. v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
  1169. else:
  1170. assert (
  1171. static_v.shape[0] == bsz * num_heads
  1172. ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.shape[0]}"
  1173. assert (
  1174. static_v.shape[2] == head_dim
  1175. ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.shape[2]}"
  1176. v = static_v
  1177. src_len = k.shape[1]
  1178. if not training:
  1179. dropout_p = 0.0
  1180. attn_output, attn_output_weights = _scaled_dot_product_attention(
  1181. q, k, v, attn_mask, dropout_p
  1182. )
  1183. attn_output = attn_output.permute([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
  1184. attn_output = linear(
  1185. attn_output, out_proj_weight, out_proj_bias, is_transpose=False
  1186. )
  1187. if need_weights:
  1188. attn_output_weights = attn_output_weights.reshape(
  1189. [bsz, num_heads, tgt_len, src_len]
  1190. )
  1191. return attn_output, attn_output_weights.sum(dim=1) / num_heads
  1192. else:
  1193. return attn_output, None
  1194. class MyMultiheadAttention(nn.Module):
  1195. """
  1196. Custom implementation of a multi-head attention layer.
  1197. Attributes:
  1198. __constants__ (list): List of constant attributes.
  1199. bias_k (Optional[paddle.Tensor]): Optional tensor for key bias.
  1200. bias_v (Optional[paddle.Tensor]): Optional tensor for value bias.
  1201. Args:
  1202. embed_dim (int): Total dimension of the model. This is the size of the input feature vectors.
  1203. num_heads (int): Number of parallel attention heads. The input dimension must be divisible by the number of heads.
  1204. dropout (float, optional): Dropout probability on the attention weights. Default is 0.0.
  1205. bias (bool, optional): If True, adds a learnable bias to the output. Default is True.
  1206. add_bias_kv (bool, optional): If True, adds bias to the key and value sequences. Default is False.
  1207. add_zero_attn (bool, optional): If True, adds a zero attention head. Default is False.
  1208. kdim (int, optional): Total number of features for keys. If None, defaults to embed_dim.
  1209. vdim (int, optional): Total number of features for values. If None, defaults to embed_dim.
  1210. batch_first (bool, optional): If True, the input and output tensors are provided as (batch, seq, feature). Default is False.
  1211. device (optional): The device on which the layer's parameters should be initialized. Default is None.
  1212. dtype (optional): The data type for the parameters. Default is None.
  1213. is_export (bool, optional): If True, the layer is set up for export, potentially changing behavior for compatibility. Default is False.
  1214. """
  1215. __constants__ = ["batch_first"]
  1216. bias_k: Optional[torch.Tensor]
  1217. bias_v: Optional[torch.Tensor]
  1218. def __init__(
  1219. self,
  1220. embed_dim,
  1221. num_heads,
  1222. dropout=0.0,
  1223. bias=True,
  1224. add_bias_kv=False,
  1225. add_zero_attn=False,
  1226. kdim=None,
  1227. vdim=None,
  1228. batch_first=False,
  1229. device=None,
  1230. dtype=None,
  1231. is_export=False,
  1232. ) -> None:
  1233. super(MyMultiheadAttention, self).__init__()
  1234. self.embed_dim = embed_dim
  1235. self.kdim = kdim if kdim is not None else embed_dim
  1236. self.vdim = vdim if vdim is not None else embed_dim
  1237. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  1238. self.num_heads = num_heads
  1239. self.dropout = dropout
  1240. self.batch_first = batch_first
  1241. self.head_dim = embed_dim // num_heads
  1242. self.is_export = is_export
  1243. assert (
  1244. self.head_dim * num_heads == self.embed_dim
  1245. ), "embed_dim must be divisible by num_heads"
  1246. if self._qkv_same_embed_dim is False:
  1247. pass
  1248. else:
  1249. if dtype is None:
  1250. dtype = torch.float32
  1251. self.in_proj_weight = torch.nn.Parameter(torch.randn(3 * embed_dim, embed_dim) * 0.01)
  1252. self.q_proj_weight = None
  1253. self.k_proj_weight = None
  1254. self.v_proj_weight = None
  1255. if bias:
  1256. self.in_proj_bias = torch.nn.Parameter(torch.randn(3 * embed_dim, ) * 0.01)
  1257. torch.nn.init.zeros_(self.in_proj_bias)
  1258. else:
  1259. self.in_proj_bias = None
  1260. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  1261. if add_bias_kv:
  1262. pass
  1263. else:
  1264. self.bias_k = self.bias_v = None
  1265. self.add_zero_attn = add_zero_attn
  1266. self._reset_parameters()
  1267. def _reset_parameters(self):
  1268. if self._qkv_same_embed_dim:
  1269. torch.nn.init.xavier_normal_(self.in_proj_weight)
  1270. else:
  1271. torch.nn.init.xavier_normal_(self.q_proj_weight)
  1272. torch.nn.init.xavier_normal_(self.k_proj_weight)
  1273. torch.nn.init.xavier_normal_(self.v_proj_weight)
  1274. if self.in_proj_bias is not None:
  1275. torch.nn.init.zeros_(self.in_proj_bias)
  1276. torch.nn.init.zeros_(self.out_proj.bias)
  1277. if self.bias_k is not None:
  1278. torch.nn.init.xavier_normal_(self.bias_k)
  1279. if self.bias_v is not None:
  1280. torch.nn.init.xavier_normal_(self.bias_v)
  1281. def forward(
  1282. self,
  1283. query: torch.Tensor,
  1284. key: torch.Tensor,
  1285. value: torch.Tensor,
  1286. key_padding_mask: Optional[torch.Tensor] = None,
  1287. need_weights: bool = True,
  1288. attn_mask: Optional[torch.Tensor] = None,
  1289. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  1290. attn_output, attn_output_weights = multi_head_attention_forward(
  1291. query,
  1292. key,
  1293. value,
  1294. self.embed_dim,
  1295. self.num_heads,
  1296. self.in_proj_weight,
  1297. self.in_proj_bias,
  1298. self.bias_k,
  1299. self.bias_v,
  1300. self.add_zero_attn,
  1301. self.dropout,
  1302. self.out_proj.weight,
  1303. self.out_proj.bias,
  1304. training=self.training,
  1305. key_padding_mask=key_padding_mask,
  1306. need_weights=need_weights,
  1307. attn_mask=attn_mask,
  1308. is_export=self.is_export,
  1309. )
  1310. return attn_output, attn_output_weights
  1311. class LogitsProcessorList(list):
  1312. """
  1313. A list of logits processors that can be applied sequentially.
  1314. Methods:
  1315. __call__(input_ids, scores, **kwargs): Apply all processors to the given inputs.
  1316. """
  1317. def __call__(self, input_ids, scores, **kwargs):
  1318. for processor in self:
  1319. function_args = inspect.signature(processor.__call__).parameters
  1320. if len(function_args) > 2:
  1321. if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
  1322. raise ValueError(
  1323. f"Make sure that all the required parameters: {list(function_args.keys())} for "
  1324. f"{processor.__class__} are passed to the logits processor."
  1325. )
  1326. scores = processor(input_ids, scores, **kwargs)
  1327. else:
  1328. scores = processor(input_ids, scores)
  1329. return scores
  1330. class ForcedEOSTokenLogitsProcessor(object):
  1331. """
  1332. A processor that forces the generation of an end-of-sequence (EOS) token
  1333. at a specified position in the sequence.
  1334. This is typically used in language generation tasks to ensure that the
  1335. generated sequence ends properly when it reaches a certain length.
  1336. Args:
  1337. max_length (int): The maximum length of the sequence. Forces EOS when this length is reached.
  1338. eos_token_id (Union[int, List[int]]): The ID(s) of the EOS token(s) to be forced in the sequence.
  1339. """
  1340. def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
  1341. self.max_length = max_length
  1342. if isinstance(eos_token_id, int):
  1343. eos_token_id = [eos_token_id]
  1344. self.eos_token_id = eos_token_id
  1345. def __call__(self, input_ids, scores):
  1346. cur_len = input_ids.shape[-1]
  1347. scores_processed = scores
  1348. if cur_len == self.max_length - 1:
  1349. scores_processed = torch.full_like(scores, -math.inf)
  1350. scores_processed[:, self.eos_token_id] = 0
  1351. return scores_processed
  1352. @dataclass
  1353. class CausalLMOutputWithCrossAttentions(ModelOutput):
  1354. loss = None
  1355. logits = None
  1356. past_key_values = None
  1357. hidden_states = None
  1358. attentions = None
  1359. cross_attentions = None
  1360. def __init__(self, *args, **kwargs):
  1361. super().__init__(*args, **kwargs)
  1362. @dataclass
  1363. class CausalLMOutputWithCrossAttentionsAndCounting(ModelOutput):
  1364. """
  1365. Base class for causal language model (or autoregressive) outputs.
  1366. """
  1367. logits = None
  1368. counting = None
  1369. past_key_values = None
  1370. hidden_states = None
  1371. attentions = None
  1372. cross_attentions = None
  1373. def __init__(self, *args, **kwargs):
  1374. super().__init__(*args, **kwargs)
  1375. class CustomMBartDecoder(MBartDecoder):
  1376. """
  1377. A custom MBartDecoder that includes additional processing layers.
  1378. This class extends the MBartDecoder by adding a customizable neural network
  1379. component called `counting_context_weight`, which applies a series of linear
  1380. transformations followed by ReLU activations. This can be used to modify or
  1381. enhance the decoder's behavior for specific tasks.
  1382. Args:
  1383. config: The configuration object containing model parameters.
  1384. """
  1385. def __init__(self, config):
  1386. super().__init__(config)
  1387. hidden_size = config.d_model
  1388. self.is_export = config.is_export
  1389. self.counting_context_weight = nn.Sequential(
  1390. nn.Linear(config.vocab_size, hidden_size),
  1391. nn.ReLU(),
  1392. nn.Linear(hidden_size, hidden_size),
  1393. nn.ReLU(),
  1394. nn.Linear(hidden_size, config.d_model),
  1395. )
  1396. def forward(
  1397. self,
  1398. input_ids=None,
  1399. attention_mask=None,
  1400. count_pred=None,
  1401. encoder_hidden_states=None,
  1402. encoder_attention_mask=None,
  1403. head_mask=None,
  1404. cross_attn_head_mask=None,
  1405. past_key_values=None,
  1406. inputs_embeds=None,
  1407. use_cache=None,
  1408. output_attentions=None,
  1409. output_hidden_states=None,
  1410. return_dict=None,
  1411. ):
  1412. self.is_export = False if self.training else True
  1413. output_attentions = (
  1414. output_attentions
  1415. if output_attentions is not None
  1416. else self.config.output_attentions
  1417. )
  1418. output_hidden_states = (
  1419. output_hidden_states
  1420. if output_hidden_states is not None
  1421. else self.config.output_hidden_states
  1422. )
  1423. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1424. return_dict = (
  1425. return_dict if return_dict is not None else self.config.use_return_dict
  1426. )
  1427. if input_ids is not None and inputs_embeds is not None:
  1428. raise ValueError(
  1429. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  1430. )
  1431. elif input_ids is not None:
  1432. input = input_ids
  1433. input_shape = input.shape
  1434. input_ids = input_ids.reshape([-1, input_shape[-1]])
  1435. elif inputs_embeds is not None:
  1436. input_shape = inputs_embeds.shape[:-1]
  1437. input = inputs_embeds[:, :, -1]
  1438. else:
  1439. raise ValueError(
  1440. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  1441. )
  1442. past_key_values_length = (
  1443. past_key_values[0][0].shape[2] if past_key_values is not None else 0
  1444. )
  1445. if inputs_embeds is None:
  1446. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  1447. if self._use_flash_attention_2:
  1448. attention_mask = (
  1449. attention_mask
  1450. if (attention_mask is not None and 0 in attention_mask)
  1451. else None
  1452. )
  1453. else:
  1454. if self.is_export:
  1455. attention_mask = _prepare_4d_causal_attention_mask_export(
  1456. attention_mask,
  1457. input_shape,
  1458. inputs_embeds,
  1459. past_key_values_length,
  1460. is_export=self.is_export,
  1461. ).to(torch.float32)
  1462. else:
  1463. attention_mask = _prepare_4d_causal_attention_mask(
  1464. attention_mask,
  1465. input_shape,
  1466. inputs_embeds,
  1467. past_key_values_length,
  1468. is_export=self.is_export,
  1469. )
  1470. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  1471. if self._use_flash_attention_2:
  1472. encoder_attention_mask = (
  1473. encoder_attention_mask if 0 in encoder_attention_mask else None
  1474. )
  1475. else:
  1476. encoder_attention_mask = _prepare_4d_attention_mask(
  1477. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  1478. )
  1479. # embed positions
  1480. positions = self.embed_positions(input, past_key_values_length)
  1481. hidden_states = inputs_embeds + positions
  1482. # TODO: add counting context weight to hidden_states
  1483. if count_pred is not None:
  1484. count_context_weight = self.counting_context_weight(count_pred)
  1485. hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
  1486. hidden_states = self.layernorm_embedding(hidden_states)
  1487. hidden_states = nn.functional.dropout(
  1488. hidden_states, p=self.dropout, training=self.training
  1489. )
  1490. if self.gradient_checkpointing and self.training:
  1491. if use_cache:
  1492. print(
  1493. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  1494. )
  1495. use_cache = False
  1496. # decoder layers
  1497. all_hidden_states = () if output_hidden_states else None
  1498. all_self_attns = () if output_attentions else None
  1499. all_cross_attentions = (
  1500. () if (output_attentions and encoder_hidden_states is not None) else None
  1501. )
  1502. next_decoder_cache = () if use_cache else None
  1503. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  1504. for attn_mask, mask_name in zip(
  1505. [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
  1506. ):
  1507. if attn_mask is not None:
  1508. if attn_mask.size()[0] != len(self.layers):
  1509. raise ValueError(
  1510. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  1511. f" {attn_mask.size()[0]}."
  1512. )
  1513. for idx, decoder_layer in enumerate(self.layers):
  1514. if output_hidden_states:
  1515. all_hidden_states += (hidden_states,)
  1516. if self.training:
  1517. dropout_probability = torch.rand()
  1518. if dropout_probability < self.layerdrop:
  1519. continue
  1520. past_key_value = (
  1521. past_key_values[idx] if past_key_values is not None else None
  1522. )
  1523. if self.gradient_checkpointing and self.training:
  1524. layer_outputs = self._gradient_checkpointing_func(
  1525. decoder_layer.__call__,
  1526. hidden_states,
  1527. attention_mask,
  1528. encoder_hidden_states,
  1529. encoder_attention_mask,
  1530. head_mask[idx] if head_mask is not None else None,
  1531. (
  1532. cross_attn_head_mask[idx]
  1533. if cross_attn_head_mask is not None
  1534. else None
  1535. ),
  1536. None,
  1537. output_attentions,
  1538. use_cache,
  1539. )
  1540. else:
  1541. layer_outputs = decoder_layer(
  1542. hidden_states,
  1543. attention_mask=attention_mask,
  1544. encoder_hidden_states=encoder_hidden_states,
  1545. encoder_attention_mask=encoder_attention_mask,
  1546. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1547. cross_attn_layer_head_mask=(
  1548. cross_attn_head_mask[idx]
  1549. if cross_attn_head_mask is not None
  1550. else None
  1551. ),
  1552. past_key_value=past_key_value,
  1553. output_attentions=output_attentions,
  1554. use_cache=use_cache,
  1555. )
  1556. hidden_states = layer_outputs[0]
  1557. if self.is_export:
  1558. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  1559. else:
  1560. if use_cache:
  1561. next_decoder_cache += (
  1562. layer_outputs[3 if output_attentions else 1],
  1563. )
  1564. if output_attentions:
  1565. all_self_attns += (layer_outputs[1],)
  1566. if encoder_hidden_states is not None:
  1567. all_cross_attentions += (layer_outputs[2],)
  1568. hidden_states = self.layer_norm(hidden_states)
  1569. if output_hidden_states:
  1570. all_hidden_states += (hidden_states,)
  1571. if self.is_export:
  1572. next_cache = next_decoder_cache
  1573. else:
  1574. next_cache = next_decoder_cache if use_cache else None
  1575. if not self.is_export:
  1576. if not return_dict:
  1577. return tuple(
  1578. v
  1579. for v in [
  1580. hidden_states,
  1581. next_cache,
  1582. all_hidden_states,
  1583. all_self_attns,
  1584. all_cross_attentions,
  1585. ]
  1586. if v is not None
  1587. )
  1588. return BaseModelOutputWithPastAndCrossAttentions(
  1589. last_hidden_state=hidden_states,
  1590. past_key_values=next_cache,
  1591. hidden_states=all_hidden_states,
  1592. attentions=all_self_attns,
  1593. cross_attentions=all_cross_attentions,
  1594. )
  1595. class SelfAttentionBlock(nn.Module):
  1596. """
  1597. A self-attention block that implements multi-head self-attention
  1598. followed by a feed-forward network, typically used in transformer architectures.
  1599. Args:
  1600. embed_size (int): The size of the embedding vector.
  1601. num_heads (int): The number of attention heads.
  1602. is_export (bool): Flag indicating whether to configure the layer for export.
  1603. """
  1604. def __init__(self, embed_size, num_heads, is_export):
  1605. super(SelfAttentionBlock, self).__init__()
  1606. self.self_attention = MyMultiheadAttention(
  1607. embed_dim=embed_size, num_heads=num_heads, is_export=is_export
  1608. )
  1609. self.norm = nn.LayerNorm(embed_size)
  1610. def forward(self, x):
  1611. attn_output, _ = self.self_attention(x, x, x)
  1612. x = self.norm(attn_output + x)
  1613. return x
  1614. class SeqCountingDecoder(nn.Module):
  1615. """
  1616. A custom sequence counting decoder that incorporates multi-head attention layers
  1617. and feed-forward networks to process sequences, potentially for latex code counting .
  1618. Args:
  1619. in_features (int): The number of input features.
  1620. out_features (int): The number of output features.
  1621. num_heads (int): The number of attention heads. Defaults to 8.
  1622. num_layers (int): The number of attention layers. Defaults to 4.
  1623. is_export (bool): Flag indicating whether to configure the layer for export.
  1624. """
  1625. def __init__(
  1626. self, in_features, out_features, num_heads=8, num_layers=4, is_export=False
  1627. ):
  1628. super(SeqCountingDecoder, self).__init__()
  1629. self.attention_blocks = nn.ModuleList(
  1630. [
  1631. SelfAttentionBlock(
  1632. embed_size=in_features, num_heads=num_heads, is_export=is_export
  1633. )
  1634. for i in range(num_layers)
  1635. ]
  1636. )
  1637. self.fc1 = nn.Linear(in_features, in_features // 2)
  1638. self.relu = nn.ReLU()
  1639. self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
  1640. self.fc2 = nn.Linear(in_features // 2, out_features)
  1641. def forward(self, x):
  1642. for block in self.attention_blocks:
  1643. x = block(x)
  1644. x = self.fc1(x)
  1645. x = self.relu(x)
  1646. x = x.transpose([0, 2, 1])
  1647. x = self.global_avg_pool(x)
  1648. x = x.squeeze(-1)
  1649. x = self.fc2(x)
  1650. return x
  1651. class CustomMBartForCausalLM(MBartForCausalLM):
  1652. """
  1653. Custom MBart model for causal language modeling with a custom decoder.
  1654. This class extends the MBartForCausalLM by replacing its decoder with a
  1655. custom decoder, allowing for additional flexibility and features in the
  1656. decoding process.
  1657. Args:
  1658. config: The configuration object containing model parameters.
  1659. length_aware (bool): A flag to enable or configure length-aware mechanisms.
  1660. """
  1661. def __init__(self, config, length_aware=True):
  1662. super().__init__(config)
  1663. self.model.decoder = CustomMBartDecoder(config)
  1664. self.counting_decoder = SeqCountingDecoder(
  1665. config.d_model, config.vocab_size, is_export=config.is_export
  1666. )
  1667. self.length_aware = length_aware
  1668. def forward(
  1669. self,
  1670. input_ids=None,
  1671. attention_mask=None,
  1672. encoder_hidden_states=None,
  1673. encoder_attention_mask=None,
  1674. head_mask=None,
  1675. cross_attn_head_mask=None,
  1676. past_key_values=None,
  1677. inputs_embeds=None,
  1678. labels=None,
  1679. use_cache=None,
  1680. output_attentions=None,
  1681. output_hidden_states=None,
  1682. return_dict=None,
  1683. count_gt=None,
  1684. ):
  1685. output_attentions = (
  1686. output_attentions
  1687. if output_attentions is not None
  1688. else self.config.output_attentions
  1689. )
  1690. output_hidden_states = (
  1691. output_hidden_states
  1692. if output_hidden_states is not None
  1693. else self.config.output_hidden_states
  1694. )
  1695. return_dict = (
  1696. return_dict if return_dict is not None else self.config.use_return_dict
  1697. )
  1698. if self.length_aware:
  1699. count_pred = self.counting_decoder(encoder_hidden_states)
  1700. else:
  1701. count_pred = None
  1702. outputs = self.model.decoder(
  1703. input_ids=input_ids,
  1704. attention_mask=attention_mask,
  1705. count_pred=count_pred,
  1706. encoder_hidden_states=encoder_hidden_states,
  1707. encoder_attention_mask=encoder_attention_mask,
  1708. head_mask=head_mask,
  1709. cross_attn_head_mask=cross_attn_head_mask,
  1710. past_key_values=past_key_values,
  1711. inputs_embeds=inputs_embeds,
  1712. use_cache=use_cache,
  1713. output_attentions=output_attentions,
  1714. output_hidden_states=output_hidden_states,
  1715. return_dict=return_dict,
  1716. )
  1717. logits = self.lm_head(outputs[0])
  1718. return CausalLMOutputWithCrossAttentionsAndCounting(
  1719. logits=logits,
  1720. counting=count_pred,
  1721. past_key_values=outputs.past_key_values,
  1722. hidden_states=outputs.hidden_states,
  1723. attentions=outputs.attentions,
  1724. cross_attentions=outputs.cross_attentions,
  1725. )
  1726. class UniMERNetHead(nn.Module):
  1727. """Implementation of UniMERNetHead decoder.
  1728. Args:
  1729. max_new_tokens (int): Maximum number of new tokens to generate.
  1730. decoder_start_token_id (int): ID of the token that starts the decoding.
  1731. temperature (float): Sampling temperature for generation.
  1732. do_sample (bool): Whether to use sampling; if False, uses greedy decoding.
  1733. top_p (float): Top-p (nucleus) sampling parameter.
  1734. in_channels (int): Number of input channels/features.
  1735. encoder_hidden_size (int): Hidden size of the encoder.
  1736. decoder_hidden_size (int): Hidden size of the decoder.
  1737. decoder_ffn_dim (int): Dimension of the decoder's feed-forward network.
  1738. decoder_layers (int): Number of layers in the decoder.
  1739. is_export (bool): Flag indicating if the model is being prepared for export.
  1740. length_aware (bool): Flag to enable length-aware mechanisms.
  1741. """
  1742. def __init__(
  1743. self,
  1744. max_new_tokens=1536,
  1745. decoder_start_token_id=0,
  1746. temperature=0.2,
  1747. do_sample=False,
  1748. top_p=0.95,
  1749. in_channels=1024,
  1750. encoder_hidden_size=1024,
  1751. decoder_hidden_size=1024,
  1752. decoder_ffn_dim=4096,
  1753. decoder_layers=8,
  1754. is_export=False,
  1755. length_aware=True,
  1756. ):
  1757. super().__init__()
  1758. mbart_config_dict = {
  1759. "activation_dropout": 0.0,
  1760. "activation_function": "gelu",
  1761. "add_cross_attention": True,
  1762. "add_final_layer_norm": True,
  1763. "attention_dropout": 0.0,
  1764. "bos_token_id": 0,
  1765. "classifier_dropout": 0.0,
  1766. "d_model": decoder_hidden_size,
  1767. "decoder_attention_heads": 16,
  1768. "decoder_ffn_dim": decoder_ffn_dim,
  1769. "decoder_layerdrop": 0.0,
  1770. "decoder_layers": decoder_layers,
  1771. "dropout": 0.1,
  1772. "encoder_attention_heads": 16,
  1773. "encoder_ffn_dim": 4096,
  1774. "encoder_layerdrop": 0.0,
  1775. "encoder_layers": 12,
  1776. "eos_token_id": 2,
  1777. "forced_eos_token_id": 2,
  1778. "init_std": 0.02,
  1779. "is_decoder": True,
  1780. "is_encoder_decoder": False,
  1781. "output_hidden_states": False,
  1782. "max_position_embeddings": max_new_tokens,
  1783. "model_type": "mbart",
  1784. "num_hidden_layers": 12,
  1785. "pad_token_id": 1,
  1786. "scale_embedding": True,
  1787. "tie_word_embeddings": False,
  1788. "transformers_version": "4.40.0",
  1789. "use_cache": True,
  1790. "use_return_dict": True,
  1791. "vocab_size": 50000,
  1792. "_attn_implementation": "eager",
  1793. "hidden_size": decoder_hidden_size,
  1794. "is_export": is_export,
  1795. }
  1796. self.max_new_tokens = max_new_tokens
  1797. self.decoder_start_token_id = decoder_start_token_id
  1798. self.temperature = temperature
  1799. self.do_sample = do_sample
  1800. self.top_p = top_p
  1801. self.max_seq_len = max_new_tokens
  1802. self.config_decoder = MBartConfig(**mbart_config_dict)
  1803. self.encoder_hidden_size = encoder_hidden_size
  1804. self.is_export = self.config_decoder.is_export
  1805. self.decoder = CustomMBartForCausalLM(
  1806. self.config_decoder, length_aware=length_aware
  1807. )
  1808. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  1809. self.enc_to_dec_proj = nn.Linear(
  1810. self.encoder_hidden_size, self.config_decoder.hidden_size
  1811. )
  1812. generation_config = {
  1813. "max_length": 1537,
  1814. "forced_eos_token_id": 2,
  1815. }
  1816. self.eos_token_id = generation_config["forced_eos_token_id"]
  1817. self.pad_token_id = self.config_decoder.pad_token_id
  1818. self.logits_processor = LogitsProcessorList()
  1819. self.logits_processor.append(
  1820. ForcedEOSTokenLogitsProcessor(
  1821. generation_config["max_length"],
  1822. generation_config["forced_eos_token_id"],
  1823. )
  1824. )
  1825. def _get_decoder_start_token_id(
  1826. self, decoder_start_token_id=None, bos_token_id=None
  1827. ) -> int:
  1828. decoder_start_token_id = (
  1829. decoder_start_token_id
  1830. if decoder_start_token_id is not None
  1831. else self.generation_config.decoder_start_token_id
  1832. )
  1833. bos_token_id = (
  1834. bos_token_id
  1835. if bos_token_id is not None
  1836. else self.generation_config.bos_token_id
  1837. )
  1838. if decoder_start_token_id is not None:
  1839. return decoder_start_token_id
  1840. elif bos_token_id is not None:
  1841. return bos_token_id
  1842. raise ValueError(
  1843. "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
  1844. )
  1845. def _prepare_decoder_input_ids_for_generation(
  1846. self,
  1847. batch_size,
  1848. model_kwargs,
  1849. decoder_start_token_id=None,
  1850. bos_token_id=None,
  1851. ):
  1852. if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
  1853. decoder_input_ids = model_kwargs.pop("decoder_input_ids")
  1854. elif "input_ids" in model_kwargs:
  1855. decoder_input_ids = model_kwargs.pop("input_ids")
  1856. else:
  1857. decoder_input_ids = None
  1858. decoder_start_token_id = self._get_decoder_start_token_id(
  1859. decoder_start_token_id, bos_token_id
  1860. )
  1861. if isinstance(decoder_start_token_id, list):
  1862. if len(decoder_start_token_id) != batch_size:
  1863. raise ValueError(
  1864. f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
  1865. )
  1866. decoder_input_ids_start = torch.LongTensor(decoder_start_token_id)
  1867. decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
  1868. else:
  1869. decoder_input_ids_start = (
  1870. torch.ones(
  1871. (batch_size, 1),
  1872. dtype=torch.int64,
  1873. )
  1874. * decoder_start_token_id
  1875. )
  1876. if decoder_input_ids is None:
  1877. decoder_input_ids = decoder_input_ids_start
  1878. elif (
  1879. self.config.model_type == "vision-encoder-decoder"
  1880. and "donut" in self.name_or_path.lower()
  1881. ):
  1882. pass
  1883. elif self.config.model_type in ["whisper"]:
  1884. pass
  1885. elif (
  1886. isinstance(decoder_start_token_id, int)
  1887. and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
  1888. ) or (
  1889. isinstance(decoder_start_token_id, torch.Tensor)
  1890. and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
  1891. ):
  1892. decoder_input_ids = torch.concat(
  1893. [decoder_input_ids_start, decoder_input_ids], dim=-1
  1894. )
  1895. if "decoder_attention_mask" in model_kwargs:
  1896. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  1897. decoder_attention_mask = torch.cat(
  1898. (
  1899. torch.ones_like(decoder_attention_mask)[:, :1],
  1900. decoder_attention_mask,
  1901. ),
  1902. dim=-1,
  1903. )
  1904. model_kwargs["decoder_attention_mask"] = decoder_attention_mask
  1905. return decoder_input_ids, model_kwargs
  1906. def prepare_inputs_for_generation_mbart(
  1907. self,
  1908. input_ids,
  1909. past_key_values=None,
  1910. attention_mask=None,
  1911. use_cache=None,
  1912. **kwargs,
  1913. ):
  1914. if attention_mask is None:
  1915. attention_mask = torch.ones(input_ids.shape)
  1916. if past_key_values:
  1917. past_length = past_key_values[0][0].shape[2]
  1918. if input_ids.shape[1] > past_length:
  1919. remove_prefix_length = past_length
  1920. else:
  1921. remove_prefix_length = input_ids.shape[1] - 1
  1922. input_ids = input_ids[:, remove_prefix_length:]
  1923. return {
  1924. "input_ids": input_ids,
  1925. "attention_mask": attention_mask,
  1926. "past_key_values": past_key_values,
  1927. "use_cache": use_cache,
  1928. }
  1929. def prepare_inputs_for_generation(
  1930. self,
  1931. input_ids,
  1932. past_key_values=None,
  1933. attention_mask=None,
  1934. use_cache=None,
  1935. encoder_outputs=None,
  1936. **kwargs,
  1937. ):
  1938. decoder_inputs = self.prepare_inputs_for_generation_mbart(
  1939. input_ids, past_key_values=past_key_values
  1940. )
  1941. decoder_attention_mask = (
  1942. decoder_inputs["attention_mask"]
  1943. if "attention_mask" in decoder_inputs
  1944. else None
  1945. )
  1946. input_dict = {
  1947. "attention_mask": attention_mask,
  1948. "decoder_attention_mask": decoder_attention_mask,
  1949. "decoder_input_ids": decoder_inputs["input_ids"],
  1950. "encoder_outputs": encoder_outputs,
  1951. "past_key_values": decoder_inputs["past_key_values"],
  1952. "use_cache": use_cache,
  1953. }
  1954. return input_dict
  1955. def prepare_inputs_for_generation_export(
  1956. self,
  1957. past_key_values=None,
  1958. attention_mask=None,
  1959. use_cache=None,
  1960. encoder_outputs=None,
  1961. **kwargs,
  1962. ):
  1963. input_dict = {
  1964. "decoder_attention_mask": None,
  1965. "use_cache": use_cache,
  1966. }
  1967. return input_dict
  1968. def _extract_past_from_model_output(
  1969. self, outputs: ModelOutput, standardize_cache_format: bool = False
  1970. ):
  1971. past_key_values = None
  1972. if "past_key_values" in outputs:
  1973. past_key_values = outputs.past_key_values
  1974. elif "mems" in outputs:
  1975. past_key_values = outputs.mems
  1976. elif "past_buckets_states" in outputs:
  1977. past_key_values = outputs.past_buckets_states
  1978. return past_key_values
  1979. def _update_model_kwargs_for_generation(
  1980. self,
  1981. outputs: ModelOutput,
  1982. model_kwargs: Dict[str, Any],
  1983. is_encoder_decoder: bool = False,
  1984. standardize_cache_format: bool = False,
  1985. ) -> Dict[str, Any]:
  1986. model_kwargs["past_key_values"] = self._extract_past_from_model_output(
  1987. outputs, standardize_cache_format=standardize_cache_format
  1988. )
  1989. if getattr(outputs, "state", None) is not None:
  1990. model_kwargs["state"] = outputs.state
  1991. if "token_type_ids" in model_kwargs:
  1992. token_type_ids = model_kwargs["token_type_ids"]
  1993. model_kwargs["token_type_ids"] = torch.concat(
  1994. [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
  1995. )
  1996. if not is_encoder_decoder:
  1997. if "attention_mask" in model_kwargs:
  1998. attention_mask = model_kwargs["attention_mask"]
  1999. model_kwargs["attention_mask"] = torch.concat(
  2000. [
  2001. attention_mask,
  2002. attention_mask.new_ones((attention_mask.shape[0], 1)),
  2003. ],
  2004. dim=-1,
  2005. )
  2006. else:
  2007. if "decoder_attention_mask" in model_kwargs:
  2008. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  2009. model_kwargs["decoder_attention_mask"] = torch.concat(
  2010. [
  2011. decoder_attention_mask,
  2012. decoder_attention_mask.new_ones(
  2013. (decoder_attention_mask.shape[0], 1)
  2014. ),
  2015. ],
  2016. dim=-1,
  2017. )
  2018. if (
  2019. "cache_position" in model_kwargs
  2020. and model_kwargs["cache_position"] is not None
  2021. ):
  2022. model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
  2023. return model_kwargs
  2024. def stopping_criteria(self, input_ids):
  2025. if self.is_export:
  2026. return input_ids[:, -1] == torch.Tensor([self.eos_token_id])
  2027. is_done = torch.isin(input_ids[:, -1], torch.Tensor([self.eos_token_id]))
  2028. return is_done
  2029. def generate_single_iter(
  2030. self,
  2031. decoder_input_ids=None,
  2032. decoder_attention_mask=None,
  2033. encoder_outputs=None,
  2034. past_key_values=None,
  2035. decoder_inputs_embeds=None,
  2036. labels=None,
  2037. use_cache=None,
  2038. output_attentions=None,
  2039. output_hidden_states=None,
  2040. return_dict=None,
  2041. **kwargs,
  2042. ):
  2043. encoder_hidden_states = encoder_outputs[0]
  2044. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  2045. encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
  2046. kwargs_decoder = {}
  2047. decoder_outputs = self.decoder(
  2048. input_ids=decoder_input_ids,
  2049. attention_mask=decoder_attention_mask,
  2050. encoder_hidden_states=encoder_hidden_states,
  2051. encoder_attention_mask=None,
  2052. inputs_embeds=None,
  2053. output_attentions=False,
  2054. output_hidden_states=output_hidden_states,
  2055. use_cache=use_cache,
  2056. past_key_values=past_key_values,
  2057. return_dict=return_dict,
  2058. **kwargs_decoder,
  2059. )
  2060. return Seq2SeqLMOutput(
  2061. loss=None,
  2062. logits=decoder_outputs.logits,
  2063. past_key_values=decoder_outputs.past_key_values,
  2064. decoder_hidden_states=decoder_outputs.hidden_states,
  2065. decoder_attentions=decoder_outputs.attentions,
  2066. cross_attentions=decoder_outputs.cross_attentions,
  2067. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  2068. encoder_hidden_states=encoder_outputs.hidden_states,
  2069. encoder_attentions=encoder_outputs.attentions,
  2070. )
  2071. @torch.no_grad()
  2072. def generate(
  2073. self,
  2074. model_kwargs,
  2075. ):
  2076. """
  2077. Generate sequences using the UniMERNetHead for inference tasks.
  2078. Args:
  2079. model_kwargs (dict): A dictionary of model configurations and inputs, which typically include:
  2080. - encoder_outputs: Outputs from the encoder.
  2081. - use_cache: Boolean flag to indicate if caching should be used.
  2082. - output_attentions: Boolean flag for outputting attention scores.
  2083. - output_hidden_states: Boolean flag for outputting hidden states.
  2084. Returns:
  2085. A tensor containing the generated sequences.
  2086. """
  2087. batch_size = model_kwargs["encoder_outputs"]["last_hidden_state"].shape[0]
  2088. generation_config = {
  2089. "decoder_start_token_id": 0,
  2090. "bos_token_id": 0,
  2091. }
  2092. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  2093. batch_size=batch_size,
  2094. model_kwargs=model_kwargs,
  2095. decoder_start_token_id=generation_config["decoder_start_token_id"],
  2096. bos_token_id=generation_config["bos_token_id"],
  2097. )
  2098. model_kwargs["key use_cache"] = True
  2099. batch_size, cur_len = input_ids.shape
  2100. if "inputs_embeds" in model_kwargs:
  2101. cur_len = model_kwargs["inputs_embeds"].shape[1]
  2102. model_kwargs["cache_position"] = torch.arange(cur_len)
  2103. pad_token_id = self.pad_token_id
  2104. eos_token_id = [self.eos_token_id]
  2105. eos_token = self.eos_token_id
  2106. unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
  2107. for idx in range(self.max_seq_len):
  2108. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  2109. outputs = self.generate_single_iter(
  2110. **model_inputs,
  2111. return_dict=True,
  2112. output_attentions=False,
  2113. output_hidden_states=False,
  2114. )
  2115. next_token_logits = outputs.logits[:, -1, :]
  2116. next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
  2117. next_tokens = torch.argmax(next_tokens_scores, dim=-1)
  2118. if eos_token_id is not None:
  2119. if pad_token_id is None:
  2120. raise ValueError(
  2121. "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
  2122. )
  2123. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  2124. 1 - unfinished_sequences
  2125. )
  2126. input_ids = torch.concat([input_ids, next_tokens[:, None]], dim=-1)
  2127. model_kwargs = self._update_model_kwargs_for_generation(
  2128. outputs,
  2129. model_kwargs,
  2130. is_encoder_decoder=self.config_decoder.is_encoder_decoder,
  2131. )
  2132. unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
  2133. input_ids
  2134. ).to(torch.int64)
  2135. if (
  2136. eos_token is not None
  2137. and (
  2138. torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
  2139. >= 1
  2140. ).all()
  2141. ):
  2142. break
  2143. return input_ids
  2144. @torch.no_grad()
  2145. def generate_export(
  2146. self,
  2147. encoder_outputs,
  2148. model_kwargs,
  2149. ):
  2150. batch_size = encoder_outputs["last_hidden_state"].shape[0]
  2151. generation_config = {
  2152. "decoder_start_token_id": 0,
  2153. "bos_token_id": 0,
  2154. }
  2155. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  2156. batch_size=batch_size,
  2157. model_kwargs=model_kwargs,
  2158. decoder_start_token_id=generation_config["decoder_start_token_id"],
  2159. bos_token_id=generation_config["bos_token_id"],
  2160. )
  2161. input_ids = input_ids.reshape([-1, 1])
  2162. decoder_input_ids = input_ids
  2163. model_kwargs["key use_cache"] = True
  2164. batch_size, cur_len = input_ids.shape
  2165. if "inputs_embeds" in model_kwargs:
  2166. cur_len = model_kwargs["inputs_embeds"].shape[1]
  2167. cache_position = torch.arange(cur_len)
  2168. pad_token_id = self.pad_token_id
  2169. eos_token_id = [self.eos_token_id]
  2170. eos_token = self.eos_token_id
  2171. unfinished_sequences = torch.ones([batch_size], dtype=torch.int64)
  2172. i_idx = torch.full([], 0)
  2173. past_key_values = []
  2174. for i in range(8):
  2175. init_arr = torch.zeros([batch_size, 16, 0, 64])
  2176. cache = (init_arr, init_arr, init_arr, init_arr)
  2177. past_key_values.append(cache)
  2178. idx = 0
  2179. while i_idx < torch.Tensor(self.max_seq_len):
  2180. model_inputs = self.prepare_inputs_for_generation_export(
  2181. past_key_values=past_key_values, **model_kwargs
  2182. )
  2183. decoder_attention_mask = model_inputs["decoder_attention_mask"]
  2184. decoder_attention_mask = torch.ones(input_ids.shape)
  2185. outputs = self.generate_single_iter(
  2186. decoder_input_ids=decoder_input_ids,
  2187. decoder_attention_mask=decoder_attention_mask,
  2188. encoder_outputs=encoder_outputs,
  2189. past_key_values=past_key_values,
  2190. return_dict=True,
  2191. output_attentions=False,
  2192. output_hidden_states=False,
  2193. )
  2194. next_token_logits = outputs.logits[:, -1, :]
  2195. next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
  2196. next_tokens = torch.argmax(next_tokens_scores, dim=-1)
  2197. if eos_token_id is not None:
  2198. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  2199. 1 - unfinished_sequences
  2200. )
  2201. input_ids = torch.concat([input_ids, next_tokens.unsqueeze(1)], dim=-1)
  2202. past_length = past_key_values[0][0].shape[2]
  2203. decoder_input_ids = next_tokens.unsqueeze(1)
  2204. past_key_values = outputs.past_key_values
  2205. cache_position = cache_position[-1:] + 1
  2206. unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
  2207. input_ids
  2208. ).to(torch.int64)
  2209. if (
  2210. eos_token is not None
  2211. and (
  2212. torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
  2213. >= 1
  2214. ).all()
  2215. ):
  2216. break
  2217. i_idx += 1
  2218. return input_ids
  2219. def forwad_train(
  2220. self,
  2221. encoder_outputs,
  2222. decoder_input_ids,
  2223. decoder_attention_mask,
  2224. past_key_values=None,
  2225. decoder_inputs_embeds=None,
  2226. labels=None,
  2227. use_cache=None,
  2228. output_attentions=None,
  2229. output_hidden_states=None,
  2230. return_dict=None,
  2231. **kwargs,
  2232. ):
  2233. """
  2234. Training for the UniMERNetHead.
  2235. Args:
  2236. encoder_outputs: Outputs from the encoder, used as input to the decoder.
  2237. decoder_input_ids: Input IDs for the decoder.
  2238. decoder_attention_mask: Attention mask for the decoder inputs.
  2239. past_key_values: Cached key/values for faster decoding.
  2240. decoder_inputs_embeds: Optional embeddings for the decoder inputs.
  2241. labels: Target labels for calculating loss.
  2242. use_cache: Whether to use cache during decoding.
  2243. output_attentions: Whether to return attention scores.
  2244. output_hidden_states: Whether to return hidden states.
  2245. return_dict: Whether to return a dictionary of outputs.
  2246. **kwargs: Additional keyword arguments.
  2247. Returns:
  2248. logits: The raw, unnormalized predictions from the model.
  2249. count_pred: Optional prediction related to sequence length or other counts.
  2250. masked_labels: The labels used during training, possibly masked.
  2251. """
  2252. labels = decoder_input_ids * 1
  2253. labels = labels.masked_fill_(labels == self.pad_token_id, -100)
  2254. input_decoder_input_ids = decoder_input_ids[:, :-1]
  2255. input_decoder_attention_mask = decoder_attention_mask[:, :-1]
  2256. encoder_hidden_states = encoder_outputs[0]
  2257. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  2258. encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
  2259. kwargs_decoder = {}
  2260. decoder_outputs = self.decoder(
  2261. input_ids=input_decoder_input_ids,
  2262. attention_mask=input_decoder_attention_mask,
  2263. encoder_hidden_states=encoder_hidden_states,
  2264. encoder_attention_mask=None,
  2265. inputs_embeds=None,
  2266. output_attentions=False,
  2267. output_hidden_states=output_hidden_states,
  2268. use_cache=use_cache,
  2269. past_key_values=past_key_values,
  2270. return_dict=return_dict,
  2271. **kwargs_decoder,
  2272. )
  2273. logits = decoder_outputs.logits
  2274. count_pred = decoder_outputs.counting
  2275. return logits, count_pred, labels
  2276. def forward(self, inputs, targets=None):
  2277. """
  2278. Forward pass for the UniMERNetHead, handling both training and inference.
  2279. Args:
  2280. inputs: The input data, which can vary based on training or inference.
  2281. targets: The target labels, used only during training.
  2282. Returns:
  2283. During inference: Returns predicted latex code.
  2284. During training: Returns logits, predicted counts, and masked labels.
  2285. """
  2286. self.is_export = False if self.training else True
  2287. if not self.training:
  2288. encoder_outputs = inputs
  2289. if self.is_export:
  2290. model_kwargs = {
  2291. "output_attentions": False,
  2292. "output_hidden_states": False,
  2293. "use_cache": True,
  2294. }
  2295. word_pred = self.generate_export(encoder_outputs, model_kwargs)
  2296. else:
  2297. model_kwargs = {
  2298. "output_attentions": False,
  2299. "output_hidden_states": False,
  2300. "use_cache": True,
  2301. "encoder_outputs": encoder_outputs,
  2302. }
  2303. word_pred = self.generate(model_kwargs)
  2304. return word_pred
  2305. encoder_outputs, tgt_seq, mask = inputs
  2306. logits, count_pred, masked_labels = self.forwad_train(
  2307. encoder_outputs, tgt_seq, mask
  2308. )
  2309. return logits, count_pred, masked_labels