rec_unimernet_head.py 94 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631
  1. import copy
  2. import math
  3. import re
  4. import numpy as np
  5. import inspect
  6. import warnings
  7. from collections import OrderedDict
  8. from typing import Optional, Tuple, Union, List, Dict, Any
  9. from dataclasses import dataclass, fields, is_dataclass
  10. import torch
  11. import torch.nn as nn
  12. from torch import Tensor
  13. import torch.nn.functional as F
  14. from torch.nn import CrossEntropyLoss
  15. from mineru.utils.config_reader import get_device
  16. class ModelOutput(OrderedDict):
  17. def __init__(self, *args, **kwargs):
  18. super().__init__(*args, **kwargs)
  19. def __post_init__(self):
  20. class_fields = fields(self)
  21. if not len(class_fields):
  22. raise ValueError(f"{self.__class__.__name__} has no fields.")
  23. if not all(field.default is None for field in class_fields[1:]):
  24. raise ValueError(
  25. f"{self.__class__.__name__} should not have more than one required field."
  26. )
  27. first_field = getattr(self, class_fields[0].name)
  28. other_fields_are_none = all(
  29. getattr(self, field.name) is None for field in class_fields[1:]
  30. )
  31. if other_fields_are_none:
  32. if isinstance(first_field, dict):
  33. iterator = first_field.items()
  34. first_field_iterator = True
  35. else:
  36. try:
  37. iterator = iter(first_field)
  38. first_field_iterator = True
  39. except TypeError:
  40. first_field_iterator = False
  41. if first_field_iterator:
  42. for idx, element in enumerate(iterator):
  43. if (
  44. not isinstance(element, (list, tuple))
  45. or not len(element) == 2
  46. or not isinstance(element[0], str)
  47. ):
  48. if idx == 0:
  49. self[class_fields[0].name] = first_field
  50. else:
  51. raise ValueError(
  52. f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
  53. )
  54. break
  55. setattr(self, element[0], element[1])
  56. if element[1] is not None:
  57. self[element[0]] = element[1]
  58. elif first_field is not None:
  59. self[class_fields[0].name] = first_field
  60. else:
  61. for field in class_fields:
  62. v = getattr(self, field.name)
  63. if v is not None:
  64. self[field.name] = v
  65. def __delitem__(self, *args, **kwargs):
  66. raise Exception(
  67. f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
  68. )
  69. def setdefault(self, *args, **kwargs):
  70. raise Exception(
  71. f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
  72. )
  73. def pop(self, *args, **kwargs):
  74. raise Exception(
  75. f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
  76. )
  77. def update(self, *args, **kwargs):
  78. raise Exception(
  79. f"You cannot use ``update`` on a {self.__class__.__name__} instance."
  80. )
  81. def __getitem__(self, k):
  82. if isinstance(k, str):
  83. inner_dict = dict(self.items())
  84. return inner_dict[k]
  85. else:
  86. return self.to_tuple()[k]
  87. def __setattr__(self, name, value):
  88. if name in self.keys() and value is not None:
  89. super().__setitem__(name, value)
  90. super().__setattr__(name, value)
  91. def __setitem__(self, key, value):
  92. super().__setitem__(key, value)
  93. super().__setattr__(key, value)
  94. def __reduce__(self):
  95. if not is_dataclass(self):
  96. return super().__reduce__()
  97. callable, _args, *remaining = super().__reduce__()
  98. args = tuple(getattr(self, field.name) for field in fields(self))
  99. return callable, args, *remaining
  100. def to_tuple(self):
  101. return tuple(self[k] for k in self.keys())
  102. @dataclass
  103. class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
  104. last_hidden_state = None
  105. past_key_values = None
  106. hidden_states = None
  107. attentions = None
  108. cross_attentions = None
  109. def __init__(self, *args, **kwargs):
  110. super().__init__(*args, **kwargs)
  111. @dataclass
  112. class Seq2SeqLMOutput(ModelOutput):
  113. loss = None
  114. logits = None
  115. past_key_values = None
  116. decoder_hidden_states = None
  117. decoder_attentions = None
  118. cross_attentions = None
  119. encoder_last_hidden_state = None
  120. encoder_hidden_states = None
  121. encoder_attentions = None
  122. def __init__(self, *args, **kwargs):
  123. super().__init__(*args, **kwargs)
  124. class MBartConfig(object):
  125. model_type = "mbart"
  126. keys_to_ignore_at_inference = ["past_key_values"]
  127. attribute_map = {
  128. "num_attention_heads": "encoder_attention_heads",
  129. "hidden_size": "d_model",
  130. }
  131. def __init__(
  132. self,
  133. vocab_size=50265,
  134. max_position_embeddings=1024,
  135. encoder_layers=12,
  136. encoder_ffn_dim=4096,
  137. encoder_attention_heads=16,
  138. decoder_layers=12,
  139. decoder_ffn_dim=4096,
  140. decoder_attention_heads=16,
  141. encoder_layerdrop=0.0,
  142. decoder_layerdrop=0.0,
  143. use_cache=True,
  144. is_encoder_decoder=True,
  145. activation_function="gelu",
  146. d_model=1024,
  147. dropout=0.1,
  148. output_hidden_states=False,
  149. use_return_dict=True,
  150. attention_dropout=0.0,
  151. activation_dropout=0.0,
  152. init_std=0.02,
  153. classifier_dropout=0.0,
  154. scale_embedding=False,
  155. pad_token_id=1,
  156. bos_token_id=0,
  157. eos_token_id=2,
  158. forced_eos_token_id=2,
  159. _attn_implementation="eager",
  160. hidden_size=1024,
  161. use_parallel=False,
  162. parallel_step=2,
  163. is_export=False,
  164. **kwargs,
  165. ):
  166. self.vocab_size = vocab_size
  167. self.hidden_size = hidden_size
  168. self.max_position_embeddings = max_position_embeddings
  169. self.d_model = d_model
  170. self.encoder_ffn_dim = encoder_ffn_dim
  171. self.encoder_layers = encoder_layers
  172. self.encoder_attention_heads = encoder_attention_heads
  173. self.decoder_ffn_dim = decoder_ffn_dim
  174. self.decoder_layers = decoder_layers
  175. self.decoder_attention_heads = decoder_attention_heads
  176. self.dropout = dropout
  177. self.output_hidden_states = output_hidden_states
  178. self.use_return_dict = use_return_dict
  179. self.attention_dropout = attention_dropout
  180. self.activation_dropout = activation_dropout
  181. self.activation_function = activation_function
  182. self.init_std = init_std
  183. self.encoder_layerdrop = encoder_layerdrop
  184. self.decoder_layerdrop = decoder_layerdrop
  185. self.classifier_dropout = classifier_dropout
  186. self.use_cache = use_cache
  187. self.num_hidden_layers = encoder_layers
  188. self.scale_embedding = (
  189. scale_embedding # scale factor will be sqrt(d_model) if True
  190. )
  191. self.pad_token_id = pad_token_id
  192. self.bos_token_id = bos_token_id
  193. self.eos_token_id = eos_token_id
  194. self.is_encoder_decoder = is_encoder_decoder
  195. self.forced_eos_token_id = forced_eos_token_id
  196. self._attn_implementation = _attn_implementation
  197. self.use_parallel = use_parallel
  198. self.parallel_step = parallel_step
  199. self.is_export = is_export
  200. super().__init__()
  201. @dataclass
  202. class AttentionMaskConverter:
  203. """
  204. A utility class for converting attention masks used in transformer models.
  205. This class handles the conversion of attention masks based on whether the
  206. attention mechanism is causal (i.e., preventing information flow from future
  207. tokens to past tokens) and whether a sliding window approach is used.
  208. Attributes:
  209. is_causal (bool): Indicates if the attention mechanism is causal.
  210. sliding_window (Optional[int]): Specifies the size of the sliding window
  211. for local attention, if applicable.
  212. Args:
  213. is_causal (bool): Determines if the attention mask should enforce causality.
  214. sliding_window (Optional[int], optional): The size of the sliding window
  215. for local attention. Default is None.
  216. """
  217. is_causal: bool
  218. sliding_window: int
  219. def __init__(self, is_causal: bool, sliding_window=None):
  220. self.is_causal = is_causal
  221. self.sliding_window = sliding_window
  222. if self.sliding_window is not None and self.sliding_window <= 0:
  223. raise ValueError(
  224. f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
  225. )
  226. @staticmethod
  227. def _make_causal_mask(
  228. input_ids_shape,
  229. dtype,
  230. past_key_values_length=0,
  231. sliding_window=None,
  232. is_export=False,
  233. ):
  234. bsz, tgt_len = input_ids_shape
  235. if is_export:
  236. mask = torch.full(
  237. [tgt_len, tgt_len], fill_value=torch.finfo(dtype).min, dtype=torch.float64
  238. )
  239. else:
  240. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
  241. mask_cond = torch.arange(mask.shape[-1])
  242. mask = mask.masked_fill_(
  243. mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
  244. )
  245. return mask[None, None, :, :].expand(
  246. [bsz, 1, tgt_len, tgt_len + past_key_values_length]
  247. )
  248. def to_4d_export(
  249. self,
  250. attention_mask_2d,
  251. query_length,
  252. dtype,
  253. key_value_length,
  254. is_export=False,
  255. ):
  256. input_shape = (attention_mask_2d.shape[0], query_length)
  257. expanded_attn_mask = self._expand_mask(
  258. attention_mask_2d, dtype, tgt_len=input_shape[-1]
  259. )
  260. expanded_4d_mask = expanded_attn_mask
  261. return expanded_4d_mask
  262. def to_4d(
  263. self,
  264. attention_mask_2d,
  265. query_length,
  266. dtype,
  267. key_value_length,
  268. is_export=False,
  269. ):
  270. input_shape = (attention_mask_2d.shape[0], query_length)
  271. causal_4d_mask = None
  272. if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
  273. if key_value_length is None:
  274. raise ValueError(
  275. "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
  276. )
  277. past_key_values_length = key_value_length - query_length
  278. causal_4d_mask = self._make_causal_mask(
  279. input_shape,
  280. dtype,
  281. past_key_values_length=past_key_values_length,
  282. sliding_window=self.sliding_window,
  283. is_export=is_export,
  284. )
  285. elif self.sliding_window is not None:
  286. raise NotImplementedError(
  287. "Sliding window is currently only implemented for causal masking"
  288. )
  289. expanded_attn_mask = self._expand_mask(
  290. attention_mask_2d, dtype, tgt_len=input_shape[-1]
  291. )
  292. if causal_4d_mask is not None:
  293. if is_export:
  294. expanded_attn_mask = causal_4d_mask
  295. return expanded_attn_mask
  296. else:
  297. expanded_attn_mask = causal_4d_mask.masked_fill_(
  298. expanded_attn_mask.to(torch.bool), torch.finfo(dtype).min
  299. )
  300. expanded_4d_mask = expanded_attn_mask
  301. return expanded_4d_mask
  302. def _expand_mask(self, mask, dtype, tgt_len=None):
  303. bsz, src_len = mask.shape
  304. tgt_len = tgt_len if tgt_len is not None else src_len
  305. expanded_mask = (
  306. mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
  307. )
  308. inverted_mask = 1.0 - expanded_mask
  309. return inverted_mask.masked_fill_(
  310. inverted_mask.to(torch.bool), torch.finfo(dtype).min
  311. )
  312. def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
  313. return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
  314. def _prepare_4d_causal_attention_mask_export(
  315. attention_mask,
  316. input_shape,
  317. inputs_embeds,
  318. past_key_values_length,
  319. sliding_window=None,
  320. is_export=False,
  321. ):
  322. attn_mask_converter = AttentionMaskConverter(
  323. is_causal=True, sliding_window=sliding_window
  324. )
  325. key_value_length = input_shape[-1] + past_key_values_length
  326. shape = attention_mask.shape
  327. len_shape = len(shape)
  328. attention_mask = attn_mask_converter.to_4d_export(
  329. attention_mask,
  330. input_shape[-1],
  331. key_value_length=key_value_length,
  332. dtype=inputs_embeds.dtype,
  333. is_export=is_export,
  334. )
  335. return attention_mask
  336. def _prepare_4d_causal_attention_mask(
  337. attention_mask,
  338. input_shape,
  339. inputs_embeds,
  340. past_key_values_length,
  341. sliding_window=None,
  342. is_export=False,
  343. ):
  344. attn_mask_converter = AttentionMaskConverter(
  345. is_causal=True, sliding_window=sliding_window
  346. )
  347. key_value_length = input_shape[-1] + past_key_values_length
  348. shape = attention_mask.shape
  349. len_shape = len(shape)
  350. if (attention_mask is not None) and (len_shape == 2):
  351. attention_mask = attn_mask_converter.to_4d(
  352. attention_mask,
  353. input_shape[-1],
  354. key_value_length=key_value_length,
  355. dtype=inputs_embeds.dtype,
  356. is_export=is_export,
  357. )
  358. return attention_mask
  359. elif attention_mask is not None and len(attention_mask.shape) == 4:
  360. expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
  361. if tuple(attention_mask.shape) != expected_shape:
  362. raise ValueError(
  363. f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
  364. )
  365. else:
  366. inverted_mask = 1.0 - attention_mask
  367. attention_mask = inverted_mask.masked_fill_(
  368. inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
  369. )
  370. else:
  371. attention_mask = attn_mask_converter.to_causal_4d(
  372. input_shape[0],
  373. input_shape[-1],
  374. key_value_length,
  375. dtype=inputs_embeds.dtype,
  376. )
  377. return attention_mask
  378. class MBartLearnedPositionalEmbedding(nn.Embedding):
  379. """
  380. This module learns positional embeddings up to a fixed maximum size.
  381. """
  382. def __init__(self, num_embeddings, embedding_dim):
  383. self.offset = 2
  384. super().__init__(num_embeddings + self.offset, embedding_dim)
  385. self.device = torch.device(get_device())
  386. def forward(self, input_ids, past_key_values_length=0):
  387. """`input_ids' shape is expected to be [bsz x seqlen]."""
  388. bsz, seq_len = input_ids.shape[:2]
  389. positions = torch.arange(
  390. past_key_values_length, past_key_values_length + seq_len, dtype=torch.int64
  391. ).expand([bsz, -1]).to(self.device)
  392. return nn.Embedding.forward(self, positions + self.offset)
  393. class MBartPreTrainedModel(nn.Module):
  394. base_model_prefix = "model"
  395. supports_gradient_checkpointing = True
  396. _no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
  397. _supports_flash_attn_2 = True
  398. def __init__(self, config):
  399. super().__init__()
  400. self.config = config
  401. def _initialize_weights(self, module):
  402. """
  403. Initialize the weights if they are not already initialized.
  404. """
  405. if getattr(module, "_is_hf_initialized", False):
  406. return
  407. self._init_weights(module)
  408. def post_init(self):
  409. self.apply(self._initialize_weights)
  410. def _init_weights(self, module):
  411. std = self.config.init_std
  412. if isinstance(module, nn.Linear):
  413. torch.nn.init.normal_(module.weight, mean=0.0, std=std)
  414. if module.bias is not None:
  415. torch.nn.init.constant_(module.bias, val=0.0)
  416. elif isinstance(module, nn.Embedding):
  417. torch.nn.init.normal_(module.weight, mean=0.0, std=std)
  418. if module.padding_idx is not None:
  419. torch.nn.init.constant_(module.weight[module.padding_idx], val=0.0)
  420. @property
  421. def dummy_inputs(self):
  422. pad_token = self.config.pad_token_id
  423. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
  424. dummy_inputs = {
  425. "attention_mask": input_ids.ne(pad_token),
  426. "input_ids": input_ids,
  427. }
  428. return dummy_inputs
  429. class MBartAttention(nn.Module):
  430. """Multi-headed attention from 'Attention Is All You Need' paper"""
  431. def __init__(
  432. self,
  433. embed_dim,
  434. num_heads,
  435. dropout: float = 0.0,
  436. is_decoder: bool = False,
  437. bias: bool = True,
  438. is_causal: bool = False,
  439. config=None,
  440. ):
  441. super().__init__()
  442. self.embed_dim = embed_dim
  443. self.num_heads = num_heads
  444. self.dropout = dropout
  445. self.head_dim = embed_dim // num_heads
  446. self.config = config
  447. if (self.head_dim * num_heads) != self.embed_dim:
  448. raise ValueError(
  449. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  450. f" and `num_heads`: {num_heads})."
  451. )
  452. self.scaling = self.head_dim ** -0.5
  453. self.is_decoder = is_decoder
  454. self.is_causal = is_causal
  455. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  456. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  457. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  458. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  459. def _shape(self, tensor, seq_len, bsz):
  460. return tensor.reshape([bsz, seq_len, self.num_heads, self.head_dim]).permute(
  461. 0, 2, 1, 3
  462. )
  463. def forward(
  464. self,
  465. hidden_states,
  466. key_value_states=None,
  467. past_key_value=None,
  468. attention_mask=None,
  469. layer_head_mask=None,
  470. output_attentions=False,
  471. ):
  472. is_cross_attention = key_value_states is not None
  473. bsz, tgt_len, _ = hidden_states.shape
  474. query_states = self.q_proj(hidden_states) * self.scaling
  475. if (
  476. is_cross_attention
  477. and past_key_value is not None
  478. and past_key_value[0].shape[2] == key_value_states.shape[1]
  479. ):
  480. key_states = past_key_value[0]
  481. value_states = past_key_value[1]
  482. elif is_cross_attention:
  483. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  484. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  485. elif past_key_value is not None:
  486. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  487. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  488. key_states = torch.concat([past_key_value[0], key_states], dim=2)
  489. value_states = torch.concat([past_key_value[1], value_states], dim=2)
  490. else:
  491. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  492. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  493. if self.is_decoder:
  494. past_key_value = (key_states, value_states)
  495. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  496. query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape)
  497. key_states = key_states.reshape(proj_shape)
  498. value_states = value_states.reshape(proj_shape)
  499. src_len = key_states.shape[1]
  500. attn_weights = torch.bmm(query_states, key_states.permute([0, 2, 1]))
  501. if attention_mask is not None:
  502. attn_weights = (
  503. attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
  504. + attention_mask
  505. )
  506. attn_weights = attn_weights.reshape(
  507. [bsz * self.num_heads, tgt_len, src_len]
  508. )
  509. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  510. if layer_head_mask is not None:
  511. if tuple(layer_head_mask.shape) != (self.num_heads,):
  512. raise ValueError(
  513. f"Head mask for a single layer should be of shape {(self.num_heads,)}, but is"
  514. f" {layer_head_mask.shape}"
  515. )
  516. attn_weights = layer_head_mask.reshape(
  517. [1, -1, 1, 1]
  518. ) * attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
  519. attn_weights = attn_weights.reshape(
  520. [bsz * self.num_heads, tgt_len, src_len]
  521. )
  522. if output_attentions:
  523. attn_weights_reshaped = attn_weights.reshape(
  524. [bsz, self.num_heads, tgt_len, src_len]
  525. )
  526. attn_weights = attn_weights_reshaped.reshape(
  527. [bsz * self.num_heads, tgt_len, src_len]
  528. )
  529. else:
  530. attn_weights_reshaped = None
  531. attn_probs = nn.functional.dropout(
  532. attn_weights, p=self.dropout, training=self.training
  533. )
  534. attn_output = torch.bmm(attn_probs, value_states)
  535. attn_output = attn_output.reshape([bsz, self.num_heads, tgt_len, self.head_dim])
  536. attn_output = attn_output.permute([0, 2, 1, 3])
  537. attn_output = attn_output.reshape([bsz, tgt_len, self.embed_dim])
  538. attn_output = self.out_proj(attn_output)
  539. return attn_output, attn_weights_reshaped, past_key_value
  540. MBART_ATTENTION_CLASSES = {
  541. "eager": MBartAttention,
  542. }
  543. class MBartDecoderLayer(nn.Module):
  544. def __init__(self, config):
  545. super().__init__()
  546. self.embed_dim = config.d_model
  547. self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
  548. embed_dim=self.embed_dim,
  549. num_heads=config.decoder_attention_heads,
  550. dropout=config.attention_dropout,
  551. is_decoder=True,
  552. is_causal=True,
  553. config=config,
  554. )
  555. self.is_export = config.is_export
  556. self.dropout = config.dropout
  557. self.activation_fn = F.gelu
  558. self.activation_dropout = config.activation_dropout
  559. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  560. self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
  561. self.embed_dim,
  562. config.decoder_attention_heads,
  563. dropout=config.attention_dropout,
  564. is_decoder=True,
  565. config=config,
  566. )
  567. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  568. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  569. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  570. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  571. self.device = torch.device(get_device())
  572. def forward(
  573. self,
  574. hidden_states,
  575. attention_mask=None,
  576. encoder_hidden_states=None,
  577. encoder_attention_mask=None,
  578. layer_head_mask=None,
  579. cross_attn_layer_head_mask=None,
  580. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  581. output_attentions: Optional[bool] = False,
  582. use_cache: Optional[bool] = True,
  583. ) -> torch.Tensor:
  584. residual = hidden_states
  585. hidden_states = self.self_attn_layer_norm(hidden_states)
  586. self_attn_past_key_value = None
  587. if past_key_value is not None:
  588. self_attn_past_key_value = tuple(
  589. t.to(self.device) if isinstance(t, torch.Tensor) else t for t in past_key_value[:2]
  590. )
  591. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  592. hidden_states=hidden_states,
  593. past_key_value=self_attn_past_key_value,
  594. attention_mask=attention_mask,
  595. layer_head_mask=layer_head_mask,
  596. output_attentions=output_attentions,
  597. )
  598. hidden_states = nn.functional.dropout(
  599. hidden_states, p=self.dropout, training=self.training
  600. )
  601. hidden_states = residual + hidden_states
  602. cross_attn_present_key_value = None
  603. cross_attn_weights = None
  604. if encoder_hidden_states is not None:
  605. residual = hidden_states
  606. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  607. cross_attn_past_key_value = (
  608. past_key_value[-2:] if past_key_value is not None else None
  609. )
  610. hidden_states, cross_attn_weights, cross_attn_present_key_value = (
  611. self.encoder_attn(
  612. hidden_states=hidden_states,
  613. key_value_states=encoder_hidden_states,
  614. attention_mask=encoder_attention_mask,
  615. layer_head_mask=cross_attn_layer_head_mask,
  616. past_key_value=cross_attn_past_key_value,
  617. output_attentions=output_attentions,
  618. )
  619. )
  620. hidden_states = nn.functional.dropout(
  621. hidden_states, p=self.dropout, training=self.training
  622. )
  623. hidden_states = residual + hidden_states
  624. present_key_value = present_key_value + cross_attn_present_key_value
  625. residual = hidden_states
  626. hidden_states = self.final_layer_norm(hidden_states)
  627. hidden_states = self.activation_fn(self.fc1(hidden_states))
  628. hidden_states = nn.functional.dropout(
  629. hidden_states, p=self.activation_dropout, training=self.training
  630. )
  631. hidden_states = self.fc2(hidden_states)
  632. hidden_states = nn.functional.dropout(
  633. hidden_states, p=self.dropout, training=self.training
  634. )
  635. hidden_states = residual + hidden_states
  636. outputs = (hidden_states,)
  637. if output_attentions:
  638. outputs += (self_attn_weights, cross_attn_weights)
  639. if self.is_export:
  640. outputs += (present_key_value,)
  641. else:
  642. if use_cache:
  643. outputs += (present_key_value,)
  644. return outputs
  645. class MBartForCausalLM(MBartPreTrainedModel):
  646. _tied_weights_keys = ["lm_head.weight"]
  647. def __init__(self, config):
  648. config = copy.deepcopy(config)
  649. config.is_decoder = True
  650. config.is_encoder_decoder = False
  651. super().__init__(config)
  652. self.model = MBartDecoderWrapper(config)
  653. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  654. self.post_init()
  655. def get_input_embeddings(self):
  656. return self.model.decoder.embed_tokens
  657. def set_input_embeddings(self, value):
  658. self.model.decoder.embed_tokens = value
  659. def get_output_embeddings(self):
  660. return self.lm_head
  661. def set_output_embeddings(self, new_embeddings):
  662. self.lm_head = new_embeddings
  663. def set_decoder(self, decoder):
  664. self.model.decoder = decoder
  665. def get_decoder(self):
  666. return self.model.decoder
  667. def forward(
  668. self,
  669. input_ids=None,
  670. attention_mask=None,
  671. encoder_hidden_states=None,
  672. encoder_attention_mask=None,
  673. head_mask=None,
  674. cross_attn_head_mask=None,
  675. past_key_values=None,
  676. inputs_embeds=None,
  677. labels=None,
  678. use_cache=None,
  679. output_attentions=None,
  680. output_hidden_states=None,
  681. return_dict=None,
  682. ):
  683. output_attentions = (
  684. output_attentions
  685. if output_attentions is not None
  686. else self.config.output_attentions
  687. )
  688. output_hidden_states = (
  689. output_hidden_states
  690. if output_hidden_states is not None
  691. else self.config.output_hidden_states
  692. )
  693. return_dict = (
  694. return_dict if return_dict is not None else self.config.use_return_dict
  695. )
  696. outputs = self.model.decoder(
  697. input_ids=input_ids,
  698. attention_mask=attention_mask,
  699. encoder_hidden_states=encoder_hidden_states,
  700. encoder_attention_mask=encoder_attention_mask,
  701. head_mask=head_mask,
  702. cross_attn_head_mask=cross_attn_head_mask,
  703. past_key_values=past_key_values,
  704. inputs_embeds=inputs_embeds,
  705. use_cache=use_cache,
  706. output_attentions=output_attentions,
  707. output_hidden_states=output_hidden_states,
  708. return_dict=return_dict,
  709. )
  710. logits = self.lm_head(outputs[0])
  711. loss = None
  712. if labels is not None:
  713. labels = labels
  714. loss_fct = CrossEntropyLoss()
  715. loss = loss_fct(
  716. logits.reshape([-1, self.config.vocab_size]), labels.reshape([-1])
  717. )
  718. if not return_dict:
  719. output = (logits,) + outputs[1:]
  720. return (loss,) + output if loss is not None else output
  721. return CausalLMOutputWithCrossAttentions(
  722. loss=loss,
  723. logits=logits,
  724. past_key_values=outputs.past_key_values,
  725. hidden_states=outputs.hidden_states,
  726. attentions=outputs.attentions,
  727. cross_attentions=outputs.cross_attentions,
  728. )
  729. def prepare_inputs_for_generation(
  730. self,
  731. input_ids,
  732. past_key_values=None,
  733. attention_mask=None,
  734. use_cache=None,
  735. **kwargs,
  736. ):
  737. if attention_mask is None:
  738. attention_mask = input_ids.new_ones(input_ids.shape)
  739. if past_key_values:
  740. past_length = past_key_values[0][0].shape[2]
  741. if input_ids.shape[1] > past_length:
  742. remove_prefix_length = past_length
  743. else:
  744. remove_prefix_length = input_ids.shape[1] - 1
  745. input_ids = input_ids[:, remove_prefix_length:]
  746. return {
  747. "input_ids": input_ids,
  748. "attention_mask": attention_mask,
  749. "past_key_values": past_key_values,
  750. "use_cache": use_cache,
  751. }
  752. @staticmethod
  753. def _reorder_cache(past_key_values, beam_idx):
  754. reordered_past = ()
  755. for layer_past in past_key_values:
  756. reordered_past += (
  757. tuple(
  758. past_state.index_select(0, beam_idx) for past_state in layer_past
  759. ),
  760. )
  761. return reordered_past
  762. class myLayerNorm(nn.LayerNorm):
  763. """
  764. Custom implementation of Layer Normalization, with additional options.
  765. This class extends the standard LayerNorm to include optional features,
  766. such as drop block regularization, which might be used for improving
  767. model generalization.
  768. Args:
  769. num_channels (int): The number of features or channels in the input.
  770. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-5.
  771. affine (bool, optional): If True, this module has learnable affine parameters (gamma and beta). Default is True.
  772. drop_block (optional): Additional regularization technique that might be applied. Default is None.
  773. """
  774. def __init__(
  775. self,
  776. num_channels,
  777. eps=1e-5,
  778. affine=True,
  779. drop_block=None,
  780. ):
  781. super(nn.LayerNorm, self).__init__()
  782. self._epsilon = eps
  783. self.num_channels = num_channels
  784. if affine:
  785. self.weight = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
  786. self.bias = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
  787. torch.nn.init.ones_(self.weight)
  788. torch.nn.init.zeros_(self.bias)
  789. def forward(self, x):
  790. x = F.layer_norm(
  791. x,
  792. [self.num_channels],
  793. weight=self.weight,
  794. bias=self.bias,
  795. eps=self._epsilon,
  796. )
  797. return x
  798. class MBartDecoder(MBartPreTrainedModel):
  799. """
  800. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
  801. Args:
  802. config
  803. embed_tokens (nn.Embedding): output embedding
  804. """
  805. def __init__(self, config, embed_tokens=None):
  806. super().__init__(config)
  807. self.dropout = config.dropout
  808. self.layerdrop = config.decoder_layerdrop
  809. self.padding_idx = config.pad_token_id
  810. self.max_target_positions = config.max_position_embeddings
  811. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  812. self.embed_tokens = nn.Embedding(
  813. config.vocab_size, config.d_model, self.padding_idx
  814. )
  815. if embed_tokens is not None:
  816. self.embed_tokens.weight = embed_tokens.weight
  817. self.embed_positions = MBartLearnedPositionalEmbedding(
  818. config.max_position_embeddings,
  819. config.d_model,
  820. )
  821. self.layers = nn.ModuleList(
  822. [MBartDecoderLayer(config) for _ in range(config.decoder_layers)]
  823. )
  824. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  825. self.layernorm_embedding = myLayerNorm(config.d_model, affine=True)
  826. self.layer_norm = nn.LayerNorm(config.d_model)
  827. self.gradient_checkpointing = False
  828. # Initialize weights and apply final processing
  829. self.post_init()
  830. self.is_export = config.is_export
  831. def get_input_embeddings(self):
  832. return self.embed_tokens
  833. def set_input_embeddings(self, value):
  834. self.embed_tokens = value
  835. def forward(
  836. self,
  837. input_ids=None,
  838. attention_mask=None,
  839. encoder_hidden_states=None,
  840. encoder_attention_mask=None,
  841. head_mask=None,
  842. cross_attn_head_mask=None,
  843. past_key_values=None,
  844. inputs_embeds=None,
  845. use_cache=None,
  846. output_attentions=None,
  847. output_hidden_states=None,
  848. return_dict=None,
  849. ):
  850. output_attentions = (
  851. output_attentions
  852. if output_attentions is not None
  853. else self.config.output_attentions
  854. )
  855. output_hidden_states = (
  856. output_hidden_states
  857. if output_hidden_states is not None
  858. else self.config.output_hidden_states
  859. )
  860. use_cache = use_cache if use_cache is not None else self.config.use_cache
  861. return_dict = (
  862. return_dict if return_dict is not None else self.config.use_return_dict
  863. )
  864. if input_ids is not None and inputs_embeds is not None:
  865. raise ValueError(
  866. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  867. )
  868. elif input_ids is not None:
  869. input = input_ids
  870. input_shape = input.shape
  871. input_ids = input_ids.reshape([-1, input_shape[-1]])
  872. elif inputs_embeds is not None:
  873. input_shape = inputs_embeds.shape[:-1]
  874. input = inputs_embeds[:, :, -1]
  875. else:
  876. raise ValueError(
  877. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  878. )
  879. past_key_values_length = (
  880. past_key_values[0][0].shape[2] if past_key_values is not None else 0
  881. )
  882. if inputs_embeds is None:
  883. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  884. if self._use_flash_attention_2:
  885. attention_mask = (
  886. attention_mask
  887. if (attention_mask is not None and 0 in attention_mask)
  888. else None
  889. )
  890. else:
  891. attention_mask = _prepare_4d_causal_attention_mask(
  892. attention_mask,
  893. input_shape,
  894. inputs_embeds,
  895. past_key_values_length,
  896. is_export=self.is_export,
  897. )
  898. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  899. if self._use_flash_attention_2:
  900. encoder_attention_mask = (
  901. encoder_attention_mask if 0 in encoder_attention_mask else None
  902. )
  903. else:
  904. encoder_attention_mask = _prepare_4d_attention_mask(
  905. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  906. )
  907. # embed positions
  908. positions = self.embed_positions(input, past_key_values_length)
  909. hidden_states = inputs_embeds + positions
  910. hidden_states = self.layernorm_embedding(hidden_states)
  911. hidden_states = nn.functional.dropout(
  912. hidden_states, p=self.dropout, training=self.training
  913. )
  914. if self.gradient_checkpointing and self.training:
  915. if use_cache:
  916. print(
  917. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  918. )
  919. use_cache = False
  920. all_hidden_states = () if output_hidden_states else None
  921. all_self_attns = () if output_attentions else None
  922. all_cross_attentions = (
  923. () if (output_attentions and encoder_hidden_states is not None) else None
  924. )
  925. next_decoder_cache = () if use_cache else None
  926. for attn_mask, mask_name in zip(
  927. [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
  928. ):
  929. if attn_mask is not None:
  930. if attn_mask.shape[0] != len(self.layers):
  931. raise ValueError(
  932. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  933. f" {attn_mask.shape[0]}."
  934. )
  935. for idx, decoder_layer in enumerate(self.layers):
  936. if output_hidden_states:
  937. all_hidden_states += (hidden_states,)
  938. if self.training:
  939. dropout_probability = torch.rand([])
  940. if dropout_probability < self.layerdrop:
  941. continue
  942. past_key_value = (
  943. past_key_values[idx] if past_key_values is not None else None
  944. )
  945. if self.gradient_checkpointing and self.training:
  946. layer_outputs = self._gradient_checkpointing_func(
  947. decoder_layer.__call__,
  948. hidden_states,
  949. attention_mask,
  950. encoder_hidden_states,
  951. encoder_attention_mask,
  952. head_mask[idx] if head_mask is not None else None,
  953. (
  954. cross_attn_head_mask[idx]
  955. if cross_attn_head_mask is not None
  956. else None
  957. ),
  958. None,
  959. output_attentions,
  960. use_cache,
  961. )
  962. else:
  963. layer_outputs = decoder_layer(
  964. hidden_states,
  965. attention_mask=attention_mask,
  966. encoder_hidden_states=encoder_hidden_states,
  967. encoder_attention_mask=encoder_attention_mask,
  968. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  969. cross_attn_layer_head_mask=(
  970. cross_attn_head_mask[idx]
  971. if cross_attn_head_mask is not None
  972. else None
  973. ),
  974. past_key_value=past_key_value,
  975. output_attentions=output_attentions,
  976. use_cache=use_cache,
  977. )
  978. hidden_states = layer_outputs[0]
  979. if use_cache:
  980. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  981. if output_attentions:
  982. all_self_attns += (layer_outputs[1],)
  983. if encoder_hidden_states is not None:
  984. all_cross_attentions += (layer_outputs[2],)
  985. hidden_states = self.layer_norm(hidden_states)
  986. if output_hidden_states:
  987. all_hidden_states += (hidden_states,)
  988. next_cache = next_decoder_cache if use_cache else None
  989. if not return_dict:
  990. return tuple(
  991. v
  992. for v in [
  993. hidden_states,
  994. next_cache,
  995. all_hidden_states,
  996. all_self_attns,
  997. all_cross_attentions,
  998. ]
  999. if v is not None
  1000. )
  1001. return BaseModelOutputWithPastAndCrossAttentions(
  1002. last_hidden_state=hidden_states,
  1003. past_key_values=next_cache,
  1004. hidden_states=all_hidden_states,
  1005. attentions=all_self_attns,
  1006. cross_attentions=all_cross_attentions,
  1007. )
  1008. class MBartDecoderWrapper(MBartPreTrainedModel):
  1009. """
  1010. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1011. used in combination with the [`EncoderDecoderModel`] framework.
  1012. """
  1013. def __init__(self, config):
  1014. super().__init__(config)
  1015. self.decoder = MBartDecoder(config)
  1016. def forward(self, *args, **kwargs):
  1017. return self.decoder(*args, **kwargs)
  1018. def _in_projection(
  1019. q: torch.Tensor,
  1020. k: torch.Tensor,
  1021. v: torch.Tensor,
  1022. w_q: torch.Tensor,
  1023. w_k: torch.Tensor,
  1024. w_v: torch.Tensor,
  1025. b_q: Optional[torch.Tensor] = None,
  1026. b_k: Optional[torch.Tensor] = None,
  1027. b_v: Optional[torch.Tensor] = None,
  1028. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  1029. Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
  1030. assert w_q.shape == (
  1031. Eq,
  1032. Eq,
  1033. ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
  1034. assert w_k.shape == (
  1035. Eq,
  1036. Ek,
  1037. ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
  1038. assert w_v.shape == (
  1039. Eq,
  1040. Ev,
  1041. ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
  1042. assert b_q is None or b_q.shape == (
  1043. Eq,
  1044. ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
  1045. assert b_k is None or b_k.shape == (
  1046. Eq,
  1047. ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
  1048. assert b_v is None or b_v.shape == (
  1049. Eq,
  1050. ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
  1051. return linear(q, w_q.T, b_q), linear(k, w_k.T, b_k), linear(v, w_v.T, b_v)
  1052. def _scaled_dot_product_attention(
  1053. q: torch.Tensor,
  1054. k: torch.Tensor,
  1055. v: torch.Tensor,
  1056. attn_mask: Optional[torch.Tensor] = None,
  1057. dropout_p: float = 0.0,
  1058. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1059. B, Nt, E = q.shape
  1060. q = q / math.sqrt(E)
  1061. attn = torch.bmm(q, k.permute([0, 2, 1]))
  1062. if attn_mask is not None:
  1063. attn += attn_mask
  1064. attn = F.softmax(attn, dim=-1)
  1065. if dropout_p > 0.0:
  1066. attn = F.dropout(attn, p=dropout_p)
  1067. output = torch.bmm(attn, v)
  1068. return output, attn
  1069. def linear(x, w, b, is_transpose):
  1070. if is_transpose:
  1071. w = w.T
  1072. if b is not None:
  1073. return torch.matmul(x, w) + b
  1074. else:
  1075. return torch.matmul(x, w)
  1076. def _in_projection_packed(
  1077. q: Tensor,
  1078. k: Tensor,
  1079. v: Tensor,
  1080. w: Tensor,
  1081. b: Optional[Tensor] = None,
  1082. is_export=False,
  1083. ) -> List[Tensor]:
  1084. E = q.shape[-1]
  1085. if k is v:
  1086. if q is k:
  1087. proj = linear(q, w, b, is_transpose=True)
  1088. if is_export:
  1089. B, D, L = proj.shape
  1090. proj = proj.reshape([B, D, 3, E])
  1091. proj = (
  1092. proj.unsqueeze(0)
  1093. .permute([3, 1, 2, 0, 4])
  1094. .squeeze(-2)
  1095. .contiguous()
  1096. )
  1097. else:
  1098. proj = (
  1099. proj.unflatten(-1, (3, E))
  1100. .unsqueeze(0)
  1101. .permute([3, 1, 2, 0, 4])
  1102. .squeeze(-2)
  1103. .contiguous()
  1104. )
  1105. return proj[0], proj[1], proj[2]
  1106. else:
  1107. w_q, w_k, w_v = w.chunk(3)
  1108. if b is None:
  1109. b_q = b_k = b_v = None
  1110. else:
  1111. b_q, b_k, b_v = b.chunk(3)
  1112. return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
  1113. def multi_head_attention_forward(
  1114. query: torch.Tensor,
  1115. key: torch.Tensor,
  1116. value: torch.Tensor,
  1117. embed_dim_to_check: int,
  1118. num_heads: int,
  1119. in_proj_weight: torch.Tensor,
  1120. in_proj_bias: Optional[torch.Tensor],
  1121. bias_k: Optional[torch.Tensor],
  1122. bias_v: Optional[torch.Tensor],
  1123. add_zero_attn: bool,
  1124. dropout_p: float,
  1125. out_proj_weight: torch.Tensor,
  1126. out_proj_bias: Optional[torch.Tensor],
  1127. training: bool = True,
  1128. key_padding_mask: Optional[torch.Tensor] = None,
  1129. need_weights: bool = True,
  1130. attn_mask: Optional[torch.Tensor] = None,
  1131. use_separate_proj_weight: bool = False,
  1132. q_proj_weight: Optional[torch.Tensor] = None,
  1133. k_proj_weight: Optional[torch.Tensor] = None,
  1134. v_proj_weight: Optional[torch.Tensor] = None,
  1135. static_k: Optional[torch.Tensor] = None,
  1136. static_v: Optional[torch.Tensor] = None,
  1137. is_export=False,
  1138. ):
  1139. tgt_len, bsz, embed_dim = query.shape
  1140. src_len, _, _ = key.shape
  1141. if isinstance(embed_dim, torch.Tensor):
  1142. head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
  1143. else:
  1144. head_dim = embed_dim // num_heads
  1145. q, k, v = _in_projection_packed(
  1146. query, key, value, in_proj_weight, in_proj_bias, is_export
  1147. )
  1148. if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
  1149. warnings.warn(
  1150. "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
  1151. )
  1152. key_padding_mask = key_padding_mask.to(torch.bool)
  1153. if bias_k is not None and bias_v is not None: # False
  1154. assert static_k is None, "bias cannot be added to static key."
  1155. assert static_v is None, "bias cannot be added to static value."
  1156. k = torch.concat([k, bias_k.repeat(1, bsz, 1)])
  1157. v = torch.concat([v, bias_v.repeat(1, bsz, 1)])
  1158. else:
  1159. assert bias_k is None
  1160. assert bias_v is None
  1161. q = q.reshape([tgt_len, bsz * num_heads, head_dim]).permute([1, 0, 2])
  1162. if static_k is None: # True
  1163. k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).permute([1, 0, 2])
  1164. else:
  1165. assert (
  1166. static_k.shape[0] == bsz * num_heads
  1167. ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.shape[0]}"
  1168. assert (
  1169. static_k.shape[2] == head_dim
  1170. ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.shape[2]}"
  1171. k = static_k
  1172. if static_v is None: # True
  1173. v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
  1174. else:
  1175. assert (
  1176. static_v.shape[0] == bsz * num_heads
  1177. ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.shape[0]}"
  1178. assert (
  1179. static_v.shape[2] == head_dim
  1180. ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.shape[2]}"
  1181. v = static_v
  1182. src_len = k.shape[1]
  1183. if not training:
  1184. dropout_p = 0.0
  1185. attn_output, attn_output_weights = _scaled_dot_product_attention(
  1186. q, k, v, attn_mask, dropout_p
  1187. )
  1188. attn_output = attn_output.permute([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
  1189. attn_output = linear(
  1190. attn_output, out_proj_weight, out_proj_bias, is_transpose=False
  1191. )
  1192. if need_weights:
  1193. attn_output_weights = attn_output_weights.reshape(
  1194. [bsz, num_heads, tgt_len, src_len]
  1195. )
  1196. return attn_output, attn_output_weights.sum(dim=1) / num_heads
  1197. else:
  1198. return attn_output, None
  1199. class MyMultiheadAttention(nn.Module):
  1200. """
  1201. Custom implementation of a multi-head attention layer.
  1202. Attributes:
  1203. __constants__ (list): List of constant attributes.
  1204. bias_k (Optional[paddle.Tensor]): Optional tensor for key bias.
  1205. bias_v (Optional[paddle.Tensor]): Optional tensor for value bias.
  1206. Args:
  1207. embed_dim (int): Total dimension of the model. This is the size of the input feature vectors.
  1208. num_heads (int): Number of parallel attention heads. The input dimension must be divisible by the number of heads.
  1209. dropout (float, optional): Dropout probability on the attention weights. Default is 0.0.
  1210. bias (bool, optional): If True, adds a learnable bias to the output. Default is True.
  1211. add_bias_kv (bool, optional): If True, adds bias to the key and value sequences. Default is False.
  1212. add_zero_attn (bool, optional): If True, adds a zero attention head. Default is False.
  1213. kdim (int, optional): Total number of features for keys. If None, defaults to embed_dim.
  1214. vdim (int, optional): Total number of features for values. If None, defaults to embed_dim.
  1215. batch_first (bool, optional): If True, the input and output tensors are provided as (batch, seq, feature). Default is False.
  1216. device (optional): The device on which the layer's parameters should be initialized. Default is None.
  1217. dtype (optional): The data type for the parameters. Default is None.
  1218. is_export (bool, optional): If True, the layer is set up for export, potentially changing behavior for compatibility. Default is False.
  1219. """
  1220. __constants__ = ["batch_first"]
  1221. bias_k: Optional[torch.Tensor]
  1222. bias_v: Optional[torch.Tensor]
  1223. def __init__(
  1224. self,
  1225. embed_dim,
  1226. num_heads,
  1227. dropout=0.0,
  1228. bias=True,
  1229. add_bias_kv=False,
  1230. add_zero_attn=False,
  1231. kdim=None,
  1232. vdim=None,
  1233. batch_first=False,
  1234. device=None,
  1235. dtype=None,
  1236. is_export=False,
  1237. ) -> None:
  1238. super(MyMultiheadAttention, self).__init__()
  1239. self.embed_dim = embed_dim
  1240. self.kdim = kdim if kdim is not None else embed_dim
  1241. self.vdim = vdim if vdim is not None else embed_dim
  1242. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  1243. self.num_heads = num_heads
  1244. self.dropout = dropout
  1245. self.batch_first = batch_first
  1246. self.head_dim = embed_dim // num_heads
  1247. self.is_export = is_export
  1248. assert (
  1249. self.head_dim * num_heads == self.embed_dim
  1250. ), "embed_dim must be divisible by num_heads"
  1251. if self._qkv_same_embed_dim is False:
  1252. pass
  1253. else:
  1254. if dtype is None:
  1255. dtype = torch.float32
  1256. self.in_proj_weight = torch.nn.Parameter(torch.randn(3 * embed_dim, embed_dim) * 0.01)
  1257. self.q_proj_weight = None
  1258. self.k_proj_weight = None
  1259. self.v_proj_weight = None
  1260. if bias:
  1261. self.in_proj_bias = torch.nn.Parameter(torch.randn(3 * embed_dim, ) * 0.01)
  1262. torch.nn.init.zeros_(self.in_proj_bias)
  1263. else:
  1264. self.in_proj_bias = None
  1265. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  1266. if add_bias_kv:
  1267. pass
  1268. else:
  1269. self.bias_k = self.bias_v = None
  1270. self.add_zero_attn = add_zero_attn
  1271. self._reset_parameters()
  1272. def _reset_parameters(self):
  1273. if self._qkv_same_embed_dim:
  1274. torch.nn.init.xavier_normal_(self.in_proj_weight)
  1275. else:
  1276. torch.nn.init.xavier_normal_(self.q_proj_weight)
  1277. torch.nn.init.xavier_normal_(self.k_proj_weight)
  1278. torch.nn.init.xavier_normal_(self.v_proj_weight)
  1279. if self.in_proj_bias is not None:
  1280. torch.nn.init.zeros_(self.in_proj_bias)
  1281. torch.nn.init.zeros_(self.out_proj.bias)
  1282. if self.bias_k is not None:
  1283. torch.nn.init.xavier_normal_(self.bias_k)
  1284. if self.bias_v is not None:
  1285. torch.nn.init.xavier_normal_(self.bias_v)
  1286. def forward(
  1287. self,
  1288. query: torch.Tensor,
  1289. key: torch.Tensor,
  1290. value: torch.Tensor,
  1291. key_padding_mask: Optional[torch.Tensor] = None,
  1292. need_weights: bool = True,
  1293. attn_mask: Optional[torch.Tensor] = None,
  1294. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  1295. attn_output, attn_output_weights = multi_head_attention_forward(
  1296. query,
  1297. key,
  1298. value,
  1299. self.embed_dim,
  1300. self.num_heads,
  1301. self.in_proj_weight,
  1302. self.in_proj_bias,
  1303. self.bias_k,
  1304. self.bias_v,
  1305. self.add_zero_attn,
  1306. self.dropout,
  1307. self.out_proj.weight,
  1308. self.out_proj.bias,
  1309. training=self.training,
  1310. key_padding_mask=key_padding_mask,
  1311. need_weights=need_weights,
  1312. attn_mask=attn_mask,
  1313. is_export=self.is_export,
  1314. )
  1315. return attn_output, attn_output_weights
  1316. class LogitsProcessorList(list):
  1317. """
  1318. A list of logits processors that can be applied sequentially.
  1319. Methods:
  1320. __call__(input_ids, scores, **kwargs): Apply all processors to the given inputs.
  1321. """
  1322. def __call__(self, input_ids, scores, **kwargs):
  1323. for processor in self:
  1324. function_args = inspect.signature(processor.__call__).parameters
  1325. if len(function_args) > 2:
  1326. if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
  1327. raise ValueError(
  1328. f"Make sure that all the required parameters: {list(function_args.keys())} for "
  1329. f"{processor.__class__} are passed to the logits processor."
  1330. )
  1331. scores = processor(input_ids, scores, **kwargs)
  1332. else:
  1333. scores = processor(input_ids, scores)
  1334. return scores
  1335. class ForcedEOSTokenLogitsProcessor(object):
  1336. """
  1337. A processor that forces the generation of an end-of-sequence (EOS) token
  1338. at a specified position in the sequence.
  1339. This is typically used in language generation tasks to ensure that the
  1340. generated sequence ends properly when it reaches a certain length.
  1341. Args:
  1342. max_length (int): The maximum length of the sequence. Forces EOS when this length is reached.
  1343. eos_token_id (Union[int, List[int]]): The ID(s) of the EOS token(s) to be forced in the sequence.
  1344. """
  1345. def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
  1346. self.max_length = max_length
  1347. if isinstance(eos_token_id, int):
  1348. eos_token_id = [eos_token_id]
  1349. self.eos_token_id = eos_token_id
  1350. def __call__(self, input_ids, scores):
  1351. cur_len = input_ids.shape[-1]
  1352. scores_processed = scores
  1353. if cur_len == self.max_length - 1:
  1354. scores_processed = torch.full_like(scores, -math.inf)
  1355. scores_processed[:, self.eos_token_id] = 0
  1356. return scores_processed
  1357. @dataclass
  1358. class CausalLMOutputWithCrossAttentions(ModelOutput):
  1359. loss = None
  1360. logits = None
  1361. past_key_values = None
  1362. hidden_states = None
  1363. attentions = None
  1364. cross_attentions = None
  1365. def __init__(self, *args, **kwargs):
  1366. super().__init__(*args, **kwargs)
  1367. @dataclass
  1368. class CausalLMOutputWithCrossAttentionsAndCounting(ModelOutput):
  1369. """
  1370. Base class for causal language model (or autoregressive) outputs.
  1371. """
  1372. logits = None
  1373. counting = None
  1374. past_key_values = None
  1375. hidden_states = None
  1376. attentions = None
  1377. cross_attentions = None
  1378. def __init__(self, *args, **kwargs):
  1379. super().__init__(*args, **kwargs)
  1380. class CustomMBartDecoder(MBartDecoder):
  1381. """
  1382. A custom MBartDecoder that includes additional processing layers.
  1383. This class extends the MBartDecoder by adding a customizable neural network
  1384. component called `counting_context_weight`, which applies a series of linear
  1385. transformations followed by ReLU activations. This can be used to modify or
  1386. enhance the decoder's behavior for specific tasks.
  1387. Args:
  1388. config: The configuration object containing model parameters.
  1389. """
  1390. def __init__(self, config):
  1391. super().__init__(config)
  1392. hidden_size = config.d_model
  1393. self.is_export = config.is_export
  1394. self.counting_context_weight = nn.Sequential(
  1395. nn.Linear(config.vocab_size, hidden_size),
  1396. nn.ReLU(),
  1397. nn.Linear(hidden_size, hidden_size),
  1398. nn.ReLU(),
  1399. nn.Linear(hidden_size, config.d_model),
  1400. )
  1401. def forward(
  1402. self,
  1403. input_ids=None,
  1404. attention_mask=None,
  1405. count_pred=None,
  1406. encoder_hidden_states=None,
  1407. encoder_attention_mask=None,
  1408. head_mask=None,
  1409. cross_attn_head_mask=None,
  1410. past_key_values=None,
  1411. inputs_embeds=None,
  1412. use_cache=None,
  1413. output_attentions=None,
  1414. output_hidden_states=None,
  1415. return_dict=None,
  1416. ):
  1417. self.is_export = False if self.training else True
  1418. output_attentions = (
  1419. output_attentions
  1420. if output_attentions is not None
  1421. else self.config.output_attentions
  1422. )
  1423. output_hidden_states = (
  1424. output_hidden_states
  1425. if output_hidden_states is not None
  1426. else self.config.output_hidden_states
  1427. )
  1428. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1429. return_dict = (
  1430. return_dict if return_dict is not None else self.config.use_return_dict
  1431. )
  1432. if input_ids is not None and inputs_embeds is not None:
  1433. raise ValueError(
  1434. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  1435. )
  1436. elif input_ids is not None:
  1437. input = input_ids
  1438. input_shape = input.shape
  1439. input_ids = input_ids.reshape([-1, input_shape[-1]])
  1440. elif inputs_embeds is not None:
  1441. input_shape = inputs_embeds.shape[:-1]
  1442. input = inputs_embeds[:, :, -1]
  1443. else:
  1444. raise ValueError(
  1445. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  1446. )
  1447. past_key_values_length = (
  1448. past_key_values[0][0].shape[2] if past_key_values is not None else 0
  1449. )
  1450. if inputs_embeds is None:
  1451. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  1452. if self._use_flash_attention_2:
  1453. attention_mask = (
  1454. attention_mask
  1455. if (attention_mask is not None and 0 in attention_mask)
  1456. else None
  1457. )
  1458. else:
  1459. if self.is_export:
  1460. attention_mask = _prepare_4d_causal_attention_mask_export(
  1461. attention_mask,
  1462. input_shape,
  1463. inputs_embeds,
  1464. past_key_values_length,
  1465. is_export=self.is_export,
  1466. ).to(torch.float32)
  1467. else:
  1468. attention_mask = _prepare_4d_causal_attention_mask(
  1469. attention_mask,
  1470. input_shape,
  1471. inputs_embeds,
  1472. past_key_values_length,
  1473. is_export=self.is_export,
  1474. )
  1475. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  1476. if self._use_flash_attention_2:
  1477. encoder_attention_mask = (
  1478. encoder_attention_mask if 0 in encoder_attention_mask else None
  1479. )
  1480. else:
  1481. encoder_attention_mask = _prepare_4d_attention_mask(
  1482. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  1483. )
  1484. # embed positions
  1485. positions = self.embed_positions(input, past_key_values_length)
  1486. hidden_states = inputs_embeds + positions
  1487. # TODO: add counting context weight to hidden_states
  1488. if count_pred is not None:
  1489. count_context_weight = self.counting_context_weight(count_pred)
  1490. hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
  1491. hidden_states = self.layernorm_embedding(hidden_states)
  1492. hidden_states = nn.functional.dropout(
  1493. hidden_states, p=self.dropout, training=self.training
  1494. )
  1495. if self.gradient_checkpointing and self.training:
  1496. if use_cache:
  1497. print(
  1498. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  1499. )
  1500. use_cache = False
  1501. # decoder layers
  1502. all_hidden_states = () if output_hidden_states else None
  1503. all_self_attns = () if output_attentions else None
  1504. all_cross_attentions = (
  1505. () if (output_attentions and encoder_hidden_states is not None) else None
  1506. )
  1507. next_decoder_cache = () if use_cache else None
  1508. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  1509. for attn_mask, mask_name in zip(
  1510. [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
  1511. ):
  1512. if attn_mask is not None:
  1513. if attn_mask.size()[0] != len(self.layers):
  1514. raise ValueError(
  1515. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  1516. f" {attn_mask.size()[0]}."
  1517. )
  1518. for idx, decoder_layer in enumerate(self.layers):
  1519. if output_hidden_states:
  1520. all_hidden_states += (hidden_states,)
  1521. if self.training:
  1522. dropout_probability = torch.rand()
  1523. if dropout_probability < self.layerdrop:
  1524. continue
  1525. past_key_value = (
  1526. past_key_values[idx] if past_key_values is not None else None
  1527. )
  1528. if self.gradient_checkpointing and self.training:
  1529. layer_outputs = self._gradient_checkpointing_func(
  1530. decoder_layer.__call__,
  1531. hidden_states,
  1532. attention_mask,
  1533. encoder_hidden_states,
  1534. encoder_attention_mask,
  1535. head_mask[idx] if head_mask is not None else None,
  1536. (
  1537. cross_attn_head_mask[idx]
  1538. if cross_attn_head_mask is not None
  1539. else None
  1540. ),
  1541. None,
  1542. output_attentions,
  1543. use_cache,
  1544. )
  1545. else:
  1546. layer_outputs = decoder_layer(
  1547. hidden_states,
  1548. attention_mask=attention_mask,
  1549. encoder_hidden_states=encoder_hidden_states,
  1550. encoder_attention_mask=encoder_attention_mask,
  1551. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1552. cross_attn_layer_head_mask=(
  1553. cross_attn_head_mask[idx]
  1554. if cross_attn_head_mask is not None
  1555. else None
  1556. ),
  1557. past_key_value=past_key_value,
  1558. output_attentions=output_attentions,
  1559. use_cache=use_cache,
  1560. )
  1561. hidden_states = layer_outputs[0]
  1562. if self.is_export:
  1563. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  1564. else:
  1565. if use_cache:
  1566. next_decoder_cache += (
  1567. layer_outputs[3 if output_attentions else 1],
  1568. )
  1569. if output_attentions:
  1570. all_self_attns += (layer_outputs[1],)
  1571. if encoder_hidden_states is not None:
  1572. all_cross_attentions += (layer_outputs[2],)
  1573. hidden_states = self.layer_norm(hidden_states)
  1574. if output_hidden_states:
  1575. all_hidden_states += (hidden_states,)
  1576. if self.is_export:
  1577. next_cache = next_decoder_cache
  1578. else:
  1579. next_cache = next_decoder_cache if use_cache else None
  1580. if not self.is_export:
  1581. if not return_dict:
  1582. return tuple(
  1583. v
  1584. for v in [
  1585. hidden_states,
  1586. next_cache,
  1587. all_hidden_states,
  1588. all_self_attns,
  1589. all_cross_attentions,
  1590. ]
  1591. if v is not None
  1592. )
  1593. return BaseModelOutputWithPastAndCrossAttentions(
  1594. last_hidden_state=hidden_states,
  1595. past_key_values=next_cache,
  1596. hidden_states=all_hidden_states,
  1597. attentions=all_self_attns,
  1598. cross_attentions=all_cross_attentions,
  1599. )
  1600. class SelfAttentionBlock(nn.Module):
  1601. """
  1602. A self-attention block that implements multi-head self-attention
  1603. followed by a feed-forward network, typically used in transformer architectures.
  1604. Args:
  1605. embed_size (int): The size of the embedding vector.
  1606. num_heads (int): The number of attention heads.
  1607. is_export (bool): Flag indicating whether to configure the layer for export.
  1608. """
  1609. def __init__(self, embed_size, num_heads, is_export):
  1610. super(SelfAttentionBlock, self).__init__()
  1611. self.self_attention = MyMultiheadAttention(
  1612. embed_dim=embed_size, num_heads=num_heads, is_export=is_export
  1613. )
  1614. self.norm = nn.LayerNorm(embed_size)
  1615. def forward(self, x):
  1616. attn_output, _ = self.self_attention(x, x, x)
  1617. x = self.norm(attn_output + x)
  1618. return x
  1619. class SeqCountingDecoder(nn.Module):
  1620. """
  1621. A custom sequence counting decoder that incorporates multi-head attention layers
  1622. and feed-forward networks to process sequences, potentially for latex code counting .
  1623. Args:
  1624. in_features (int): The number of input features.
  1625. out_features (int): The number of output features.
  1626. num_heads (int): The number of attention heads. Defaults to 8.
  1627. num_layers (int): The number of attention layers. Defaults to 4.
  1628. is_export (bool): Flag indicating whether to configure the layer for export.
  1629. """
  1630. def __init__(
  1631. self, in_features, out_features, num_heads=8, num_layers=4, is_export=False
  1632. ):
  1633. super(SeqCountingDecoder, self).__init__()
  1634. self.attention_blocks = nn.ModuleList(
  1635. [
  1636. SelfAttentionBlock(
  1637. embed_size=in_features, num_heads=num_heads, is_export=is_export
  1638. )
  1639. for i in range(num_layers)
  1640. ]
  1641. )
  1642. self.fc1 = nn.Linear(in_features, in_features // 2)
  1643. self.relu = nn.ReLU()
  1644. self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
  1645. self.fc2 = nn.Linear(in_features // 2, out_features)
  1646. def forward(self, x):
  1647. for block in self.attention_blocks:
  1648. x = block(x)
  1649. x = self.fc1(x)
  1650. x = self.relu(x)
  1651. x = x.transpose([0, 2, 1])
  1652. x = self.global_avg_pool(x)
  1653. x = x.squeeze(-1)
  1654. x = self.fc2(x)
  1655. return x
  1656. class CustomMBartForCausalLM(MBartForCausalLM):
  1657. """
  1658. Custom MBart model for causal language modeling with a custom decoder.
  1659. This class extends the MBartForCausalLM by replacing its decoder with a
  1660. custom decoder, allowing for additional flexibility and features in the
  1661. decoding process.
  1662. Args:
  1663. config: The configuration object containing model parameters.
  1664. length_aware (bool): A flag to enable or configure length-aware mechanisms.
  1665. """
  1666. def __init__(self, config, length_aware=True):
  1667. super().__init__(config)
  1668. self.model.decoder = CustomMBartDecoder(config)
  1669. self.counting_decoder = SeqCountingDecoder(
  1670. config.d_model, config.vocab_size, is_export=config.is_export
  1671. )
  1672. self.length_aware = length_aware
  1673. def forward(
  1674. self,
  1675. input_ids=None,
  1676. attention_mask=None,
  1677. encoder_hidden_states=None,
  1678. encoder_attention_mask=None,
  1679. head_mask=None,
  1680. cross_attn_head_mask=None,
  1681. past_key_values=None,
  1682. inputs_embeds=None,
  1683. labels=None,
  1684. use_cache=None,
  1685. output_attentions=None,
  1686. output_hidden_states=None,
  1687. return_dict=None,
  1688. count_gt=None,
  1689. ):
  1690. output_attentions = (
  1691. output_attentions
  1692. if output_attentions is not None
  1693. else self.config.output_attentions
  1694. )
  1695. output_hidden_states = (
  1696. output_hidden_states
  1697. if output_hidden_states is not None
  1698. else self.config.output_hidden_states
  1699. )
  1700. return_dict = (
  1701. return_dict if return_dict is not None else self.config.use_return_dict
  1702. )
  1703. if self.length_aware:
  1704. count_pred = self.counting_decoder(encoder_hidden_states)
  1705. else:
  1706. count_pred = None
  1707. outputs = self.model.decoder(
  1708. input_ids=input_ids,
  1709. attention_mask=attention_mask,
  1710. count_pred=count_pred,
  1711. encoder_hidden_states=encoder_hidden_states,
  1712. encoder_attention_mask=encoder_attention_mask,
  1713. head_mask=head_mask,
  1714. cross_attn_head_mask=cross_attn_head_mask,
  1715. past_key_values=past_key_values,
  1716. inputs_embeds=inputs_embeds,
  1717. use_cache=use_cache,
  1718. output_attentions=output_attentions,
  1719. output_hidden_states=output_hidden_states,
  1720. return_dict=return_dict,
  1721. )
  1722. logits = self.lm_head(outputs[0])
  1723. return CausalLMOutputWithCrossAttentionsAndCounting(
  1724. logits=logits,
  1725. counting=count_pred,
  1726. past_key_values=outputs.past_key_values,
  1727. hidden_states=outputs.hidden_states,
  1728. attentions=outputs.attentions,
  1729. cross_attentions=outputs.cross_attentions,
  1730. )
  1731. class UniMERNetHead(nn.Module):
  1732. """Implementation of UniMERNetHead decoder.
  1733. Args:
  1734. max_new_tokens (int): Maximum number of new tokens to generate.
  1735. decoder_start_token_id (int): ID of the token that starts the decoding.
  1736. temperature (float): Sampling temperature for generation.
  1737. do_sample (bool): Whether to use sampling; if False, uses greedy decoding.
  1738. top_p (float): Top-p (nucleus) sampling parameter.
  1739. in_channels (int): Number of input channels/features.
  1740. encoder_hidden_size (int): Hidden size of the encoder.
  1741. decoder_hidden_size (int): Hidden size of the decoder.
  1742. decoder_ffn_dim (int): Dimension of the decoder's feed-forward network.
  1743. decoder_layers (int): Number of layers in the decoder.
  1744. is_export (bool): Flag indicating if the model is being prepared for export.
  1745. length_aware (bool): Flag to enable length-aware mechanisms.
  1746. """
  1747. def __init__(
  1748. self,
  1749. max_new_tokens=1536,
  1750. decoder_start_token_id=0,
  1751. temperature=0.2,
  1752. do_sample=False,
  1753. top_p=0.95,
  1754. in_channels=1024,
  1755. encoder_hidden_size=1024,
  1756. decoder_hidden_size=1024,
  1757. decoder_ffn_dim=4096,
  1758. decoder_layers=8,
  1759. is_export=False,
  1760. length_aware=True,
  1761. ):
  1762. super().__init__()
  1763. mbart_config_dict = {
  1764. "activation_dropout": 0.0,
  1765. "activation_function": "gelu",
  1766. "add_cross_attention": True,
  1767. "add_final_layer_norm": True,
  1768. "attention_dropout": 0.0,
  1769. "bos_token_id": 0,
  1770. "classifier_dropout": 0.0,
  1771. "d_model": decoder_hidden_size,
  1772. "decoder_attention_heads": 16,
  1773. "decoder_ffn_dim": decoder_ffn_dim,
  1774. "decoder_layerdrop": 0.0,
  1775. "decoder_layers": decoder_layers,
  1776. "dropout": 0.1,
  1777. "encoder_attention_heads": 16,
  1778. "encoder_ffn_dim": 4096,
  1779. "encoder_layerdrop": 0.0,
  1780. "encoder_layers": 12,
  1781. "eos_token_id": 2,
  1782. "forced_eos_token_id": 2,
  1783. "init_std": 0.02,
  1784. "is_decoder": True,
  1785. "is_encoder_decoder": False,
  1786. "output_hidden_states": False,
  1787. "max_position_embeddings": max_new_tokens,
  1788. "model_type": "mbart",
  1789. "num_hidden_layers": 12,
  1790. "pad_token_id": 1,
  1791. "scale_embedding": True,
  1792. "tie_word_embeddings": False,
  1793. "transformers_version": "4.40.0",
  1794. "use_cache": True,
  1795. "use_return_dict": True,
  1796. "vocab_size": 50000,
  1797. "_attn_implementation": "eager",
  1798. "hidden_size": decoder_hidden_size,
  1799. "is_export": is_export,
  1800. }
  1801. self.max_new_tokens = max_new_tokens
  1802. self.decoder_start_token_id = decoder_start_token_id
  1803. self.temperature = temperature
  1804. self.do_sample = do_sample
  1805. self.top_p = top_p
  1806. self.max_seq_len = max_new_tokens
  1807. self.config_decoder = MBartConfig(**mbart_config_dict)
  1808. self.encoder_hidden_size = encoder_hidden_size
  1809. self.is_export = self.config_decoder.is_export
  1810. self.decoder = CustomMBartForCausalLM(
  1811. self.config_decoder, length_aware=length_aware
  1812. )
  1813. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  1814. self.enc_to_dec_proj = nn.Linear(
  1815. self.encoder_hidden_size, self.config_decoder.hidden_size
  1816. )
  1817. generation_config = {
  1818. "max_length": 1537,
  1819. "forced_eos_token_id": 2,
  1820. }
  1821. self.eos_token_id = generation_config["forced_eos_token_id"]
  1822. self.pad_token_id = self.config_decoder.pad_token_id
  1823. self.logits_processor = LogitsProcessorList()
  1824. self.logits_processor.append(
  1825. ForcedEOSTokenLogitsProcessor(
  1826. generation_config["max_length"],
  1827. generation_config["forced_eos_token_id"],
  1828. )
  1829. )
  1830. def _get_decoder_start_token_id(
  1831. self, decoder_start_token_id=None, bos_token_id=None
  1832. ) -> int:
  1833. decoder_start_token_id = (
  1834. decoder_start_token_id
  1835. if decoder_start_token_id is not None
  1836. else self.generation_config.decoder_start_token_id
  1837. )
  1838. bos_token_id = (
  1839. bos_token_id
  1840. if bos_token_id is not None
  1841. else self.generation_config.bos_token_id
  1842. )
  1843. if decoder_start_token_id is not None:
  1844. return decoder_start_token_id
  1845. elif bos_token_id is not None:
  1846. return bos_token_id
  1847. raise ValueError(
  1848. "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
  1849. )
  1850. def _prepare_decoder_input_ids_for_generation(
  1851. self,
  1852. batch_size,
  1853. model_kwargs,
  1854. decoder_start_token_id=None,
  1855. bos_token_id=None,
  1856. ):
  1857. if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
  1858. decoder_input_ids = model_kwargs.pop("decoder_input_ids")
  1859. elif "input_ids" in model_kwargs:
  1860. decoder_input_ids = model_kwargs.pop("input_ids")
  1861. else:
  1862. decoder_input_ids = None
  1863. decoder_start_token_id = self._get_decoder_start_token_id(
  1864. decoder_start_token_id, bos_token_id
  1865. )
  1866. if isinstance(decoder_start_token_id, list):
  1867. if len(decoder_start_token_id) != batch_size:
  1868. raise ValueError(
  1869. f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
  1870. )
  1871. decoder_input_ids_start = torch.LongTensor(decoder_start_token_id)
  1872. decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
  1873. else:
  1874. decoder_input_ids_start = (
  1875. torch.ones(
  1876. (batch_size, 1),
  1877. dtype=torch.int64,
  1878. )
  1879. * decoder_start_token_id
  1880. )
  1881. if decoder_input_ids is None:
  1882. decoder_input_ids = decoder_input_ids_start
  1883. elif (
  1884. self.config.model_type == "vision-encoder-decoder"
  1885. and "donut" in self.name_or_path.lower()
  1886. ):
  1887. pass
  1888. elif self.config.model_type in ["whisper"]:
  1889. pass
  1890. elif (
  1891. isinstance(decoder_start_token_id, int)
  1892. and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
  1893. ) or (
  1894. isinstance(decoder_start_token_id, torch.Tensor)
  1895. and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
  1896. ):
  1897. decoder_input_ids = torch.concat(
  1898. [decoder_input_ids_start, decoder_input_ids], dim=-1
  1899. )
  1900. if "decoder_attention_mask" in model_kwargs:
  1901. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  1902. decoder_attention_mask = torch.cat(
  1903. (
  1904. torch.ones_like(decoder_attention_mask)[:, :1],
  1905. decoder_attention_mask,
  1906. ),
  1907. dim=-1,
  1908. )
  1909. model_kwargs["decoder_attention_mask"] = decoder_attention_mask
  1910. return decoder_input_ids, model_kwargs
  1911. def prepare_inputs_for_generation_mbart(
  1912. self,
  1913. input_ids,
  1914. past_key_values=None,
  1915. attention_mask=None,
  1916. use_cache=None,
  1917. **kwargs,
  1918. ):
  1919. if attention_mask is None:
  1920. attention_mask = torch.ones(input_ids.shape)
  1921. if past_key_values:
  1922. past_length = past_key_values[0][0].shape[2]
  1923. if input_ids.shape[1] > past_length:
  1924. remove_prefix_length = past_length
  1925. else:
  1926. remove_prefix_length = input_ids.shape[1] - 1
  1927. input_ids = input_ids[:, remove_prefix_length:]
  1928. return {
  1929. "input_ids": input_ids,
  1930. "attention_mask": attention_mask,
  1931. "past_key_values": past_key_values,
  1932. "use_cache": use_cache,
  1933. }
  1934. def prepare_inputs_for_generation(
  1935. self,
  1936. input_ids,
  1937. past_key_values=None,
  1938. attention_mask=None,
  1939. use_cache=None,
  1940. encoder_outputs=None,
  1941. **kwargs,
  1942. ):
  1943. decoder_inputs = self.prepare_inputs_for_generation_mbart(
  1944. input_ids, past_key_values=past_key_values
  1945. )
  1946. decoder_attention_mask = (
  1947. decoder_inputs["attention_mask"]
  1948. if "attention_mask" in decoder_inputs
  1949. else None
  1950. )
  1951. input_dict = {
  1952. "attention_mask": attention_mask,
  1953. "decoder_attention_mask": decoder_attention_mask,
  1954. "decoder_input_ids": decoder_inputs["input_ids"],
  1955. "encoder_outputs": encoder_outputs,
  1956. "past_key_values": decoder_inputs["past_key_values"],
  1957. "use_cache": use_cache,
  1958. }
  1959. return input_dict
  1960. def prepare_inputs_for_generation_export(
  1961. self,
  1962. past_key_values=None,
  1963. attention_mask=None,
  1964. use_cache=None,
  1965. encoder_outputs=None,
  1966. **kwargs,
  1967. ):
  1968. input_dict = {
  1969. "decoder_attention_mask": None,
  1970. "use_cache": use_cache,
  1971. }
  1972. return input_dict
  1973. def _extract_past_from_model_output(
  1974. self, outputs: ModelOutput, standardize_cache_format: bool = False
  1975. ):
  1976. past_key_values = None
  1977. if "past_key_values" in outputs:
  1978. past_key_values = outputs.past_key_values
  1979. elif "mems" in outputs:
  1980. past_key_values = outputs.mems
  1981. elif "past_buckets_states" in outputs:
  1982. past_key_values = outputs.past_buckets_states
  1983. return past_key_values
  1984. def _update_model_kwargs_for_generation(
  1985. self,
  1986. outputs: ModelOutput,
  1987. model_kwargs: Dict[str, Any],
  1988. is_encoder_decoder: bool = False,
  1989. standardize_cache_format: bool = False,
  1990. ) -> Dict[str, Any]:
  1991. model_kwargs["past_key_values"] = self._extract_past_from_model_output(
  1992. outputs, standardize_cache_format=standardize_cache_format
  1993. )
  1994. if getattr(outputs, "state", None) is not None:
  1995. model_kwargs["state"] = outputs.state
  1996. if "token_type_ids" in model_kwargs:
  1997. token_type_ids = model_kwargs["token_type_ids"]
  1998. model_kwargs["token_type_ids"] = torch.concat(
  1999. [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
  2000. )
  2001. if not is_encoder_decoder:
  2002. if "attention_mask" in model_kwargs:
  2003. attention_mask = model_kwargs["attention_mask"]
  2004. model_kwargs["attention_mask"] = torch.concat(
  2005. [
  2006. attention_mask,
  2007. attention_mask.new_ones((attention_mask.shape[0], 1)),
  2008. ],
  2009. dim=-1,
  2010. )
  2011. else:
  2012. if "decoder_attention_mask" in model_kwargs:
  2013. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  2014. model_kwargs["decoder_attention_mask"] = torch.concat(
  2015. [
  2016. decoder_attention_mask,
  2017. decoder_attention_mask.new_ones(
  2018. (decoder_attention_mask.shape[0], 1)
  2019. ),
  2020. ],
  2021. dim=-1,
  2022. )
  2023. if (
  2024. "cache_position" in model_kwargs
  2025. and model_kwargs["cache_position"] is not None
  2026. ):
  2027. model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
  2028. return model_kwargs
  2029. def stopping_criteria(self, input_ids):
  2030. if self.is_export:
  2031. return input_ids[:, -1] == torch.Tensor([self.eos_token_id])
  2032. is_done = torch.isin(input_ids[:, -1], torch.Tensor([self.eos_token_id]))
  2033. return is_done
  2034. def generate_single_iter(
  2035. self,
  2036. decoder_input_ids=None,
  2037. decoder_attention_mask=None,
  2038. encoder_outputs=None,
  2039. past_key_values=None,
  2040. decoder_inputs_embeds=None,
  2041. labels=None,
  2042. use_cache=None,
  2043. output_attentions=None,
  2044. output_hidden_states=None,
  2045. return_dict=None,
  2046. **kwargs,
  2047. ):
  2048. encoder_hidden_states = encoder_outputs[0]
  2049. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  2050. encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
  2051. kwargs_decoder = {}
  2052. decoder_outputs = self.decoder(
  2053. input_ids=decoder_input_ids,
  2054. attention_mask=decoder_attention_mask,
  2055. encoder_hidden_states=encoder_hidden_states,
  2056. encoder_attention_mask=None,
  2057. inputs_embeds=None,
  2058. output_attentions=False,
  2059. output_hidden_states=output_hidden_states,
  2060. use_cache=use_cache,
  2061. past_key_values=past_key_values,
  2062. return_dict=return_dict,
  2063. **kwargs_decoder,
  2064. )
  2065. return Seq2SeqLMOutput(
  2066. loss=None,
  2067. logits=decoder_outputs.logits,
  2068. past_key_values=decoder_outputs.past_key_values,
  2069. decoder_hidden_states=decoder_outputs.hidden_states,
  2070. decoder_attentions=decoder_outputs.attentions,
  2071. cross_attentions=decoder_outputs.cross_attentions,
  2072. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  2073. encoder_hidden_states=encoder_outputs.hidden_states,
  2074. encoder_attentions=encoder_outputs.attentions,
  2075. )
  2076. @torch.no_grad()
  2077. def generate(
  2078. self,
  2079. model_kwargs,
  2080. ):
  2081. """
  2082. Generate sequences using the UniMERNetHead for inference tasks.
  2083. Args:
  2084. model_kwargs (dict): A dictionary of model configurations and inputs, which typically include:
  2085. - encoder_outputs: Outputs from the encoder.
  2086. - use_cache: Boolean flag to indicate if caching should be used.
  2087. - output_attentions: Boolean flag for outputting attention scores.
  2088. - output_hidden_states: Boolean flag for outputting hidden states.
  2089. Returns:
  2090. A tensor containing the generated sequences.
  2091. """
  2092. batch_size = model_kwargs["encoder_outputs"]["last_hidden_state"].shape[0]
  2093. generation_config = {
  2094. "decoder_start_token_id": 0,
  2095. "bos_token_id": 0,
  2096. }
  2097. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  2098. batch_size=batch_size,
  2099. model_kwargs=model_kwargs,
  2100. decoder_start_token_id=generation_config["decoder_start_token_id"],
  2101. bos_token_id=generation_config["bos_token_id"],
  2102. )
  2103. model_kwargs["key use_cache"] = True
  2104. batch_size, cur_len = input_ids.shape
  2105. if "inputs_embeds" in model_kwargs:
  2106. cur_len = model_kwargs["inputs_embeds"].shape[1]
  2107. model_kwargs["cache_position"] = torch.arange(cur_len)
  2108. pad_token_id = self.pad_token_id
  2109. eos_token_id = [self.eos_token_id]
  2110. eos_token = self.eos_token_id
  2111. unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
  2112. for idx in range(self.max_seq_len):
  2113. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  2114. outputs = self.generate_single_iter(
  2115. **model_inputs,
  2116. return_dict=True,
  2117. output_attentions=False,
  2118. output_hidden_states=False,
  2119. )
  2120. next_token_logits = outputs.logits[:, -1, :]
  2121. next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
  2122. next_tokens = torch.argmax(next_tokens_scores, dim=-1)
  2123. if eos_token_id is not None:
  2124. if pad_token_id is None:
  2125. raise ValueError(
  2126. "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
  2127. )
  2128. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  2129. 1 - unfinished_sequences
  2130. )
  2131. input_ids = torch.concat([input_ids, next_tokens[:, None]], dim=-1)
  2132. model_kwargs = self._update_model_kwargs_for_generation(
  2133. outputs,
  2134. model_kwargs,
  2135. is_encoder_decoder=self.config_decoder.is_encoder_decoder,
  2136. )
  2137. unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
  2138. input_ids
  2139. ).to(torch.int64)
  2140. if (
  2141. eos_token is not None
  2142. and (
  2143. torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
  2144. >= 1
  2145. ).all()
  2146. ):
  2147. break
  2148. return input_ids
  2149. @torch.no_grad()
  2150. def generate_export(
  2151. self,
  2152. encoder_outputs,
  2153. model_kwargs,
  2154. ):
  2155. batch_size = encoder_outputs["last_hidden_state"].shape[0]
  2156. generation_config = {
  2157. "decoder_start_token_id": 0,
  2158. "bos_token_id": 0,
  2159. }
  2160. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  2161. batch_size=batch_size,
  2162. model_kwargs=model_kwargs,
  2163. decoder_start_token_id=generation_config["decoder_start_token_id"],
  2164. bos_token_id=generation_config["bos_token_id"],
  2165. )
  2166. input_ids = input_ids.reshape([-1, 1])
  2167. decoder_input_ids = input_ids
  2168. model_kwargs["key use_cache"] = True
  2169. batch_size, cur_len = input_ids.shape
  2170. if "inputs_embeds" in model_kwargs:
  2171. cur_len = model_kwargs["inputs_embeds"].shape[1]
  2172. cache_position = torch.arange(cur_len)
  2173. pad_token_id = self.pad_token_id
  2174. eos_token_id = [self.eos_token_id]
  2175. eos_token = self.eos_token_id
  2176. unfinished_sequences = torch.ones([batch_size], dtype=torch.int64)
  2177. i_idx = torch.full([], 0)
  2178. past_key_values = []
  2179. for i in range(8):
  2180. init_arr = torch.zeros([batch_size, 16, 0, 64])
  2181. cache = (init_arr, init_arr, init_arr, init_arr)
  2182. past_key_values.append(cache)
  2183. idx = 0
  2184. while i_idx < torch.Tensor(self.max_seq_len):
  2185. model_inputs = self.prepare_inputs_for_generation_export(
  2186. past_key_values=past_key_values, **model_kwargs
  2187. )
  2188. decoder_attention_mask = model_inputs["decoder_attention_mask"]
  2189. decoder_attention_mask = torch.ones(input_ids.shape)
  2190. outputs = self.generate_single_iter(
  2191. decoder_input_ids=decoder_input_ids,
  2192. decoder_attention_mask=decoder_attention_mask,
  2193. encoder_outputs=encoder_outputs,
  2194. past_key_values=past_key_values,
  2195. return_dict=True,
  2196. output_attentions=False,
  2197. output_hidden_states=False,
  2198. )
  2199. next_token_logits = outputs.logits[:, -1, :]
  2200. next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
  2201. next_tokens = torch.argmax(next_tokens_scores, dim=-1)
  2202. if eos_token_id is not None:
  2203. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  2204. 1 - unfinished_sequences
  2205. )
  2206. input_ids = torch.concat([input_ids, next_tokens.unsqueeze(1)], dim=-1)
  2207. past_length = past_key_values[0][0].shape[2]
  2208. decoder_input_ids = next_tokens.unsqueeze(1)
  2209. past_key_values = outputs.past_key_values
  2210. cache_position = cache_position[-1:] + 1
  2211. unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
  2212. input_ids
  2213. ).to(torch.int64)
  2214. if (
  2215. eos_token is not None
  2216. and (
  2217. torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
  2218. >= 1
  2219. ).all()
  2220. ):
  2221. break
  2222. i_idx += 1
  2223. return input_ids
  2224. def forwad_train(
  2225. self,
  2226. encoder_outputs,
  2227. decoder_input_ids,
  2228. decoder_attention_mask,
  2229. past_key_values=None,
  2230. decoder_inputs_embeds=None,
  2231. labels=None,
  2232. use_cache=None,
  2233. output_attentions=None,
  2234. output_hidden_states=None,
  2235. return_dict=None,
  2236. **kwargs,
  2237. ):
  2238. """
  2239. Training for the UniMERNetHead.
  2240. Args:
  2241. encoder_outputs: Outputs from the encoder, used as input to the decoder.
  2242. decoder_input_ids: Input IDs for the decoder.
  2243. decoder_attention_mask: Attention mask for the decoder inputs.
  2244. past_key_values: Cached key/values for faster decoding.
  2245. decoder_inputs_embeds: Optional embeddings for the decoder inputs.
  2246. labels: Target labels for calculating loss.
  2247. use_cache: Whether to use cache during decoding.
  2248. output_attentions: Whether to return attention scores.
  2249. output_hidden_states: Whether to return hidden states.
  2250. return_dict: Whether to return a dictionary of outputs.
  2251. **kwargs: Additional keyword arguments.
  2252. Returns:
  2253. logits: The raw, unnormalized predictions from the model.
  2254. count_pred: Optional prediction related to sequence length or other counts.
  2255. masked_labels: The labels used during training, possibly masked.
  2256. """
  2257. labels = decoder_input_ids * 1
  2258. labels = labels.masked_fill_(labels == self.pad_token_id, -100)
  2259. input_decoder_input_ids = decoder_input_ids[:, :-1]
  2260. input_decoder_attention_mask = decoder_attention_mask[:, :-1]
  2261. encoder_hidden_states = encoder_outputs[0]
  2262. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  2263. encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
  2264. kwargs_decoder = {}
  2265. decoder_outputs = self.decoder(
  2266. input_ids=input_decoder_input_ids,
  2267. attention_mask=input_decoder_attention_mask,
  2268. encoder_hidden_states=encoder_hidden_states,
  2269. encoder_attention_mask=None,
  2270. inputs_embeds=None,
  2271. output_attentions=False,
  2272. output_hidden_states=output_hidden_states,
  2273. use_cache=use_cache,
  2274. past_key_values=past_key_values,
  2275. return_dict=return_dict,
  2276. **kwargs_decoder,
  2277. )
  2278. logits = decoder_outputs.logits
  2279. count_pred = decoder_outputs.counting
  2280. return logits, count_pred, labels
  2281. def forward(self, inputs, targets=None):
  2282. """
  2283. Forward pass for the UniMERNetHead, handling both training and inference.
  2284. Args:
  2285. inputs: The input data, which can vary based on training or inference.
  2286. targets: The target labels, used only during training.
  2287. Returns:
  2288. During inference: Returns predicted latex code.
  2289. During training: Returns logits, predicted counts, and masked labels.
  2290. """
  2291. self.is_export = False if self.training else True
  2292. if not self.training:
  2293. encoder_outputs = inputs
  2294. if self.is_export:
  2295. model_kwargs = {
  2296. "output_attentions": False,
  2297. "output_hidden_states": False,
  2298. "use_cache": True,
  2299. }
  2300. word_pred = self.generate_export(encoder_outputs, model_kwargs)
  2301. else:
  2302. model_kwargs = {
  2303. "output_attentions": False,
  2304. "output_hidden_states": False,
  2305. "use_cache": True,
  2306. "encoder_outputs": encoder_outputs,
  2307. }
  2308. word_pred = self.generate(model_kwargs)
  2309. return word_pred
  2310. encoder_outputs, tgt_seq, mask = inputs
  2311. logits, count_pred, masked_labels = self.forwad_train(
  2312. encoder_outputs, tgt_seq, mask
  2313. )
  2314. return logits, count_pred, masked_labels