utils.py 83 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import inspect
  16. from typing import List, Optional, Union
  17. import paddle
  18. import paddle.distributed as dist
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle import Tensor
  22. from paddle.common_ops_import import convert_dtype
  23. from paddle.utils import map_structure
  24. from ......utils import logging
  25. from ..transformers.model_outputs import ModelOutput
  26. from .configuration_utils import DEFAULT_MAX_NEW_TOKENS, GenerationConfig
  27. from .logits_process import (
  28. ForcedBOSTokenLogitsProcessor,
  29. ForcedEOSTokenLogitsProcessor,
  30. HammingDiversityLogitsProcessor,
  31. LogitsProcessor,
  32. LogitsProcessorList,
  33. MinLengthLogitsProcessor,
  34. NoRepeatNGramLogitsProcessor,
  35. RepetitionPenaltyLogitsProcessor,
  36. TopKProcess,
  37. TopPProcess,
  38. )
  39. from .stopping_criteria import (
  40. StoppingCriteria,
  41. StoppingCriteriaList,
  42. validate_stopping_criteria,
  43. )
  44. __all__ = [
  45. "GenerationMixin",
  46. "BeamSearchScorer",
  47. "BeamHypotheses",
  48. "LogitsProcessorList",
  49. "LogitsProcessor",
  50. "MinLengthLogitsProcessor",
  51. "RepetitionPenaltyLogitsProcessor",
  52. "TopKProcess",
  53. "TopPProcess",
  54. "get_unfinished_flag",
  55. ]
  56. def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float:
  57. """get scale value by dtype
  58. Args:
  59. dtype (str): the string dtype value
  60. Returns:
  61. float: the scale value
  62. """
  63. if dtype is None:
  64. dtype = paddle.get_default_dtype()
  65. dtype = convert_dtype(dtype)
  66. scale_value = 1e6
  67. # TODO(wj-Mcaf): support int8, int4 dtypes later
  68. if dtype == "float16":
  69. scale_value = 1e4
  70. if return_positive:
  71. return scale_value
  72. return -1 * scale_value
  73. def get_unfinished_flag(
  74. input_ids: Tensor,
  75. unfinished_flag: Tensor,
  76. eos_token_id: Union[int, List[int], List[List[int]]],
  77. ) -> Tensor:
  78. """get unfinished flag for generation step
  79. Args:
  80. input_ids (Tensor): the input_ids
  81. eos_token_id (Union[int, list[int], list[list[int]]]): the end os sentence flag, which can be:
  82. * single token id, eg: 10
  83. * multiple token ids to stop generation, eg: [10, 10]
  84. * some more tokens to stop generations, eg: [[10], [20, 20], [30, 30, 30]]
  85. Returns:
  86. Tensor: the unfinished flag tensor
  87. """
  88. if isinstance(eos_token_id, int):
  89. unfinished_flag = paddle.logical_and(
  90. unfinished_flag, input_ids[:, -1:] != eos_token_id
  91. )
  92. else:
  93. batch_unfinish_flag = None
  94. for batch_eos_token_id in eos_token_id:
  95. if batch_unfinish_flag is None:
  96. batch_unfinish_flag = ~get_unfinished_flag(
  97. input_ids, unfinished_flag, batch_eos_token_id
  98. )
  99. else:
  100. batch_unfinish_flag = paddle.logical_or(
  101. batch_unfinish_flag,
  102. ~get_unfinished_flag(
  103. input_ids, unfinished_flag, batch_eos_token_id
  104. ),
  105. )
  106. unfinished_flag = ~batch_unfinish_flag
  107. return unfinished_flag
  108. class BeamHypotheses:
  109. def __init__(self, num_beams, length_penalty, early_stopping):
  110. """
  111. Initialize n-best list of hypotheses.
  112. """
  113. self.length_penalty = length_penalty
  114. self.early_stopping = early_stopping
  115. self.num_beams = num_beams
  116. self.beams = []
  117. self.worst_score = get_scale_by_dtype()
  118. def __len__(self):
  119. """
  120. Number of hypotheses in the list.
  121. """
  122. return len(self.beams)
  123. def add(self, hyp, sum_logprobs, origin_len=0):
  124. """
  125. Add a new hypothesis to the list.
  126. """
  127. score = sum_logprobs / (
  128. ((hyp.shape[-1] - origin_len + 5) / 6) ** self.length_penalty
  129. )
  130. if len(self) < self.num_beams or score > self.worst_score:
  131. self.beams.append((score, hyp))
  132. if len(self) > self.num_beams:
  133. sorted_next_scores = sorted(
  134. [(s, idx) for idx, (s, _) in enumerate(self.beams)]
  135. )
  136. del self.beams[sorted_next_scores[0][1]]
  137. self.worst_score = sorted_next_scores[1][0]
  138. else:
  139. self.worst_score = min(score, self.worst_score)
  140. def is_done(self, best_sum_logprobs, cur_len, origin_len=0):
  141. """
  142. If there are enough hypotheses and that none of the hypotheses being
  143. generated can become better than the worst one in the heap, then we
  144. are done with this sentence.
  145. """
  146. if len(self) < self.num_beams:
  147. return False
  148. elif self.early_stopping:
  149. return True
  150. else:
  151. cur_score = (
  152. best_sum_logprobs
  153. / ((cur_len - origin_len + 5) / 6) ** self.length_penalty
  154. )
  155. ret = self.worst_score >= cur_score
  156. return ret
  157. class BeamSearchScorer(object):
  158. """
  159. implementing standard beam search decoding.
  160. """
  161. def __init__(
  162. self,
  163. batch_size,
  164. max_length,
  165. num_beams,
  166. length_penalty=1.0,
  167. do_early_stopping=False,
  168. num_beam_hyps_to_keep=1,
  169. num_beam_groups=1,
  170. ):
  171. self.max_length = max_length
  172. self.num_beams = num_beams
  173. self.length_penalty = length_penalty
  174. self.do_early_stopping = do_early_stopping
  175. self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
  176. self.num_beam_groups = num_beam_groups
  177. self.group_size = self.num_beams // self.num_beam_groups
  178. self._is_init = False
  179. self._beam_hyps = [
  180. BeamHypotheses(
  181. num_beams=self.num_beams,
  182. length_penalty=self.length_penalty,
  183. early_stopping=self.do_early_stopping,
  184. )
  185. for _ in range(batch_size)
  186. ]
  187. self._done = paddle.to_tensor([0 for _ in range(batch_size)], dtype="int64")
  188. if not isinstance(num_beams, int) or num_beams <= 1:
  189. raise ValueError(
  190. "`num_beams` has to be an integer strictly greater than 1, but "
  191. "received {}. For `num_beams` == 1, one should make use of "
  192. "`greedy_search` instead.".format(num_beams)
  193. )
  194. if (
  195. not isinstance(num_beam_groups, int)
  196. or (num_beam_groups > num_beams)
  197. or (num_beams % num_beam_groups != 0)
  198. ):
  199. raise ValueError(
  200. "`num_beam_groups` has to be an integer smaller or equal than "
  201. "`num_beams` and `num_beams` has to be divisible by "
  202. "`num_beam_groups`, but received num_beam_groups={}, num_beams="
  203. "{}.".format(num_beam_groups, num_beams)
  204. )
  205. @property
  206. def is_done(self):
  207. return paddle.min(self._done) == 1
  208. def process(
  209. self,
  210. input_ids,
  211. next_scores,
  212. next_tokens,
  213. next_indices,
  214. origin_len=0,
  215. pad_token_id=None,
  216. eos_token_id=None,
  217. ):
  218. cur_len = input_ids.shape[-1]
  219. batch_size = len(self._beam_hyps)
  220. assert batch_size == (input_ids.shape[0] // self.group_size)
  221. next_beam_scores = paddle.zeros(
  222. [batch_size, self.group_size], dtype=next_scores.dtype
  223. )
  224. next_beam_tokens = paddle.zeros(
  225. [batch_size, self.group_size], dtype=next_tokens.dtype
  226. )
  227. next_beam_indices = paddle.zeros(
  228. [batch_size, self.group_size], dtype=next_indices.dtype
  229. )
  230. for batch_idx, beam_hyp in enumerate(self._beam_hyps):
  231. if self._done[batch_idx] == 1:
  232. assert (
  233. len(beam_hyp) >= self.num_beams
  234. ), "Batch can only be done if at least {} beams have been generated".format(
  235. self.num_beams
  236. )
  237. assert (
  238. eos_token_id is not None and pad_token_id is not None
  239. ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
  240. # pad the batch
  241. next_beam_scores[batch_idx, :] = 0
  242. next_beam_tokens[batch_idx, :] = pad_token_id
  243. next_beam_indices[batch_idx, :] = 0
  244. continue
  245. # next tokens for this sentence
  246. beam_idx = 0
  247. for beam_token_rank, (next_token, next_score, next_index) in enumerate(
  248. zip(
  249. next_tokens[batch_idx],
  250. next_scores[batch_idx],
  251. next_indices[batch_idx],
  252. )
  253. ):
  254. batch_beam_idx = batch_idx * self.group_size + next_index
  255. # add to generated hypotheses if end of sentence
  256. if (eos_token_id is not None) and (next_token.item() == eos_token_id):
  257. # If beam_token does not belong to top num_beams tokens,
  258. # it should not be added
  259. is_beam_token_worse_than_top_num_beams = (
  260. beam_token_rank >= self.group_size
  261. )
  262. if is_beam_token_worse_than_top_num_beams:
  263. continue
  264. beam_hyp.add(
  265. input_ids[batch_beam_idx.item()].clone(),
  266. next_score.item(),
  267. origin_len,
  268. )
  269. else:
  270. # add next predicted token since it is not eos_token
  271. next_beam_scores[batch_idx, beam_idx] = next_score
  272. next_beam_tokens[batch_idx, beam_idx] = next_token.item()
  273. next_beam_indices[batch_idx, beam_idx] = batch_beam_idx.item()
  274. beam_idx += 1
  275. # once the beam for next step is full, don't add more tokens to it.
  276. if beam_idx == self.group_size:
  277. break
  278. if beam_idx < self.group_size:
  279. raise ValueError(
  280. "At most {} tokens in `next_tokens[batch_idx]` can be equal "
  281. "to `eos_token_id: {}`. Make sure `next_tokens[batch_idx]` "
  282. "are corrected.".format(self.group_size, eos_token_id)
  283. )
  284. # Check if we are done so that we can save a pad step if all(done)
  285. if beam_hyp.is_done(
  286. next_scores[batch_idx].max().item(), cur_len, origin_len
  287. ):
  288. self._done[batch_idx] = 1
  289. return {
  290. "next_beam_scores": next_beam_scores.reshape([-1]),
  291. "next_beam_tokens": next_beam_tokens.reshape([-1]),
  292. "next_beam_indices": next_beam_indices.reshape([-1]),
  293. }
  294. def finalize(
  295. self,
  296. input_ids,
  297. final_beam_scores,
  298. final_beam_tokens,
  299. final_beam_indices,
  300. origin_len=0,
  301. pad_token_id=None,
  302. eos_token_id=None,
  303. ):
  304. batch_size = len(self._beam_hyps)
  305. # finalize all open beam hypotheses and add to generated hypotheses
  306. for batch_idx, beam_hyp in enumerate(self._beam_hyps):
  307. if self._done[batch_idx] == 1:
  308. continue
  309. # all open beam hypotheses are added to the beam hypothesis
  310. # beam hypothesis class automatically keeps the best beams
  311. for beam_id in range(self.num_beams):
  312. batch_beam_idx = batch_idx * self.num_beams + beam_id
  313. final_score = final_beam_scores[batch_beam_idx].item()
  314. final_tokens = input_ids[batch_beam_idx]
  315. beam_hyp.add(final_tokens, final_score, origin_len=origin_len)
  316. # select the best hypotheses
  317. sent_lengths = paddle.zeros(
  318. [batch_size * self.num_beam_hyps_to_keep], dtype=input_ids.dtype
  319. )
  320. best = []
  321. # retrieve best hypotheses
  322. for i, beam_hyp in enumerate(self._beam_hyps):
  323. sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
  324. for j in range(self.num_beam_hyps_to_keep):
  325. best_score, best_hyp = sorted_hyps.pop()
  326. sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
  327. best.append([best_hyp, best_score])
  328. # prepare for adding eos
  329. sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
  330. decoded = paddle.zeros(
  331. [batch_size * self.num_beam_hyps_to_keep, sent_max_len],
  332. dtype=input_ids.dtype,
  333. )
  334. # shorter batches are padded if needed
  335. if sent_lengths.min().item() != sent_lengths.max().item():
  336. assert pad_token_id is not None, "`pad_token_id` has to be defined"
  337. decoded[:, :] = pad_token_id
  338. decoded_score = paddle.zeros([batch_size * self.num_beam_hyps_to_keep, 1])
  339. # fill with hypotheses and eos_token_id if the latter fits in
  340. for i, (hypo, score) in enumerate(best):
  341. decoded[i, : sent_lengths[i].item()] = hypo.cpu().numpy()
  342. decoded_score[i] = score
  343. if sent_lengths[i] < self.max_length:
  344. decoded[i, sent_lengths[i].item()] = eos_token_id
  345. return decoded, decoded_score
  346. class GenerationMixin(object):
  347. r"""
  348. This class implements the interface for generation task.
  349. It's used as the base class of `paddlenlp.transformers.PretrainedModel
  350. <https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.transformers.model_utils.html>`__.
  351. """
  352. # enable `to_static` method for CausalLM Model
  353. enable_to_static_method = False
  354. @staticmethod
  355. def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
  356. batch_size = 1
  357. if bos_token_id is None:
  358. raise ValueError(
  359. "`bos_token_id` should be defined when no " "`input_ids` are provided."
  360. )
  361. if encoder_output is not None:
  362. batch_size = encoder_output.shape[0]
  363. return paddle.ones([batch_size, 1], dtype="int64") * bos_token_id
  364. @staticmethod
  365. def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id):
  366. is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
  367. input_ids == pad_token_id
  368. ).item()
  369. is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
  370. (eos_token_id is not None) and (pad_token_id != eos_token_id)
  371. )
  372. if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
  373. attention_mask = (input_ids == pad_token_id).astype(
  374. paddle.get_default_dtype()
  375. ) * get_scale_by_dtype(return_positive=False)
  376. else:
  377. attention_mask = paddle.zeros_like(
  378. input_ids, dtype=paddle.get_default_dtype()
  379. )
  380. return paddle.unsqueeze(attention_mask, axis=[1, 2])
  381. @staticmethod
  382. def prepare_seq_len_for_generation(input_ids, pad_token_id, eos_token_id):
  383. is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
  384. input_ids == pad_token_id
  385. ).item()
  386. is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
  387. (eos_token_id is not None) and (pad_token_id != eos_token_id)
  388. )
  389. if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
  390. seq_len = paddle.sum(input_ids != pad_token_id, axis=1).unsqueeze(-1)
  391. else:
  392. seq_len = paddle.full(
  393. (input_ids.shape[0], 1), input_ids.shape[1], dtype="int64"
  394. )
  395. return seq_len
  396. def get_logits_processor(
  397. self,
  398. min_length=None,
  399. max_length=None,
  400. eos_token_id=None,
  401. forced_bos_token_id=None,
  402. forced_eos_token_id=None,
  403. num_beams=1,
  404. num_beam_groups=1,
  405. diversity_rate=0.0,
  406. repetition_penalty=None,
  407. no_repeat_ngram_size=None,
  408. logits_processors=None,
  409. ):
  410. processors = LogitsProcessorList()
  411. if min_length is not None and eos_token_id is not None and min_length > -1:
  412. processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
  413. if num_beam_groups > 1 and diversity_rate > 0.0:
  414. processors.append(
  415. HammingDiversityLogitsProcessor(
  416. diversity_rate=diversity_rate,
  417. num_beams=num_beams,
  418. num_beam_groups=num_beam_groups,
  419. )
  420. )
  421. if repetition_penalty is not None and repetition_penalty != 1.0:
  422. processors.append(
  423. RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
  424. )
  425. if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
  426. processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
  427. if forced_bos_token_id is not None:
  428. processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
  429. if forced_eos_token_id is not None:
  430. processors.append(
  431. ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)
  432. )
  433. # TODO
  434. # Add more pre_processing for distribution
  435. if logits_processors is not None:
  436. custom_processors = LogitsProcessorList()
  437. custom_processors_type = [type(lp) for lp in logits_processors]
  438. for processor in processors:
  439. if type(processor) not in custom_processors_type:
  440. custom_processors.append(processor)
  441. custom_processors.extend(logits_processors)
  442. return custom_processors
  443. else:
  444. return processors
  445. @staticmethod
  446. def expand_inputs_for_generation(
  447. input_ids, expand_size, attention_mask=None, **model_kwargs
  448. ):
  449. index = paddle.tile(
  450. paddle.arange(input_ids.shape[0], dtype="int64").unsqueeze(-1),
  451. [1, expand_size],
  452. ).reshape([-1])
  453. input_ids = paddle.gather(input_ids, index)
  454. if attention_mask is not None:
  455. model_kwargs["attention_mask"] = paddle.gather(attention_mask, index)
  456. if (
  457. "token_type_ids" in model_kwargs
  458. and model_kwargs["token_type_ids"] is not None
  459. ):
  460. token_type_ids = model_kwargs["token_type_ids"]
  461. model_kwargs["token_type_ids"] = paddle.gather(token_type_ids, index)
  462. if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
  463. position_ids = model_kwargs["position_ids"]
  464. model_kwargs["position_ids"] = paddle.gather(position_ids, index)
  465. if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None:
  466. seq_len = model_kwargs["seq_len"]
  467. model_kwargs["seq_len"] = paddle.gather(seq_len, index)
  468. if (
  469. "encoder_output" in model_kwargs
  470. and model_kwargs["encoder_output"] is not None
  471. ):
  472. encoder_output = model_kwargs["encoder_output"]
  473. model_kwargs["encoder_output"] = paddle.gather(encoder_output, index)
  474. if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
  475. role_ids = model_kwargs["role_ids"]
  476. model_kwargs["role_ids"] = paddle.gather(role_ids, index)
  477. return input_ids, model_kwargs
  478. @staticmethod
  479. def update_model_kwargs_for_generation(
  480. outputs, model_kwargs, is_encoder_decoder=False
  481. ):
  482. # Update the model inputs during generation.
  483. # Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
  484. # and they contain pad value, the result vectors updated by this method
  485. # may be different from expected. In this case, you need to rewrite the
  486. # method.
  487. # update cache
  488. if (
  489. isinstance(outputs, tuple)
  490. and len(outputs) > 1
  491. and not isinstance(outputs[1], paddle.Tensor)
  492. ):
  493. model_kwargs["cache"] = outputs[1]
  494. model_kwargs["past_key_values"] = outputs[1]
  495. if isinstance(outputs, ModelOutput) and "past_key_values" in outputs:
  496. model_kwargs["cache"] = outputs.past_key_values
  497. model_kwargs["past_key_values"] = outputs.past_key_values
  498. # update token_type_ids with last value
  499. if (
  500. "token_type_ids" in model_kwargs
  501. and model_kwargs["token_type_ids"] is not None
  502. ):
  503. token_type_ids = model_kwargs["token_type_ids"]
  504. model_kwargs["token_type_ids"] = paddle.concat(
  505. [token_type_ids, token_type_ids[:, -1:]], axis=-1
  506. )
  507. # update position_ids
  508. if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
  509. position_ids = model_kwargs["position_ids"]
  510. model_kwargs["position_ids"] = paddle.concat(
  511. [position_ids, position_ids[..., -1:] + 1], axis=-1
  512. )
  513. # update attention_mask
  514. if not is_encoder_decoder and "attention_mask" in model_kwargs:
  515. attention_mask = model_kwargs["attention_mask"]
  516. # nn.Pad2D don't support the data type `bool`
  517. if convert_dtype(attention_mask.dtype) == "bool":
  518. attention_mask = paddle.cast(attention_mask, "int64")
  519. if len(attention_mask.shape) == 4:
  520. cur_device = paddle.get_device()
  521. if cur_device.split(":")[0] == "npu":
  522. attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(
  523. attention_mask
  524. )
  525. attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
  526. else:
  527. attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(
  528. attention_mask
  529. )
  530. attention_mask = nn.Pad2D(
  531. [0, 1, 0, 0], value=get_scale_by_dtype(return_positive=False)
  532. )(attention_mask)
  533. dtype = convert_dtype(attention_mask.dtype)
  534. if "int" in dtype:
  535. attention_mask[:, :, -1, -1] = 1
  536. elif "float" in dtype:
  537. attention_mask[:, :, -1, -1] = 0.0
  538. else:
  539. raise ValueError(
  540. "The data type of input `attention_mask` must "
  541. "be bool, int or float"
  542. )
  543. else:
  544. attention_mask = paddle.concat(
  545. [
  546. attention_mask,
  547. paddle.ones([attention_mask.shape[0], 1], dtype="int64"),
  548. ],
  549. axis=-1,
  550. )
  551. model_kwargs["attention_mask"] = attention_mask
  552. # update role_ids
  553. if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
  554. role_ids = model_kwargs["role_ids"]
  555. model_kwargs["role_ids"] = paddle.concat(
  556. [role_ids, role_ids[:, -1:]], axis=-1
  557. )
  558. return model_kwargs
  559. @staticmethod
  560. def update_scores_for_generation(scores, next_scores, length, unfinished_flag):
  561. # update scores
  562. unfinished_scores = (
  563. scores * paddle.to_tensor(length, dtype=scores.dtype) + next_scores
  564. ) / (paddle.to_tensor(length, dtype=scores.dtype) + 1)
  565. scores = paddle.where(unfinished_flag, unfinished_scores, scores)
  566. return scores
  567. def prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
  568. if "encoder_output" not in model_kwargs:
  569. # retrieve encoder hidden states
  570. encoder = self.get_encoder()
  571. encoder_kwargs = {
  572. argument: value
  573. for argument, value in model_kwargs.items()
  574. if not (
  575. argument.startswith("decoder_")
  576. or argument.startswith("cross_attn")
  577. or argument == "use_cache"
  578. )
  579. }
  580. # Use inputs_embeds as the priority if inputs_embeds exists
  581. if "inputs_embeds" in encoder_kwargs:
  582. model_kwargs["encoder_output"] = encoder(**encoder_kwargs)
  583. else:
  584. model_kwargs["encoder_output"] = encoder(
  585. input_ids=input_ids, **encoder_kwargs
  586. )
  587. return model_kwargs
  588. def prepare_decoder_input_ids_for_generation(
  589. self, input_ids, decoder_start_token_id=None, bos_token_id=None
  590. ):
  591. decoder_start_token_id = (
  592. decoder_start_token_id
  593. if decoder_start_token_id is not None
  594. else self.config.decoder_start_token_id
  595. )
  596. decoder_start_token_id = (
  597. decoder_start_token_id
  598. if decoder_start_token_id is not None
  599. else bos_token_id
  600. )
  601. decoder_input_ids = (
  602. paddle.ones([input_ids.shape[0], 1], dtype="int64") * decoder_start_token_id
  603. )
  604. return decoder_input_ids
  605. def get_decoder_start_token_id(
  606. self, decoder_start_token_id=None, bos_token_id=None
  607. ):
  608. decoder_start_token_id = (
  609. decoder_start_token_id
  610. if decoder_start_token_id is not None
  611. else self.config.decoder_start_token_id
  612. )
  613. bos_token_id = (
  614. bos_token_id if bos_token_id is not None else self.config.bos_token_id
  615. )
  616. if decoder_start_token_id is not None:
  617. return decoder_start_token_id
  618. elif self.config.decoder_start_token_id is not None:
  619. return self.config.decoder_start_token_id
  620. elif bos_token_id is not None:
  621. return bos_token_id
  622. elif self.config.bos_token_id is not None:
  623. return self.config.bos_token_id
  624. raise ValueError(
  625. "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
  626. )
  627. def prepare_inputs_for_generation(self, input_ids, **kwargs):
  628. # Implement in subclasses for custom behavior to prepare inputs in the
  629. # generate method.
  630. return {"input_ids": input_ids}
  631. def adjust_logits_during_generation(self, logits):
  632. # Implement in subclasses for custom behavior to adjust the logits in
  633. # the generate method.
  634. return logits
  635. def prepare_fast_entry(self, kwargs):
  636. return False
  637. def _convert_to_fast(self, kwargs):
  638. # try general convert
  639. pass
  640. def _build_fast(self, kwargs):
  641. self._fast_entry = False
  642. if kwargs["num_beam_groups"] != 1:
  643. # not support for group_beam_search yet in the fast version
  644. raise AttributeError(
  645. "'num_beam_groups != 1' is not supported yet in the fast version"
  646. )
  647. if (
  648. paddle.get_default_dtype() == "float16"
  649. and kwargs["use_fp16_decoding"] is False
  650. ):
  651. logging.info(
  652. "Since the default dtype is float16, float16 would be used "
  653. "though 'use_fp16_decoding=False'."
  654. )
  655. kwargs["use_fp16_decoding"] = True
  656. self.prepare_fast_entry(kwargs)
  657. def set_pad_token_id(self, pad_token_id, eos_token_id):
  658. if pad_token_id is None and eos_token_id is not None:
  659. logging.warning(
  660. "Setting `pad_token_id` to `eos_token_id`:{} for "
  661. "open-end generation.".format(eos_token_id)
  662. )
  663. if isinstance(eos_token_id, list):
  664. pad_token_id = eos_token_id[0]
  665. else:
  666. pad_token_id = eos_token_id
  667. return pad_token_id
  668. @paddle.no_grad()
  669. def generate(
  670. self,
  671. input_ids: paddle.Tensor = None,
  672. generation_config: GenerationConfig = None,
  673. stopping_criteria: StoppingCriteria = None,
  674. streamer=None,
  675. synced_gpus: Optional[bool] = None,
  676. **kwargs,
  677. ):
  678. r"""
  679. The interface for generation task. This method can generate sequences
  680. by using decoding strategy. Currently, there are three decoding
  681. strategies supported: "greedy_search", "sampling" and "beam_search".
  682. Args:
  683. input_ids (Tensor, optional): The input sequence ids for the
  684. generation. It is a Tensor with shape [batch_size, sequence_length].
  685. The data type should be int32 or int64. Default to None, which
  686. we will initialize it as a Tensor with shape [1, 1], filled
  687. with the value `bos_token_id`.
  688. generation_config (`~generation.GenerationConfig`, *optional*):
  689. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  690. passed to generate matching the attributes of `generation_config` will override them. If
  691. `generation_config` is not provided, the default will be used, which had the following loading
  692. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  693. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  694. default values, whose documentation should be checked to parameterize generation.
  695. stopping_criteria (`StoppingCriteriaList`, *optional*):
  696. Custom stopping criteria that complement the default stopping criteria built from arguments and a
  697. generation config. If a stopping criteria is passed that is already created with the arguments or a
  698. generation config an error is thrown. This feature is intended for advanced users.
  699. streamer (`~streamer.BaseStreamer`, *optional*):
  700. Streamer object that will be used to stream the generated sequences. Generated tokens are passed
  701. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
  702. synced_gpus (`bool`, *optional*):
  703. Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
  704. `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
  705. generating before other GPUs. Otherwise it'll be set to `False`.
  706. kwargs (dict): It can be used to specify additional kwargs
  707. passed to the model.
  708. Returns:
  709. tuple[Tensor]: It is a tuple contains two elements: ids and scores.
  710. Each element is a Tensor.
  711. With the fields:
  712. - ids (Tensor):
  713. The ids of the generated sequences. It is a Tensor with shape
  714. [batch_size * num_return_sequences, sequence_length]. The data
  715. type is same as the input `input_ids`.
  716. - scores (Tensor):
  717. The scores of the generated sequences. It is a Tensor with shape
  718. [batch_size * num_return_sequences, 1]. The data type is float32
  719. or float64, which is the same as the parameters in the model.
  720. Example:
  721. .. code-block::
  722. import paddle
  723. from paddlenlp.transformers import (
  724. UnifiedTransformerLMHeadModel,
  725. UnifiedTransformerTokenizer
  726. )
  727. paddle.seed(2)
  728. # Initialize the model and tokenizer
  729. model_name_or_path = 'unified_transformer-12L-cn-luge'
  730. model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
  731. tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)
  732. # Prepare the model inputs.
  733. history = "早上好,今天空气质量不错。"
  734. inputs = tokenizer.dialogue_encode(history, task_type='chitchat',
  735. add_start_token_as_response=True, return_tensors=True)
  736. .. code-block::
  737. # Generate the sequence by using "greedy_search" strategy
  738. ids, scores = model.generate(
  739. **inputs,
  740. decode_strategy="greedy_search")
  741. print(ids.shape, scores.shape)
  742. # [1, 3] [1, 1]
  743. sequence_ids = ids.cpu().numpy().tolist()[0]
  744. sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
  745. response = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
  746. print(response)
  747. # 是的
  748. .. code-block::
  749. # Generate 2 sequences by using "sampling" strategy (top_k=5)
  750. generation_config = GenerationConfig(
  751. decode_strategy="sampling",
  752. top_k=5,
  753. num_return_sequences=2
  754. )
  755. ids, scores = model.generate(
  756. **inputs,
  757. generation_config=generation_config,
  758. )
  759. print(ids.shape, scores.shape)
  760. # [2, 7] [2, 1]
  761. response = []
  762. for sequence_ids in ids.cpu().numpy().tolist():
  763. sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
  764. text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
  765. response.append(text)
  766. print(response)
  767. # ['天气好,心情也好', '你也是']
  768. .. code-block::
  769. # Generate 2 sequences by using "beam_search" strategy (num_beams=5)
  770. generation_config = GenerationConfig(
  771. decode_strategy="beam_search",
  772. num_beams=5,
  773. num_return_sequences=2
  774. )
  775. ids, scores = model.generate(
  776. **inputs,
  777. generation_config=generation_config,
  778. )
  779. print(ids.shape, scores.shape)
  780. # [2, 3] [2, 1]
  781. response = []
  782. for sequence_ids in ids.cpu().numpy().tolist():
  783. sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
  784. text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
  785. response.append(text)
  786. print(response)
  787. # ['是的', '嗯嗯']
  788. """
  789. if generation_config is None:
  790. if self.generation_config is None or (
  791. self.generation_config._from_model_config
  792. and self.config._has_non_default_generation_parameters()
  793. ):
  794. new_generation_config = GenerationConfig.from_model_config(self.config)
  795. if new_generation_config != self.generation_config:
  796. logging.warning(
  797. "model.generation_config is in conflict with model.config, "
  798. "model.config is used."
  799. )
  800. self.generation_config = new_generation_config
  801. generation_config = self.generation_config
  802. # without update model.generation_config
  803. generation_config = copy.deepcopy(generation_config)
  804. model_kwargs = generation_config.update(**kwargs)
  805. assert generation_config.decode_strategy in [
  806. "greedy_search",
  807. "sampling",
  808. "beam_search",
  809. ], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format(
  810. generation_config.decode_strategy
  811. )
  812. if getattr(self, "deprecated_warnings", None) is None:
  813. self.deprecated_warnings = {}
  814. use_fast = False
  815. if "use_faster" in model_kwargs:
  816. raise ValueError("`use_faster` is deprecated now.")
  817. if "use_fast" in model_kwargs:
  818. raise ValueError("`use_fast` is deprecated now.")
  819. bos_token_id = (
  820. generation_config.bos_token_id
  821. if generation_config.bos_token_id is not None
  822. else self.config.bos_token_id
  823. )
  824. eos_token_id = (
  825. generation_config.eos_token_id
  826. if generation_config.eos_token_id is not None
  827. else self.config.eos_token_id
  828. )
  829. pad_token_id = (
  830. generation_config.pad_token_id
  831. if generation_config.pad_token_id is not None
  832. else self.config.pad_token_id
  833. )
  834. forced_bos_token_id = (
  835. generation_config.forced_bos_token_id
  836. if generation_config.forced_bos_token_id is not None
  837. else self.config.forced_bos_token_id
  838. )
  839. forced_eos_token_id = (
  840. generation_config.forced_eos_token_id
  841. if generation_config.forced_eos_token_id is not None
  842. else self.config.forced_eos_token_id
  843. )
  844. decoder_start_token_id = (
  845. generation_config.decoder_start_token_id
  846. if generation_config.decoder_start_token_id is not None
  847. else self.config.decoder_start_token_id
  848. )
  849. no_repeat_ngram_size = (
  850. generation_config.no_repeat_ngram_size
  851. if generation_config.no_repeat_ngram_size is not None
  852. else self.config.no_repeat_ngram_size
  853. )
  854. if getattr(self, "_fast_entry", None) is not False and use_fast:
  855. fg_args = locals()
  856. fg_args.pop("self")
  857. fg_args.pop("__class__", None)
  858. model_kwargs = fg_args.pop("model_kwargs")
  859. fg_args.update(model_kwargs)
  860. try:
  861. if getattr(self, "_fast_entry", None) is None:
  862. self._build_fast(fg_args)
  863. if self._fast_entry:
  864. output = self._fast_entry(**fg_args)
  865. if isinstance(output, tuple):
  866. output_ids, dummy_srore = output
  867. else:
  868. output_ids = output
  869. # make result and fast result oneconsistent
  870. dummy_srore = None
  871. if generation_config.decode_strategy == "beam_search":
  872. output_ids = output_ids.transpose([1, 2, 0])
  873. output_ids = output_ids[
  874. :, : generation_config.num_return_sequences, :
  875. ].reshape([-1, output_ids.shape[-1]])
  876. if dummy_srore is not None:
  877. dummy_srore = dummy_srore[
  878. :, : generation_config.num_return_sequences
  879. ].flatten()
  880. else:
  881. output_ids = output_ids.transpose([1, 0])
  882. return output_ids, dummy_srore
  883. except Exception as e:
  884. fg_args["model_kwargs"] = model_kwargs
  885. # TODO
  886. # Prevent self._convert_to_fast to throw Exception
  887. self._convert_to_fast(fg_args)
  888. logging.warning(e)
  889. logging.warning(
  890. "FastGeneration is not available, "
  891. "and the original version would be used instead."
  892. )
  893. # input_ids in model_kwargs is supported
  894. if "input_ids" in model_kwargs:
  895. _input_ids = model_kwargs.pop("input_ids")
  896. if input_ids is None:
  897. input_ids = _input_ids
  898. # params check
  899. if input_ids is None and "inputs_embeds" not in model_kwargs:
  900. # Init `input_ids` with bos_token_id
  901. input_ids = self.prepare_input_ids_for_generation(bos_token_id)
  902. elif "inputs_embeds" in model_kwargs:
  903. # Add input embeds support
  904. input_ids = self.prepare_input_ids_for_generation(
  905. bos_token_id, encoder_output=model_kwargs["inputs_embeds"]
  906. )
  907. if model_kwargs.get("attention_mask", None) is None:
  908. # TODO
  909. # Init `attention_mask` depending on `pad_token_id`
  910. model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation(
  911. input_ids, pad_token_id, eos_token_id
  912. )
  913. self.is_encoder_decoder = self.config.is_encoder_decoder
  914. if self.is_encoder_decoder:
  915. model_kwargs = self.prepare_encoder_decoder_kwargs_for_generation(
  916. input_ids, model_kwargs
  917. )
  918. # set input_ids as decoder_input_ids
  919. if "decoder_input_ids" in model_kwargs:
  920. input_ids = model_kwargs.pop("decoder_input_ids")
  921. else:
  922. input_ids = self.prepare_decoder_input_ids_for_generation(
  923. input_ids, decoder_start_token_id, bos_token_id
  924. )
  925. # streamer
  926. if streamer is not None:
  927. # streamer couldn't support beam_search strategy
  928. if (
  929. generation_config.decode_strategy == "beam_search"
  930. or generation_config.num_beams > 1
  931. ):
  932. raise ValueError(
  933. "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
  934. )
  935. pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
  936. if (
  937. generation_config.max_length != 0
  938. and generation_config.max_new_tokens == DEFAULT_MAX_NEW_TOKENS
  939. ):
  940. logging.warning(
  941. "`max_length` will be deprecated in future releases, use `max_new_tokens` instead."
  942. )
  943. generation_config.max_new_tokens = generation_config.max_length
  944. if generation_config.min_length != 0 and generation_config.min_new_tokens == 0:
  945. logging.warning(
  946. "`min_length` will be deprecated in future releases, use `min_new_tokens` instead."
  947. )
  948. generation_config.min_new_tokens = generation_config.min_length
  949. max_length = generation_config.max_new_tokens
  950. min_length = generation_config.min_new_tokens
  951. input_len = input_ids.shape[-1]
  952. min_len = input_len + min_length
  953. max_len = input_len + max_length
  954. logits_processors = self.get_logits_processor(
  955. min_length=min_len if min_length > 0 else None,
  956. max_length=max_len,
  957. eos_token_id=eos_token_id,
  958. forced_bos_token_id=forced_bos_token_id,
  959. forced_eos_token_id=forced_eos_token_id,
  960. num_beams=generation_config.num_beams,
  961. num_beam_groups=generation_config.num_beam_groups,
  962. diversity_rate=generation_config.diversity_rate,
  963. repetition_penalty=generation_config.repetition_penalty,
  964. no_repeat_ngram_size=generation_config.no_repeat_ngram_size,
  965. logits_processors=(
  966. model_kwargs["logits_processors"]
  967. if "logits_processors" in model_kwargs
  968. and isinstance(model_kwargs["logits_processors"], LogitsProcessorList)
  969. else None
  970. ),
  971. )
  972. if "logits_processors" in model_kwargs:
  973. model_kwargs.pop("logits_processors")
  974. model_kwargs["use_cache"] = generation_config.use_cache
  975. stopping_criteria = (
  976. stopping_criteria
  977. if stopping_criteria is not None
  978. else StoppingCriteriaList()
  979. )
  980. if generation_config.decode_strategy == "greedy_search":
  981. if generation_config.num_return_sequences > 1:
  982. raise ValueError(
  983. "`num_return_sequences` has to be 1, but is {} "
  984. "when doing greedy search.".format(
  985. generation_config.num_return_sequences
  986. )
  987. )
  988. return self.greedy_search(
  989. input_ids,
  990. logits_processors,
  991. max_len,
  992. pad_token_id,
  993. eos_token_id,
  994. stopping_criteria=stopping_criteria,
  995. streamer=streamer,
  996. fast_ptq_sampling=generation_config.fast_ptq_sampling,
  997. trunc_input=generation_config.trunc_input,
  998. synced_gpus=synced_gpus,
  999. **model_kwargs,
  1000. )
  1001. elif generation_config.decode_strategy == "sampling":
  1002. if generation_config.num_return_sequences > 1:
  1003. input_ids, model_kwargs = self.expand_inputs_for_generation(
  1004. input_ids,
  1005. expand_size=generation_config.num_return_sequences,
  1006. **model_kwargs,
  1007. )
  1008. return self.sample(
  1009. input_ids,
  1010. logits_processors,
  1011. max_len,
  1012. pad_token_id,
  1013. eos_token_id,
  1014. generation_config.top_k,
  1015. generation_config.top_p,
  1016. generation_config.temperature,
  1017. stopping_criteria=stopping_criteria,
  1018. streamer=streamer,
  1019. fast_ptq_sampling=generation_config.fast_ptq_sampling,
  1020. trunc_input=generation_config.trunc_input,
  1021. synced_gpus=synced_gpus,
  1022. **model_kwargs,
  1023. )
  1024. elif generation_config.decode_strategy == "beam_search":
  1025. batch_size = input_ids.shape[0]
  1026. if generation_config.num_return_sequences > generation_config.num_beams:
  1027. raise ValueError(
  1028. "`num_return_sequences` has to be smaller or equal to "
  1029. "`num_beams`. But received `num_return_sequences` is {}, "
  1030. "`num_beams` is {}".format(
  1031. generation_config.num_return_sequences,
  1032. generation_config.num_beams,
  1033. )
  1034. )
  1035. if generation_config.num_beams <= 1:
  1036. raise ValueError(
  1037. "`num_beams` has to be bigger than 1. But received "
  1038. "`num_beams` is {}. If `num_beams` is 1, `decode_strategy` "
  1039. "should be 'greedy_search'".format(generation_config.num_beams)
  1040. )
  1041. if generation_config.num_beam_groups > 1:
  1042. diverse_beam_scorer = BeamSearchScorer(
  1043. batch_size=batch_size,
  1044. max_length=max_len,
  1045. num_beams=generation_config.num_beams,
  1046. length_penalty=generation_config.length_penalty,
  1047. do_early_stopping=generation_config.early_stopping,
  1048. num_beam_hyps_to_keep=generation_config.num_return_sequences,
  1049. num_beam_groups=generation_config.num_beam_groups,
  1050. )
  1051. # interleave with `num_beams`
  1052. input_ids, model_kwargs = self.expand_inputs_for_generation(
  1053. input_ids, expand_size=generation_config.num_beams, **model_kwargs
  1054. )
  1055. return self.group_beam_search(
  1056. input_ids,
  1057. diverse_beam_scorer,
  1058. logits_processors,
  1059. max_len,
  1060. pad_token_id,
  1061. eos_token_id,
  1062. stopping_criteria=stopping_criteria,
  1063. fast_ptq_sampling=generation_config.fast_ptq_sampling,
  1064. trunc_input=generation_config.trunc_input,
  1065. synced_gpus=synced_gpus,
  1066. **model_kwargs,
  1067. )
  1068. else:
  1069. beam_scorer = BeamSearchScorer(
  1070. batch_size=batch_size,
  1071. max_length=max_len,
  1072. num_beams=generation_config.num_beams,
  1073. length_penalty=generation_config.length_penalty,
  1074. do_early_stopping=generation_config.early_stopping,
  1075. num_beam_hyps_to_keep=generation_config.num_return_sequences,
  1076. )
  1077. input_ids, model_kwargs = self.expand_inputs_for_generation(
  1078. input_ids, expand_size=generation_config.num_beams, **model_kwargs
  1079. )
  1080. return self.beam_search(
  1081. input_ids,
  1082. beam_scorer,
  1083. logits_processors,
  1084. max_len,
  1085. generation_config.diversity_rate,
  1086. pad_token_id,
  1087. eos_token_id,
  1088. stopping_criteria=stopping_criteria,
  1089. fast_ptq_sampling=generation_config.fast_ptq_sampling,
  1090. trunc_input=generation_config.trunc_input,
  1091. synced_gpus=synced_gpus,
  1092. **model_kwargs,
  1093. )
  1094. def greedy_search(
  1095. self,
  1096. input_ids,
  1097. logits_processors,
  1098. max_length,
  1099. pad_token_id,
  1100. eos_token_id,
  1101. stopping_criteria=None,
  1102. streamer=None,
  1103. fast_ptq_sampling=False,
  1104. trunc_input=True,
  1105. synced_gpus=False,
  1106. **model_kwargs,
  1107. ):
  1108. logits_processors = (
  1109. logits_processors
  1110. if logits_processors is not None
  1111. else LogitsProcessorList()
  1112. )
  1113. # max_length will be convert to MaxLengthCriteria
  1114. stopping_criteria = (
  1115. stopping_criteria
  1116. if stopping_criteria is not None
  1117. else StoppingCriteriaList()
  1118. )
  1119. if max_length is not None:
  1120. # logging.warning(
  1121. # "`max_length` is deprecated in this function, use"
  1122. # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
  1123. # )
  1124. stopping_criteria = validate_stopping_criteria(
  1125. stopping_criteria, max_length
  1126. )
  1127. batch_size, cur_len = input_ids.shape
  1128. origin_len = cur_len
  1129. unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
  1130. scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
  1131. generate_end = False
  1132. while True:
  1133. if synced_gpus:
  1134. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
  1135. # The following logic allows an early break if all peers finished generating their sequence
  1136. this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
  1137. # send 0.0 if we finished, 1.0 otherwise
  1138. dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
  1139. # did all peers finish? the reduced sum will be 0.0 then
  1140. if this_peer_finished_flag.item() == 0.0:
  1141. break
  1142. # prepare model inputs & get model output
  1143. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  1144. outputs = self(**model_inputs)
  1145. if synced_gpus and generate_end:
  1146. continue # don't waste resources running the code we don't need
  1147. if isinstance(outputs, tuple):
  1148. logits = outputs[0]
  1149. elif isinstance(outputs, ModelOutput):
  1150. logits = outputs.logits
  1151. else:
  1152. logits = outputs
  1153. # [batch_size, vocab_size]
  1154. next_token_logits = logits[:, -1, :]
  1155. # pre-process distribution
  1156. next_token_logits = self.adjust_logits_during_generation(next_token_logits)
  1157. probs = logits_processors(input_ids, next_token_logits)
  1158. # greedy
  1159. next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
  1160. next_scores = paddle.index_sample(probs, next_tokens)
  1161. if eos_token_id is not None:
  1162. next_tokens = paddle.where(
  1163. unfinished_flag,
  1164. next_tokens,
  1165. paddle.full_like(next_tokens, pad_token_id),
  1166. )
  1167. scores = self.update_scores_for_generation(
  1168. scores, next_scores, cur_len - origin_len, unfinished_flag
  1169. )
  1170. cur_len += 1
  1171. input_ids = paddle.concat([input_ids, next_tokens], axis=1)
  1172. if streamer is not None:
  1173. if self.config.tensor_parallel_rank == 0:
  1174. streamer.put(next_tokens.cpu())
  1175. if stopping_criteria(input_ids, scores):
  1176. generate_end = True
  1177. if eos_token_id is not None:
  1178. unfinished_flag = get_unfinished_flag(
  1179. input_ids, unfinished_flag, eos_token_id
  1180. )
  1181. if not paddle.any(unfinished_flag):
  1182. generate_end = True
  1183. # Stop when there is a </s> in all sentences
  1184. if generate_end and not synced_gpus:
  1185. break
  1186. model_kwargs = self.update_model_kwargs_for_generation(
  1187. outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
  1188. )
  1189. if fast_ptq_sampling:
  1190. break
  1191. if streamer is not None:
  1192. streamer.end()
  1193. return input_ids[:, origin_len:] if trunc_input else input_ids, scores
  1194. def sample(
  1195. self,
  1196. input_ids,
  1197. logits_processors,
  1198. max_length,
  1199. pad_token_id,
  1200. eos_token_id,
  1201. top_k=None,
  1202. top_p=None,
  1203. temperature=None,
  1204. min_tokens_to_keep=1,
  1205. stopping_criteria=None,
  1206. streamer=None,
  1207. fast_ptq_sampling=False,
  1208. trunc_input=True,
  1209. synced_gpus=False,
  1210. **model_kwargs,
  1211. ):
  1212. logits_processors = (
  1213. logits_processors
  1214. if logits_processors is not None
  1215. else LogitsProcessorList()
  1216. )
  1217. # max_length will be convert to MaxLengthCriteria
  1218. stopping_criteria = (
  1219. stopping_criteria
  1220. if stopping_criteria is not None
  1221. else StoppingCriteriaList()
  1222. )
  1223. if max_length is not None:
  1224. # logging.warning(
  1225. # "`max_length` is deprecated in this function, use"
  1226. # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
  1227. # )
  1228. stopping_criteria = validate_stopping_criteria(
  1229. stopping_criteria, max_length
  1230. )
  1231. batch_size, cur_len = input_ids.shape
  1232. origin_len = cur_len
  1233. unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
  1234. scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
  1235. generate_end = False
  1236. while True:
  1237. if synced_gpus:
  1238. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
  1239. # The following logic allows an early break if all peers finished generating their sequence
  1240. this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
  1241. # send 0.0 if we finished, 1.0 otherwise
  1242. dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
  1243. # did all peers finish? the reduced sum will be 0.0 then
  1244. if this_peer_finished_flag.item() == 0.0:
  1245. break
  1246. # prepare model inputs & get model output
  1247. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  1248. # NOTE: to decrease ref-count and clear outdate cache in-time
  1249. model_kwargs["cache"] = None
  1250. model_kwargs["past_key_values"] = None
  1251. outputs = self(**model_inputs)
  1252. if synced_gpus and generate_end:
  1253. continue # don't waste resources running the code we don't need
  1254. if isinstance(outputs, tuple):
  1255. logits = outputs[0]
  1256. elif isinstance(outputs, ModelOutput):
  1257. logits = outputs.logits
  1258. else:
  1259. logits = outputs
  1260. # [batch_size, vocab_size]
  1261. logits = logits[:, -1, :]
  1262. # pre-process distribution
  1263. logits = self.adjust_logits_during_generation(logits)
  1264. logits = logits_processors(input_ids, logits)
  1265. # sample
  1266. origin_probs = F.softmax(logits)
  1267. origin_probs = paddle.log(origin_probs)
  1268. if temperature is not None and temperature != 1.0:
  1269. logits = logits / temperature
  1270. probs = F.softmax(logits)
  1271. if top_k is not None and top_k != 0:
  1272. probs = TopKProcess(probs, top_k, min_tokens_to_keep)
  1273. if top_p is not None and top_p < 1.0:
  1274. probs = TopPProcess(probs, top_p, min_tokens_to_keep)
  1275. if paddle.device.is_compiled_with_custom_device("gcu"):
  1276. probs = paddle.cast(probs, "float32")
  1277. if paddle.device.is_compiled_with_xpu():
  1278. probs = paddle.cast(probs, "float32")
  1279. # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
  1280. next_tokens = paddle.multinomial(probs)
  1281. if self.config.tensor_parallel_degree > 1:
  1282. # Maybe no need to broadcast if seed is set correctly.
  1283. from paddle.distributed import fleet
  1284. try:
  1285. hcg = fleet.get_hybrid_communicate_group()
  1286. group = hcg.get_model_parallel_group()
  1287. src = hcg.get_model_parallel_group_src_rank()
  1288. except:
  1289. group, src = None, 0
  1290. paddle.distributed.broadcast(next_tokens, src=src, group=group)
  1291. # config does not include pipeline_parallel_degree, and pipeline parallel
  1292. # uses trainer.model_wrapped to run in both train and predict mode
  1293. # which has pp_group as a attribute
  1294. # TODO(guosheng): only let the last stage of pipeline to do softmax
  1295. # and sampling, and then broadcast to avoid broadcast logits.
  1296. if getattr(self, "pp_group", None) is not None:
  1297. paddle.distributed.broadcast(
  1298. next_tokens,
  1299. src=self.pp_group.ranks[0],
  1300. group=self.pp_group, # use rank 0 for same seed to check
  1301. )
  1302. next_scores = paddle.index_sample(origin_probs, next_tokens)
  1303. if eos_token_id is not None:
  1304. next_tokens = paddle.where(
  1305. unfinished_flag,
  1306. next_tokens,
  1307. paddle.full_like(next_tokens, pad_token_id),
  1308. )
  1309. scores = self.update_scores_for_generation(
  1310. scores, next_scores, cur_len - origin_len, unfinished_flag
  1311. )
  1312. cur_len += 1
  1313. input_ids = paddle.concat([input_ids, next_tokens], axis=1)
  1314. if streamer is not None:
  1315. if self.config.tensor_parallel_rank == 0:
  1316. streamer.put(next_tokens.cpu())
  1317. if stopping_criteria(input_ids, scores):
  1318. generate_end = True
  1319. if eos_token_id is not None:
  1320. unfinished_flag = get_unfinished_flag(
  1321. input_ids, unfinished_flag, eos_token_id
  1322. )
  1323. if not paddle.any(unfinished_flag):
  1324. generate_end = True
  1325. # Stop when there is a </s> in all sentences
  1326. if generate_end and not synced_gpus:
  1327. break
  1328. model_kwargs = self.update_model_kwargs_for_generation(
  1329. outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
  1330. )
  1331. if fast_ptq_sampling:
  1332. break
  1333. if streamer is not None:
  1334. streamer.end()
  1335. return input_ids[:, origin_len:] if trunc_input else input_ids, scores
  1336. def _get_model_inputs_spec(self, dtype: str):
  1337. spec = {
  1338. "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
  1339. "attention_mask": paddle.static.InputSpec(
  1340. shape=[None, None], dtype="int64"
  1341. ),
  1342. }
  1343. if "position_ids" in inspect.getfullargspec(self.forward).args:
  1344. spec["position_ids"] = paddle.static.InputSpec(
  1345. shape=[None, None], dtype="int64"
  1346. )
  1347. return spec
  1348. def to_static(self, path: str, config: dict):
  1349. """export generation model to static
  1350. Args:
  1351. path (str): path of saved inference model
  1352. config (dict): configuration for generation
  1353. bos_token_id (int): token id of begin-of-sentence
  1354. eos_token_id (int): token id of end-of-sentence
  1355. pad_token_id (int): token id of pad token
  1356. use_top_p (bool): whether use top_p decoding strategy
  1357. """
  1358. use_top_p = config.get("use_top_p", True)
  1359. top_k_spec = (
  1360. paddle.static.InputSpec(shape=[1], dtype="int64") if not use_top_p else 0
  1361. )
  1362. top_p_spec = (
  1363. paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
  1364. )
  1365. temperature = (
  1366. paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
  1367. )
  1368. dtype = config.get("dtype", None)
  1369. logits_processors = config.get("logits_processors", None)
  1370. model_inputs_spec = self._get_model_inputs_spec(dtype)
  1371. input_spec = [
  1372. model_inputs_spec["input_ids"], # input_ids
  1373. model_inputs_spec["attention_mask"], # attention_mask
  1374. model_inputs_spec.get("position_ids", None), # attention_mask
  1375. logits_processors,
  1376. paddle.static.InputSpec(shape=[1], dtype="int64"), # max_length
  1377. self.generation_config.pad_token_id or config.get("pad_token_id", None),
  1378. self.generation_config.eos_token_id or config.get("eos_token_id", None),
  1379. top_k_spec, # top_k
  1380. top_p_spec, # top_p
  1381. temperature, # temperature
  1382. 1,
  1383. ]
  1384. model = paddle.jit.to_static(self.sample_d2s, input_spec=input_spec)
  1385. paddle.jit.save(model, path)
  1386. def sample_d2s(
  1387. self,
  1388. input_ids,
  1389. attention_mask,
  1390. position_ids,
  1391. logits_processors,
  1392. max_new_tokens,
  1393. pad_token_id,
  1394. eos_token_id,
  1395. top_k=None,
  1396. top_p=None,
  1397. temperature=None,
  1398. min_tokens_to_keep=1,
  1399. ):
  1400. pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
  1401. logits_processors = (
  1402. logits_processors
  1403. if logits_processors is not None
  1404. else LogitsProcessorList()
  1405. )
  1406. if paddle.is_tensor(top_k) and not paddle.is_tensor(top_p):
  1407. use_top_p = False
  1408. elif not paddle.is_tensor(top_k) and paddle.is_tensor(top_p):
  1409. use_top_p = True
  1410. # top_k and top_p are the const value
  1411. elif isinstance(top_p, float) or isinstance(top_k, int):
  1412. use_top_p = True
  1413. else:
  1414. if top_p is None and top_k is None:
  1415. raise ValueError("top_k and top_p should not be None")
  1416. raise ValueError(
  1417. "you should not specify InputSpec for top_k and top_p parameters, one of InputSpec is expected"
  1418. )
  1419. batch_size, cur_len = input_ids.shape
  1420. # used for compute on gpu, avoid memcpy D2H
  1421. cur_len_gpu = paddle.full([1], cur_len, dtype="int64")
  1422. origin_len = input_ids.shape[1]
  1423. # used for compute on gpu, avoid memcpy D2H
  1424. origin_len_gpu = paddle.full([1], origin_len, dtype="int64")
  1425. unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
  1426. scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
  1427. # use_cache is immutable, we split it off other mutable kwargs.
  1428. immutable = {"use_cache": True}
  1429. model_kwargs = {"attention_mask": attention_mask, "position_ids": position_ids}
  1430. def _forward_(**args):
  1431. model_inputs = self.prepare_inputs_for_generation(
  1432. input_ids, **args, **immutable
  1433. )
  1434. assert "use_cache" in model_inputs
  1435. del model_inputs["use_cache"]
  1436. return self(**model_inputs, **immutable)
  1437. def _post_process_(
  1438. outputs,
  1439. input_ids,
  1440. cur_len,
  1441. origin_len,
  1442. scores,
  1443. unfinished_flag,
  1444. model_kwargs,
  1445. pad_token_id,
  1446. ):
  1447. if isinstance(outputs, tuple):
  1448. logits = outputs[0]
  1449. elif isinstance(outputs, ModelOutput):
  1450. logits = outputs.logits
  1451. else:
  1452. logits = outputs
  1453. # [batch_size, vocab_size]
  1454. logits = logits[:, -1, :]
  1455. # pre-process distribution
  1456. logits = self.adjust_logits_during_generation(logits)
  1457. logits = logits_processors(input_ids, logits)
  1458. probs = F.softmax(logits)
  1459. # sample
  1460. origin_probs = F.log_softmax(logits)
  1461. # compute next_tokens
  1462. if use_top_p:
  1463. logits = logits / temperature
  1464. top_ps_tensor = paddle.full(
  1465. shape=[probs.shape[0], 1], fill_value=top_p, dtype=probs.dtype
  1466. )
  1467. _, next_tokens = paddle.tensor.top_p_sampling(probs, top_ps_tensor)
  1468. else:
  1469. probs = TopKProcess(probs, top_k, min_tokens_to_keep)
  1470. if top_k == 1:
  1471. next_tokens = paddle.unsqueeze_(paddle.argmax(probs, axis=-1), -1)
  1472. else:
  1473. next_tokens = paddle.multinomial(probs)
  1474. next_scores = paddle.index_sample(origin_probs, next_tokens)
  1475. scores = self.update_scores_for_generation(
  1476. scores, next_scores, cur_len - origin_len, unfinished_flag
  1477. )
  1478. if eos_token_id is not None:
  1479. next_tokens = paddle.where(
  1480. unfinished_flag,
  1481. next_tokens,
  1482. paddle.full_like(next_tokens, pad_token_id),
  1483. )
  1484. input_ids = paddle.concat([input_ids, next_tokens], axis=1)
  1485. if eos_token_id is not None:
  1486. unfinished_flag = get_unfinished_flag(
  1487. input_ids, unfinished_flag, eos_token_id
  1488. )
  1489. model_kwargs = self.update_model_kwargs_for_generation(
  1490. outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
  1491. )
  1492. return input_ids, scores, unfinished_flag, model_kwargs
  1493. outputs = _forward_(**model_kwargs)
  1494. input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
  1495. outputs,
  1496. input_ids,
  1497. cur_len_gpu,
  1498. origin_len_gpu,
  1499. scores,
  1500. unfinished_flag,
  1501. model_kwargs,
  1502. pad_token_id,
  1503. )
  1504. cur_len += 1
  1505. cur_len_gpu += 1
  1506. attn_mask = model_kwargs["attention_mask"]
  1507. # make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
  1508. model_kwargs["attention_mask"] = paddle.reshape(attn_mask, attn_mask.shape)
  1509. model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
  1510. max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")
  1511. while cur_len < max_new_tokens and paddle.any(unfinished_flag):
  1512. input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
  1513. _forward_(**model_kwargs),
  1514. input_ids,
  1515. cur_len_gpu,
  1516. origin_len_gpu,
  1517. scores,
  1518. unfinished_flag,
  1519. model_kwargs,
  1520. pad_token_id,
  1521. )
  1522. cur_len += 1
  1523. cur_len_gpu += 1
  1524. return input_ids[:, origin_len:], scores
  1525. def reorder_cache(self, cache, beam_idx):
  1526. cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache)
  1527. return cache
  1528. def beam_search(
  1529. self,
  1530. input_ids,
  1531. beam_scorer,
  1532. logits_processors,
  1533. max_length,
  1534. diversity_rate,
  1535. pad_token_id,
  1536. eos_token_id,
  1537. stopping_criteria=None,
  1538. fast_ptq_sampling=False,
  1539. trunc_input=True,
  1540. synced_gpus=False,
  1541. **model_kwargs,
  1542. ):
  1543. logits_processors = (
  1544. logits_processors
  1545. if logits_processors is not None
  1546. else LogitsProcessorList()
  1547. )
  1548. # max_length will be convert to MaxLengthCriteria
  1549. stopping_criteria = (
  1550. stopping_criteria
  1551. if stopping_criteria is not None
  1552. else StoppingCriteriaList()
  1553. )
  1554. if max_length is not None:
  1555. # logging.warning(
  1556. # "`max_length` is deprecated in this function, use"
  1557. # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
  1558. # )
  1559. stopping_criteria = validate_stopping_criteria(
  1560. stopping_criteria, max_length
  1561. )
  1562. batch_size = len(beam_scorer._beam_hyps)
  1563. num_beams = beam_scorer.num_beams
  1564. batch_beam_size, cur_len = input_ids.shape
  1565. origin_len = cur_len
  1566. assert (
  1567. num_beams * batch_size == batch_beam_size
  1568. ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
  1569. num_beams * batch_size, batch_beam_size
  1570. )
  1571. beam_scores = paddle.zeros(
  1572. (batch_size, num_beams), dtype=paddle.get_default_dtype()
  1573. )
  1574. beam_scores[:, 1:] = get_scale_by_dtype(return_positive=False)
  1575. beam_scores = paddle.reshape(beam_scores, [-1])
  1576. generate_end = False
  1577. while True:
  1578. if synced_gpus:
  1579. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
  1580. # The following logic allows an early break if all peers finished generating their sequence
  1581. this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
  1582. # send 0.0 if we finished, 1.0 otherwise
  1583. dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
  1584. # did all peers finish? the reduced sum will be 0.0 then
  1585. if this_peer_finished_flag.item() == 0.0:
  1586. break
  1587. # prepare model inputs & get model output
  1588. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  1589. outputs = self(**model_inputs)
  1590. if synced_gpus and generate_end:
  1591. cur_len = cur_len + 1
  1592. continue # don't waste resources running the code we don't need
  1593. if isinstance(outputs, tuple):
  1594. logits = outputs[0]
  1595. elif isinstance(outputs, ModelOutput):
  1596. logits = outputs.logits
  1597. else:
  1598. logits = outputs
  1599. # [batch_size, vocab_size]
  1600. logits = logits[:, -1, :]
  1601. # pre-process distribution
  1602. logits = self.adjust_logits_during_generation(logits)
  1603. # beam search
  1604. # [batch_size * num_beams, vocab_size]
  1605. next_scores = F.softmax(logits)
  1606. next_scores = paddle.log(next_scores)
  1607. next_scores = logits_processors(input_ids, next_scores)
  1608. next_scores = next_scores + beam_scores.unsqueeze(-1)
  1609. vocab_size = next_scores.shape[-1]
  1610. if diversity_rate == 0.0:
  1611. # reshape for beam search
  1612. next_scores = next_scores.reshape([batch_size, num_beams * vocab_size])
  1613. next_scores, next_tokens = paddle.topk(
  1614. next_scores, 2 * num_beams, axis=1
  1615. )
  1616. next_indices = next_tokens // vocab_size
  1617. next_tokens = next_tokens % vocab_size
  1618. else:
  1619. next_scores, next_tokens = paddle.topk(
  1620. next_scores, 2 * num_beams, axis=1
  1621. )
  1622. sibling_score = (
  1623. paddle.arange(1, 2 * num_beams + 1, dtype="int64").unsqueeze(0)
  1624. * diversity_rate
  1625. )
  1626. diversed_score = next_scores - sibling_score
  1627. next_scores = next_scores.reshape(
  1628. [batch_size, 2 * num_beams * num_beams]
  1629. )
  1630. next_tokens = next_tokens.reshape(
  1631. [batch_size, 2 * num_beams * num_beams]
  1632. )
  1633. diversed_score = diversed_score.reshape(
  1634. [batch_size, 2 * num_beams * num_beams]
  1635. )
  1636. diversed_score, diversed_tokens = paddle.topk(
  1637. diversed_score, 2 * num_beams, axis=1
  1638. )
  1639. # TODO
  1640. # Use gather_nd() to select origan token and score
  1641. next_scores = paddle.stack(
  1642. [
  1643. paddle.index_select(next_scores[i], diversed_tokens[i])
  1644. for i in range(next_scores.shape[0])
  1645. ]
  1646. )
  1647. next_tokens = paddle.stack(
  1648. [
  1649. paddle.index_select(next_tokens[i], diversed_tokens[i])
  1650. for i in range(next_tokens.shape[0])
  1651. ]
  1652. )
  1653. next_indices = diversed_tokens // (2 * num_beams)
  1654. # stateless
  1655. beam_outputs = beam_scorer.process(
  1656. input_ids,
  1657. next_scores,
  1658. next_tokens,
  1659. next_indices,
  1660. origin_len=origin_len,
  1661. pad_token_id=pad_token_id,
  1662. eos_token_id=eos_token_id,
  1663. )
  1664. beam_scores = beam_outputs["next_beam_scores"]
  1665. beam_next_tokens = beam_outputs["next_beam_tokens"]
  1666. beam_idx = beam_outputs["next_beam_indices"]
  1667. # beam_idx may contain element -1 and cause error
  1668. # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
  1669. beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
  1670. cur_len += 1
  1671. input_ids = paddle.concat(
  1672. [
  1673. paddle.index_select(input_ids, beam_idx),
  1674. beam_next_tokens.unsqueeze(-1),
  1675. ],
  1676. axis=-1,
  1677. )
  1678. if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
  1679. if not synced_gpus:
  1680. break
  1681. else:
  1682. generate_end = True
  1683. model_kwargs = self.update_model_kwargs_for_generation(
  1684. outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
  1685. )
  1686. if "cache" in model_kwargs:
  1687. # reorder the cache
  1688. model_kwargs["cache"] = self.reorder_cache(
  1689. model_kwargs["cache"], beam_idx
  1690. )
  1691. if "past_key_values" in model_kwargs:
  1692. # reorder the cache
  1693. model_kwargs["past_key_values"] = self.reorder_cache(
  1694. model_kwargs["past_key_values"], beam_idx
  1695. )
  1696. if fast_ptq_sampling:
  1697. break
  1698. pred_ids, scores = beam_scorer.finalize(
  1699. input_ids,
  1700. beam_scores,
  1701. next_tokens,
  1702. next_indices,
  1703. origin_len=origin_len,
  1704. pad_token_id=pad_token_id,
  1705. eos_token_id=eos_token_id,
  1706. )
  1707. return pred_ids[:, origin_len:] if trunc_input else input_ids, scores
  1708. def group_beam_search(
  1709. self,
  1710. input_ids,
  1711. beam_scorer,
  1712. logits_processors,
  1713. max_length,
  1714. pad_token_id,
  1715. eos_token_id,
  1716. stopping_criteria=None,
  1717. fast_ptq_sampling=False,
  1718. trunc_input=True,
  1719. synced_gpus=False,
  1720. **model_kwargs,
  1721. ):
  1722. logits_processors = (
  1723. logits_processors
  1724. if logits_processors is not None
  1725. else LogitsProcessorList()
  1726. )
  1727. # max_length will be convert to MaxLengthCriteria
  1728. stopping_criteria = (
  1729. stopping_criteria
  1730. if stopping_criteria is not None
  1731. else StoppingCriteriaList()
  1732. )
  1733. if max_length is not None:
  1734. # logging.warning(
  1735. # "`max_length` is deprecated in this function, use"
  1736. # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
  1737. # )
  1738. stopping_criteria = validate_stopping_criteria(
  1739. stopping_criteria, max_length
  1740. )
  1741. batch_size = len(beam_scorer._beam_hyps)
  1742. num_beams = beam_scorer.num_beams
  1743. num_beam_groups = beam_scorer.num_beam_groups
  1744. num_sub_beams = num_beams // num_beam_groups
  1745. batch_beam_size, cur_len = input_ids.shape
  1746. origin_len = cur_len
  1747. assert (
  1748. num_beams * batch_size == batch_beam_size
  1749. ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
  1750. num_beams * batch_size, batch_beam_size
  1751. )
  1752. beam_scores = paddle.full(
  1753. (batch_size, num_beams),
  1754. get_scale_by_dtype(return_positive=False),
  1755. dtype="float32",
  1756. )
  1757. # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
  1758. # the same group don't produce same tokens everytime.
  1759. beam_scores[:, ::num_sub_beams] = 0
  1760. beam_scores = paddle.reshape(beam_scores, [-1])
  1761. generate_end = False
  1762. while True:
  1763. if synced_gpus:
  1764. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
  1765. # The following logic allows an early break if all peers finished generating their sequence
  1766. this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
  1767. # send 0.0 if we finished, 1.0 otherwise
  1768. dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
  1769. # did all peers finish? the reduced sum will be 0.0 then
  1770. if this_peer_finished_flag.item() == 0.0:
  1771. break
  1772. # predicted tokens in cur_len step
  1773. current_tokens = paddle.zeros(
  1774. shape=[batch_size * num_beams], dtype=input_ids.dtype
  1775. )
  1776. # indices which will form the beams in the next time step
  1777. reordering_indices = paddle.zeros(
  1778. shape=[batch_size * num_beams], dtype="int64"
  1779. )
  1780. # prepare model inputs & get model output
  1781. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  1782. outputs = self(**model_inputs)
  1783. if synced_gpus and generate_end:
  1784. cur_len = cur_len + 1
  1785. continue # don't waste resources running the code we don't need
  1786. for beam_group_idx in range(num_beam_groups):
  1787. group_start_idx = beam_group_idx * num_sub_beams
  1788. group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
  1789. group_size = group_end_idx - group_start_idx
  1790. # indices of beams of current group among all sentences in batch
  1791. batch_group_indices = []
  1792. for batch_idx in range(batch_size):
  1793. batch_group_indices.extend(
  1794. [
  1795. batch_idx * num_beams + idx
  1796. for idx in range(group_start_idx, group_end_idx)
  1797. ]
  1798. )
  1799. group_input_ids = input_ids[batch_group_indices]
  1800. if isinstance(outputs, tuple):
  1801. logits = outputs[0]
  1802. elif isinstance(outputs, ModelOutput):
  1803. logits = outputs.logits
  1804. else:
  1805. logits = outputs
  1806. logits = logits[:, -1, :]
  1807. logits = paddle.index_select(
  1808. logits, paddle.to_tensor(batch_group_indices)
  1809. )
  1810. logits = self.adjust_logits_during_generation(logits)
  1811. next_scores = F.softmax(logits)
  1812. next_scores = paddle.log(next_scores)
  1813. vocab_size = next_scores.shape[-1]
  1814. next_scores = logits_processors(
  1815. group_input_ids,
  1816. next_scores,
  1817. current_tokens=current_tokens,
  1818. beam_group_idx=beam_group_idx,
  1819. )
  1820. next_scores = next_scores + beam_scores[batch_group_indices].unsqueeze(
  1821. -1
  1822. )
  1823. # reshape for beam search
  1824. next_scores = next_scores.reshape([batch_size, group_size * vocab_size])
  1825. next_scores, next_tokens = paddle.topk(
  1826. next_scores, 2 * group_size, axis=1
  1827. )
  1828. next_indices = next_tokens // vocab_size
  1829. next_tokens = next_tokens % vocab_size
  1830. beam_outputs = beam_scorer.process(
  1831. group_input_ids,
  1832. next_scores,
  1833. next_tokens,
  1834. next_indices,
  1835. origin_len=origin_len,
  1836. pad_token_id=pad_token_id,
  1837. eos_token_id=eos_token_id,
  1838. )
  1839. beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
  1840. beam_next_tokens = beam_outputs["next_beam_tokens"]
  1841. beam_idx = beam_outputs["next_beam_indices"]
  1842. # beam_idx may contain element -1 and cause error
  1843. # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
  1844. beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
  1845. input_ids[batch_group_indices] = group_input_ids[beam_idx]
  1846. group_input_ids = paddle.concat(
  1847. [
  1848. paddle.index_select(group_input_ids, index=beam_idx),
  1849. beam_next_tokens.unsqueeze(-1),
  1850. ],
  1851. axis=-1,
  1852. )
  1853. current_tokens[batch_group_indices] = beam_next_tokens
  1854. reordering_indices[batch_group_indices] = (
  1855. num_beams * (beam_idx // group_size)
  1856. + group_start_idx
  1857. + (beam_idx % group_size)
  1858. )
  1859. input_ids = paddle.concat(
  1860. [input_ids, current_tokens.unsqueeze(-1)], axis=-1
  1861. )
  1862. cur_len += 1
  1863. if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
  1864. if not synced_gpus:
  1865. break
  1866. else:
  1867. generate_end = True
  1868. model_kwargs = self.update_model_kwargs_for_generation(
  1869. outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
  1870. )
  1871. if "cache" in model_kwargs:
  1872. # reorder the cache
  1873. model_kwargs["cache"] = self.reorder_cache(
  1874. model_kwargs["cache"], reordering_indices
  1875. )
  1876. if "past_key_values" in model_kwargs:
  1877. # reorder the cache
  1878. model_kwargs["past_key_values"] = self.reorder_cache(
  1879. model_kwargs["past_key_values"], reordering_indices
  1880. )
  1881. if fast_ptq_sampling:
  1882. break
  1883. pred_ids, scores = beam_scorer.finalize(
  1884. input_ids,
  1885. beam_scores,
  1886. next_tokens,
  1887. next_indices,
  1888. origin_len=origin_len,
  1889. pad_token_id=pad_token_id,
  1890. eos_token_id=eos_token_id,
  1891. )
  1892. return pred_ids[:, origin_len:] if trunc_input else input_ids, scores