logits_process.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. from abc import ABC
  16. from collections import OrderedDict
  17. from typing import Callable, Dict, List, Tuple, Union
  18. import numpy as np
  19. import paddle
  20. from paddle.nn.layer.layers import in_declarative_mode
  21. class LogitsProcessor(ABC):
  22. """
  23. Abstract base class for all logit processors that can be applied during
  24. generation.
  25. """
  26. def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor):
  27. raise NotImplementedError(
  28. f"{self.__class__} is an abstract class. "
  29. "Only classes inheriting this class can be called."
  30. )
  31. class LogitsProcessorList:
  32. """use ordered dict to store processors"""
  33. def __init__(self, processors: List[LogitsProcessor] = None) -> None:
  34. self._processors = OrderedDict()
  35. processors = processors or []
  36. for processor in processors:
  37. self.append(processor)
  38. def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor, **kwargs):
  39. for processor in self._processors.values():
  40. processor_args = inspect.signature(processor.__call__).parameters
  41. if len(processor_args) > 2:
  42. assert all(
  43. arg in kwargs for arg in list(processor_args.keys())[2:]
  44. ), f"The parameters don't match for {processor.__class__}"
  45. logits = processor(input_ids, logits, **kwargs)
  46. else:
  47. logits = processor(input_ids, logits)
  48. return logits
  49. def append(self, processor: LogitsProcessor):
  50. self._processors[len(self._processors)] = processor
  51. class MinLengthLogitsProcessor(LogitsProcessor):
  52. r"""
  53. Enforcing a min-length by setting EOS probability to 0.
  54. Args:
  55. min_length (int): The minimum length of generation sequence.
  56. eos_token_id (int): The id of the `end-of-sequence` token.
  57. """
  58. def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
  59. if min_length < 0 and not in_declarative_mode():
  60. raise ValueError(
  61. "`min_length` should be a positive integer, but get {}".format(
  62. min_length
  63. )
  64. )
  65. if not isinstance(eos_token_id, int) or eos_token_id < 0:
  66. raise ValueError(
  67. "`eos_token_id` should be a positive integer, but get {}".format(
  68. eos_token_id
  69. )
  70. )
  71. self.min_length = min_length
  72. self.eos_token_id = eos_token_id
  73. def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor):
  74. cur_len = input_ids.shape[-1]
  75. if cur_len < self.min_length:
  76. logits[:, self.eos_token_id] = paddle.finfo(logits.dtype).min
  77. return logits
  78. class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
  79. r"""
  80. Enforcing an exponential penalty on repeated sequences.
  81. Args:
  82. repetition_penalty (float):
  83. The parameter for repetition penalty. 1.0 means no penalty. See `this paper
  84. <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
  85. """
  86. def __init__(self, penalty: float):
  87. if not (penalty > 0) and not in_declarative_mode():
  88. raise ValueError(
  89. f"`penalty` has to be a strictly positive float, but is {penalty}"
  90. )
  91. self.penalty = penalty
  92. def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor):
  93. score = paddle.index_sample(logits, input_ids)
  94. score = paddle.where(score < 0, score * self.penalty, score / self.penalty)
  95. input_ids = (
  96. input_ids
  97. + paddle.arange(logits.shape[0], dtype="int64").unsqueeze(-1)
  98. * logits.shape[-1]
  99. )
  100. outputs = paddle.scatter(
  101. logits.flatten(), input_ids.flatten(), score.flatten()
  102. ).reshape(logits.shape)
  103. return outputs
  104. def _get_ngrams(ngram_size: int, prev_input_ids: paddle.Tensor, num_hypos: int):
  105. """
  106. Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
  107. this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.
  108. Args:
  109. ngram_size (`int`):
  110. The number sequential tokens taken as a group which may only occur once before being banned.
  111. prev_input_ids (`paddle.Tensor`):
  112. Generated token ids for the current hypothesis.
  113. num_hypos (`int`):
  114. The number of hypotheses for which n-grams need to be generated.
  115. Returns:
  116. generated_ngrams (`dict`):
  117. Dictionary of generated ngrams.
  118. """
  119. generated_ngrams = [{} for _ in range(num_hypos)]
  120. for idx in range(num_hypos):
  121. gen_tokens = prev_input_ids[idx].tolist()
  122. generated_ngram = generated_ngrams[idx]
  123. for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
  124. prev_ngram_tuple = tuple(ngram[:-1])
  125. generated_ngram[prev_ngram_tuple] = generated_ngram.get(
  126. prev_ngram_tuple, []
  127. ) + [ngram[-1]]
  128. return generated_ngrams
  129. def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
  130. """
  131. Determines the banned tokens for the current hypothesis based on previously generated n-grams.
  132. Args:
  133. banned_ngrams (`dict`):
  134. A dictionary containing previously generated n-grams for each hypothesis.
  135. prev_input_ids (`paddle.Tensor`):
  136. Generated token ids for the current hypothesis.
  137. ngram_size (`int`):
  138. The number sequential tokens taken as a group which may only occur once before being banned.
  139. cur_len (`int`):
  140. The current length of the token sequences for which the n-grams are being checked.
  141. Returns:
  142. List of tokens that are banned.
  143. """
  144. start_idx = cur_len + 1 - ngram_size
  145. ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
  146. return banned_ngrams.get(ngram_idx, [])
  147. def _calc_banned_ngram_tokens(
  148. ngram_size: int, prev_input_ids: paddle.Tensor, num_hypos: int, cur_len: int
  149. ):
  150. """Copied from fairseq for no_repeat_ngram in beam_search"""
  151. if cur_len + 1 < ngram_size:
  152. # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
  153. return [[] for _ in range(num_hypos)]
  154. generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
  155. banned_tokens = [
  156. _get_generated_ngrams(
  157. generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len
  158. )
  159. for hypo_idx in range(num_hypos)
  160. ]
  161. return banned_tokens
  162. class NoRepeatNGramLogitsProcessor(LogitsProcessor):
  163. r"""
  164. [`LogitsProcessor`] that enforces no repetition of n-grams. See
  165. [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
  166. Args:
  167. ngram_size (`int`):
  168. All ngrams of size `ngram_size` can only occur once.
  169. """
  170. def __init__(self, ngram_size: int):
  171. if not isinstance(ngram_size, int) or ngram_size <= 0:
  172. raise ValueError(
  173. f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}"
  174. )
  175. self.ngram_size = ngram_size
  176. def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
  177. num_batch_hypotheses = scores.shape[0]
  178. cur_len = input_ids.shape[-1]
  179. banned_batch_tokens = _calc_banned_ngram_tokens(
  180. self.ngram_size, input_ids, num_batch_hypotheses, cur_len
  181. )
  182. for i, banned_tokens in enumerate(banned_batch_tokens):
  183. if len(banned_tokens) == 0:
  184. continue
  185. scores[i, banned_tokens] = paddle.finfo(scores.dtype).min
  186. return scores
  187. class HammingDiversityLogitsProcessor(LogitsProcessor):
  188. """
  189. This `LogitsProcessor` enforces diverse beam search. Note that this logits
  190. processor is only effective for `group_beam_search`. See
  191. `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
  192. Args:
  193. diversity_rate (float): This value is subtracted from a beam's score if
  194. it generates a token same as any beam from other group at a particular
  195. time.
  196. num_beams (int): Number of beams used for group beam search.
  197. num_beam_groups (int): Number of groups to divide `num_beams` into in order
  198. to ensure diversity among different groups of beams.
  199. """
  200. def __init__(self, diversity_rate: float, num_beams: int, num_beam_groups: int):
  201. if not isinstance(diversity_rate, float) or (not diversity_rate > 0.0):
  202. raise ValueError(
  203. "`diversity_rate` should be a float strictly larger than 0."
  204. )
  205. self._diversity_rate = diversity_rate
  206. if not isinstance(num_beams, int) or num_beams < 2:
  207. raise ValueError("`num_beams` should be an integer strictly larger than 1.")
  208. self._num_beams = num_beams
  209. if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
  210. raise ValueError(
  211. "`num_beam_groups` should be an integer strictly larger than 1."
  212. )
  213. self._num_sub_beams = num_beams // num_beam_groups
  214. def __call__(
  215. self,
  216. input_ids: paddle.Tensor,
  217. scores: paddle.Tensor,
  218. current_tokens: paddle.Tensor,
  219. beam_group_idx: int,
  220. ):
  221. batch_size = current_tokens.shape[0] // self._num_beams
  222. group_start_idx = beam_group_idx * self._num_sub_beams
  223. group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
  224. group_size = group_end_idx - group_start_idx
  225. vocab_size = scores.shape[-1]
  226. if group_start_idx == 0:
  227. return scores
  228. for batch_idx in range(batch_size):
  229. previous_group_tokens = current_tokens[
  230. batch_idx * self._num_beams : batch_idx * self._num_beams
  231. + group_start_idx
  232. ]
  233. token_frequency = paddle.bincount(
  234. previous_group_tokens, minlength=vocab_size
  235. )
  236. scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= (
  237. self._diversity_rate * token_frequency
  238. )
  239. return scores
  240. class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
  241. """
  242. This `LogitsProcessor` enforces the first generated token to be the selected `forced_bos_token`.
  243. Args:
  244. forced_bos_token_id (:obj:`int`):
  245. The id of the token to be generated as the first token.
  246. """
  247. def __init__(self, forced_bos_token_id: int):
  248. self.forced_bos_token_id = forced_bos_token_id
  249. def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
  250. cur_len = input_ids.shape[-1]
  251. if cur_len == 1:
  252. scores[:] = paddle.finfo(scores.dtype).min
  253. scores[:, self.forced_bos_token_id] = 0
  254. return scores
  255. class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
  256. """
  257. This `LogitsProcessor` enforces the last generated token to be the selected `forced_eos_token`.
  258. Args:
  259. max_length (int): The maximum length of the sequence to be generated.
  260. forced_eos_token_id (int): The id of the token to be generated as the last token.
  261. """
  262. def __init__(self, max_length: int, forced_eos_token_id: Union[int, List[int]]):
  263. self.max_length = max_length
  264. self.forced_eos_token_id = forced_eos_token_id
  265. def __call__(self, input_ids, scores):
  266. cur_len = input_ids.shape[-1]
  267. if cur_len == self.max_length - 1:
  268. scores[:] = paddle.finfo(scores.dtype).min
  269. scores[:, self.forced_eos_token_id] = 0
  270. return scores
  271. def TopKProcess(probs: paddle.Tensor, top_k: int, min_tokens_to_keep: int):
  272. top_k = paddle.minimum(
  273. paddle.maximum(paddle.to_tensor(top_k), paddle.to_tensor(min_tokens_to_keep)),
  274. paddle.to_tensor(probs.shape[-1]),
  275. )
  276. # Remove all tokens with a probability less than the last token of the top-k
  277. # cast to float16 to support generation & d2s
  278. if probs.dtype == paddle.bfloat16:
  279. probs = paddle.cast(probs, paddle.float32)
  280. topk_probs, _ = paddle.topk(probs, k=top_k)
  281. topk_probs = paddle.cast(topk_probs, paddle.bfloat16)
  282. else:
  283. topk_probs, _ = paddle.topk(probs, k=top_k)
  284. probs = paddle.where(
  285. probs >= topk_probs[:, -1:], probs, paddle.full_like(probs, 0.0)
  286. )
  287. return probs
  288. def TopPProcess(probs: paddle.Tensor, top_p: float, min_tokens_to_keep: int):
  289. if probs.dtype == paddle.bfloat16:
  290. probs = paddle.cast(probs, paddle.float32)
  291. sorted_indices = paddle.argsort(probs, descending=True)
  292. sorted_probs = paddle.sort(probs, descending=True)
  293. sorted_probs = paddle.cast(sorted_probs, paddle.bfloat16)
  294. else:
  295. sorted_indices = paddle.argsort(probs, descending=True)
  296. sorted_probs = paddle.sort(probs, descending=True)
  297. cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
  298. # Remove tokens with cumulative probs above the top_p, But keep at
  299. # least min_tokens_to_keep tokens
  300. sorted_indices_to_remove = cumulative_probs > top_p
  301. if min_tokens_to_keep > 1:
  302. # Set 'min_tokens_to_keep - 1' because the first token is kept
  303. sorted_indices_to_remove[:, : min_tokens_to_keep - 1] = 0
  304. # Keep the first token
  305. sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64")
  306. sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
  307. sorted_indices_to_remove[:, 0] = 0
  308. # Scatter sorted tensors to original indexing
  309. sorted_indices = (
  310. sorted_indices
  311. + paddle.arange(probs.shape[0], dtype="int64").unsqueeze(-1) * probs.shape[-1]
  312. )
  313. condition = paddle.scatter(
  314. sorted_indices_to_remove.flatten(),
  315. sorted_indices.flatten(),
  316. sorted_indices_to_remove.flatten(),
  317. )
  318. condition = paddle.cast(condition, "bool").reshape(probs.shape)
  319. probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs)
  320. return probs
  321. class LogitsWarper:
  322. """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
  323. def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
  324. raise NotImplementedError(
  325. f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
  326. )
  327. class TemperatureLogitsWarper(LogitsWarper):
  328. r"""
  329. [`LogitsWarper`] for temperature (exponential scaling output probability distribution).
  330. Args:
  331. temperature (`float`):
  332. The value used to module the logits distribution.
  333. """
  334. def __init__(self, temperature: float):
  335. if not isinstance(temperature, float) or not (temperature > 0):
  336. raise ValueError(
  337. f"`temperature` has to be a strictly positive float, but is {temperature}"
  338. )
  339. self.temperature = temperature
  340. def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
  341. scores = scores / self.temperature
  342. return scores
  343. class SequenceBiasLogitsProcessor(LogitsProcessor):
  344. """
  345. [`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
  346. when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
  347. one token, consider using beam methods (to gracefully work around partially completed sequences that have a
  348. negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
  349. <Tip>
  350. In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
  351. initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
  352. `add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
  353. come from `pre tokenizers`.
  354. </Tip>
  355. Args:
  356. sequence_bias (`Dict[Tuple[int], float]`):
  357. Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
  358. sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
  359. will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
  360. completed (in the token selection step after this processor is applied).
  361. Examples:
  362. ```python
  363. >>> from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM
  364. >>> model = AutoModelForCausalLM.from_pretrained("gpt2-en")
  365. >>> tokenizer = AutoTokenizer.from_pretrained("gpt2-en")
  366. >>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
  367. >>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
  368. >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
  369. The full name of Donald is Donald J. Trump Jr
  370. >>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
  371. >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2-en")
  372. >>> def get_tokens_as_tuple(word):
  373. ... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
  374. >>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
  375. >>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
  376. >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
  377. >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
  378. The full name of Donald is Donald J. Donald,
  379. >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
  380. >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
  381. The full name of Donald is Donald Rumsfeld,
  382. >>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
  383. >>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
  384. >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
  385. >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
  386. The full name of Donald is Donald Duck.
  387. ```
  388. """
  389. def __init__(self, sequence_bias: Dict[Tuple[int], float]):
  390. self.sequence_bias = sequence_bias
  391. self._validate_arguments()
  392. # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
  393. # is inferred in the first usage, which inhibits initializing here)
  394. self.length_1_bias = None
  395. self.prepared_bias_variables = False
  396. def __call__(self, input_ids, scores):
  397. # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
  398. if not self.prepared_bias_variables:
  399. self._prepare_bias_variables(scores)
  400. # 2 - prepares an empty bias to add
  401. bias = paddle.zeros_like(scores)
  402. # 3 - include the bias from length = 1
  403. if self.length_1_bias is not None:
  404. bias += self.length_1_bias
  405. # 4 - include the bias from length > 1, after determining which biased sequences may be completed.
  406. for sequence_ids, sequence_bias in self.sequence_bias.items():
  407. if len(sequence_ids) == 1: # the sequence is of length 1, already applied
  408. continue
  409. if (
  410. len(sequence_ids) > input_ids.shape[1]
  411. ): # the sequence is longer than the context, ignore
  412. continue
  413. prefix_length = len(sequence_ids) - 1
  414. last_token = sequence_ids[-1]
  415. matching_rows = (
  416. paddle.equal(
  417. input_ids[:, -prefix_length:],
  418. paddle.to_tensor(sequence_ids[:-1], dtype=input_ids.dtype),
  419. )
  420. .astype(paddle.int64)
  421. .prod(axis=1)
  422. )
  423. bias[:, last_token] += paddle.where(
  424. matching_rows == 1,
  425. paddle.to_tensor(sequence_bias),
  426. paddle.to_tensor(0.0),
  427. )
  428. # 5 - apply the bias to the scores
  429. scores = scores + bias
  430. return scores
  431. def _prepare_bias_variables(self, scores):
  432. vocabulary_size = scores.shape[-1]
  433. # Check biased tokens out of bounds
  434. invalid_biases = []
  435. for sequence_ids in self.sequence_bias:
  436. for token_id in sequence_ids:
  437. if token_id >= vocabulary_size:
  438. invalid_biases.append(token_id)
  439. if len(invalid_biases) > 0:
  440. raise ValueError(
  441. f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: "
  442. f"{invalid_biases}"
  443. )
  444. # Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
  445. # with simpler logic.
  446. self.length_1_bias = paddle.zeros((vocabulary_size,))
  447. for sequence_ids, bias in self.sequence_bias.items():
  448. if len(sequence_ids) == 1:
  449. self.length_1_bias[sequence_ids[-1]] = bias
  450. self.prepared_bias_variables = True
  451. def _validate_arguments(self):
  452. sequence_bias = self.sequence_bias
  453. if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
  454. raise ValueError(
  455. f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}."
  456. )
  457. if any(
  458. not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()
  459. ):
  460. raise ValueError(
  461. f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}."
  462. )
  463. if any(
  464. any(
  465. (not isinstance(token_id, (int, np.integer)) or token_id < 0)
  466. for token_id in sequence_ids
  467. )
  468. or len(sequence_ids) == 0
  469. for sequence_ids in sequence_bias.keys()
  470. ):
  471. raise ValueError(
  472. f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
  473. f"{sequence_bias}."
  474. )
  475. if any(not isinstance(bias, float) for bias in sequence_bias.values()):
  476. raise ValueError(
  477. f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}."
  478. )
  479. class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
  480. """
  481. [`LogitsProcessor`] that enforces that specified sequences will never be selected.
  482. <Tip>
  483. In order to get the token ids of the words that should not appear in the generated text, make sure to set
  484. `add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
  485. add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
  486. as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
  487. [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
  488. </Tip>
  489. Args:
  490. bad_words_ids (`List[List[int]]`):
  491. List of list of token ids that are not allowed to be generated.
  492. eos_token_id (`Union[int, List[int]]`):
  493. The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
  494. Examples:
  495. ```python
  496. >>> from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM
  497. >>> model = AutoModelForCausalLM.from_pretrained("gpt2-en")
  498. >>> tokenizer = AutoTokenizer.from_pretrained("gpt2-en")
  499. >>> inputs = tokenizer(["In a word, the cake is a"], return_tensors="pt")
  500. >>> output_ids = model.generate(inputs["input_ids"], max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
  501. >>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
  502. In a word, the cake is a bit of a mess.
  503. >>> # Now let's take the bad words out. Please note that the tokenizer is initialized differently
  504. >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2-en", add_prefix_space=True)
  505. >>> def get_tokens_as_list(word_list):
  506. ... "Converts a sequence of words into a list of tokens"
  507. ... tokens_list = []
  508. ... for word in word_list:
  509. ... tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
  510. ... tokens_list.append(tokenized_word)
  511. ... return tokens_list
  512. >>> bad_words_ids = get_tokens_as_list(word_list=["mess"])
  513. >>> output_ids = model.generate(
  514. ... inputs["input_ids"], max_new_tokens=5, bad_words_ids=bad_words_ids, pad_token_id=tokenizer.eos_token_id
  515. ... )
  516. >>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
  517. In a word, the cake is a bit of a surprise.
  518. ```
  519. >>> from paddlenlp.transformers.generation import NoBadWordsLogitsProcessor, LogitsProcessorList
  520. >>> logits_processors = LogitsProcessorList([NoBadWordsLogitsProcessor([[5,6]], eos_token_id=tokenizer.eos_token_id)])
  521. >>> output_ids = model.generate(
  522. ... inputs["input_ids"], max_new_tokens=5, logits_processors=logits_processors, pad_token_id=tokenizer.eos_token_id
  523. ... )
  524. >>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
  525. In a word, the cake is a bit of a surprise.
  526. ```
  527. """
  528. def __init__(
  529. self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]
  530. ):
  531. self.bad_word_ids = bad_words_ids
  532. self._validate_arguments()
  533. # Filter EOS token from bad_words_ids
  534. if eos_token_id is None:
  535. eos_token_id = []
  536. if isinstance(eos_token_id, int):
  537. eos_token_id = [eos_token_id]
  538. bad_words_ids = list(
  539. filter(
  540. lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id),
  541. bad_words_ids,
  542. )
  543. )
  544. # Forbidding a sequence is equivalent to setting its bias to -inf
  545. sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
  546. super().__init__(sequence_bias=sequence_bias)
  547. def _validate_arguments(self):
  548. bad_words_ids = self.bad_word_ids
  549. if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
  550. raise ValueError(
  551. f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}."
  552. )
  553. if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
  554. raise ValueError(
  555. f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}."
  556. )
  557. if any(
  558. any(
  559. (not isinstance(token_id, (int, np.integer)) or token_id < 0)
  560. for token_id in bad_word_ids
  561. )
  562. for bad_word_ids in bad_words_ids
  563. ):
  564. raise ValueError(
  565. f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
  566. )
  567. class PrefixConstrainedLogitsProcessor(LogitsProcessor):
  568. r"""
  569. [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
  570. generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information.
  571. Args:
  572. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`):
  573. This function constraints the beam search to allowed tokens only at each step. This function takes 2
  574. arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
  575. next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
  576. `batch_id`.
  577. """
  578. def __init__(
  579. self,
  580. prefix_allowed_tokens_fn: Callable[[int, paddle.Tensor], List[int]],
  581. num_beams: int,
  582. ):
  583. self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
  584. self._num_beams = num_beams
  585. def __call__(
  586. self, input_ids: paddle.Tensor, scores: paddle.Tensor
  587. ) -> paddle.Tensor:
  588. mask = paddle.full_like(scores, paddle.finfo(scores.dtype).min)
  589. for batch_id, beam_sent in enumerate(
  590. input_ids.reshape([-1, self._num_beams, input_ids.shape[-1]])
  591. ):
  592. for beam_id, sent in enumerate(beam_sent):
  593. mask[
  594. batch_id * self._num_beams + beam_id,
  595. self._prefix_allowed_tokens_fn(batch_id, sent),
  596. ] = 0
  597. return scores + mask