| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- import inspect
- from typing import List, Optional, Union
- import paddle
- import paddle.distributed as dist
- import paddle.nn as nn
- import paddle.nn.functional as F
- from paddle import Tensor
- from paddle.common_ops_import import convert_dtype
- from paddle.utils import map_structure
- from ......utils import logging
- from ..transformers.model_outputs import ModelOutput
- from .configuration_utils import DEFAULT_MAX_NEW_TOKENS, GenerationConfig
- from .logits_process import (
- ForcedBOSTokenLogitsProcessor,
- ForcedEOSTokenLogitsProcessor,
- HammingDiversityLogitsProcessor,
- LogitsProcessor,
- LogitsProcessorList,
- MinLengthLogitsProcessor,
- NoRepeatNGramLogitsProcessor,
- RepetitionPenaltyLogitsProcessor,
- TopKProcess,
- TopPProcess,
- )
- from .stopping_criteria import (
- StoppingCriteria,
- StoppingCriteriaList,
- validate_stopping_criteria,
- )
- __all__ = [
- "GenerationMixin",
- "BeamSearchScorer",
- "BeamHypotheses",
- "LogitsProcessorList",
- "LogitsProcessor",
- "MinLengthLogitsProcessor",
- "RepetitionPenaltyLogitsProcessor",
- "TopKProcess",
- "TopPProcess",
- "get_unfinished_flag",
- ]
- def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float:
- """get scale value by dtype
- Args:
- dtype (str): the string dtype value
- Returns:
- float: the scale value
- """
- if dtype is None:
- dtype = paddle.get_default_dtype()
- dtype = convert_dtype(dtype)
- scale_value = 1e6
- # TODO(wj-Mcaf): support int8, int4 dtypes later
- if dtype == "float16":
- scale_value = 1e4
- if return_positive:
- return scale_value
- return -1 * scale_value
- def get_unfinished_flag(
- input_ids: Tensor,
- unfinished_flag: Tensor,
- eos_token_id: Union[int, List[int], List[List[int]]],
- ) -> Tensor:
- """get unfinished flag for generation step
- Args:
- input_ids (Tensor): the input_ids
- eos_token_id (Union[int, list[int], list[list[int]]]): the end os sentence flag, which can be:
- * single token id, eg: 10
- * multiple token ids to stop generation, eg: [10, 10]
- * some more tokens to stop generations, eg: [[10], [20, 20], [30, 30, 30]]
- Returns:
- Tensor: the unfinished flag tensor
- """
- if isinstance(eos_token_id, int):
- unfinished_flag = paddle.logical_and(
- unfinished_flag, input_ids[:, -1:] != eos_token_id
- )
- else:
- batch_unfinish_flag = None
- for batch_eos_token_id in eos_token_id:
- if batch_unfinish_flag is None:
- batch_unfinish_flag = ~get_unfinished_flag(
- input_ids, unfinished_flag, batch_eos_token_id
- )
- else:
- batch_unfinish_flag = paddle.logical_or(
- batch_unfinish_flag,
- ~get_unfinished_flag(
- input_ids, unfinished_flag, batch_eos_token_id
- ),
- )
- unfinished_flag = ~batch_unfinish_flag
- return unfinished_flag
- class BeamHypotheses:
- def __init__(self, num_beams, length_penalty, early_stopping):
- """
- Initialize n-best list of hypotheses.
- """
- self.length_penalty = length_penalty
- self.early_stopping = early_stopping
- self.num_beams = num_beams
- self.beams = []
- self.worst_score = get_scale_by_dtype()
- def __len__(self):
- """
- Number of hypotheses in the list.
- """
- return len(self.beams)
- def add(self, hyp, sum_logprobs, origin_len=0):
- """
- Add a new hypothesis to the list.
- """
- score = sum_logprobs / (
- ((hyp.shape[-1] - origin_len + 5) / 6) ** self.length_penalty
- )
- if len(self) < self.num_beams or score > self.worst_score:
- self.beams.append((score, hyp))
- if len(self) > self.num_beams:
- sorted_next_scores = sorted(
- [(s, idx) for idx, (s, _) in enumerate(self.beams)]
- )
- del self.beams[sorted_next_scores[0][1]]
- self.worst_score = sorted_next_scores[1][0]
- else:
- self.worst_score = min(score, self.worst_score)
- def is_done(self, best_sum_logprobs, cur_len, origin_len=0):
- """
- If there are enough hypotheses and that none of the hypotheses being
- generated can become better than the worst one in the heap, then we
- are done with this sentence.
- """
- if len(self) < self.num_beams:
- return False
- elif self.early_stopping:
- return True
- else:
- cur_score = (
- best_sum_logprobs
- / ((cur_len - origin_len + 5) / 6) ** self.length_penalty
- )
- ret = self.worst_score >= cur_score
- return ret
- class BeamSearchScorer(object):
- """
- implementing standard beam search decoding.
- """
- def __init__(
- self,
- batch_size,
- max_length,
- num_beams,
- length_penalty=1.0,
- do_early_stopping=False,
- num_beam_hyps_to_keep=1,
- num_beam_groups=1,
- ):
- self.max_length = max_length
- self.num_beams = num_beams
- self.length_penalty = length_penalty
- self.do_early_stopping = do_early_stopping
- self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
- self.num_beam_groups = num_beam_groups
- self.group_size = self.num_beams // self.num_beam_groups
- self._is_init = False
- self._beam_hyps = [
- BeamHypotheses(
- num_beams=self.num_beams,
- length_penalty=self.length_penalty,
- early_stopping=self.do_early_stopping,
- )
- for _ in range(batch_size)
- ]
- self._done = paddle.to_tensor([0 for _ in range(batch_size)], dtype="int64")
- if not isinstance(num_beams, int) or num_beams <= 1:
- raise ValueError(
- "`num_beams` has to be an integer strictly greater than 1, but "
- "received {}. For `num_beams` == 1, one should make use of "
- "`greedy_search` instead.".format(num_beams)
- )
- if (
- not isinstance(num_beam_groups, int)
- or (num_beam_groups > num_beams)
- or (num_beams % num_beam_groups != 0)
- ):
- raise ValueError(
- "`num_beam_groups` has to be an integer smaller or equal than "
- "`num_beams` and `num_beams` has to be divisible by "
- "`num_beam_groups`, but received num_beam_groups={}, num_beams="
- "{}.".format(num_beam_groups, num_beams)
- )
- @property
- def is_done(self):
- return paddle.min(self._done) == 1
- def process(
- self,
- input_ids,
- next_scores,
- next_tokens,
- next_indices,
- origin_len=0,
- pad_token_id=None,
- eos_token_id=None,
- ):
- cur_len = input_ids.shape[-1]
- batch_size = len(self._beam_hyps)
- assert batch_size == (input_ids.shape[0] // self.group_size)
- next_beam_scores = paddle.zeros(
- [batch_size, self.group_size], dtype=next_scores.dtype
- )
- next_beam_tokens = paddle.zeros(
- [batch_size, self.group_size], dtype=next_tokens.dtype
- )
- next_beam_indices = paddle.zeros(
- [batch_size, self.group_size], dtype=next_indices.dtype
- )
- for batch_idx, beam_hyp in enumerate(self._beam_hyps):
- if self._done[batch_idx] == 1:
- assert (
- len(beam_hyp) >= self.num_beams
- ), "Batch can only be done if at least {} beams have been generated".format(
- self.num_beams
- )
- assert (
- eos_token_id is not None and pad_token_id is not None
- ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
- # pad the batch
- next_beam_scores[batch_idx, :] = 0
- next_beam_tokens[batch_idx, :] = pad_token_id
- next_beam_indices[batch_idx, :] = 0
- continue
- # next tokens for this sentence
- beam_idx = 0
- for beam_token_rank, (next_token, next_score, next_index) in enumerate(
- zip(
- next_tokens[batch_idx],
- next_scores[batch_idx],
- next_indices[batch_idx],
- )
- ):
- batch_beam_idx = batch_idx * self.group_size + next_index
- # add to generated hypotheses if end of sentence
- if (eos_token_id is not None) and (next_token.item() == eos_token_id):
- # If beam_token does not belong to top num_beams tokens,
- # it should not be added
- is_beam_token_worse_than_top_num_beams = (
- beam_token_rank >= self.group_size
- )
- if is_beam_token_worse_than_top_num_beams:
- continue
- beam_hyp.add(
- input_ids[batch_beam_idx.item()].clone(),
- next_score.item(),
- origin_len,
- )
- else:
- # add next predicted token since it is not eos_token
- next_beam_scores[batch_idx, beam_idx] = next_score
- next_beam_tokens[batch_idx, beam_idx] = next_token.item()
- next_beam_indices[batch_idx, beam_idx] = batch_beam_idx.item()
- beam_idx += 1
- # once the beam for next step is full, don't add more tokens to it.
- if beam_idx == self.group_size:
- break
- if beam_idx < self.group_size:
- raise ValueError(
- "At most {} tokens in `next_tokens[batch_idx]` can be equal "
- "to `eos_token_id: {}`. Make sure `next_tokens[batch_idx]` "
- "are corrected.".format(self.group_size, eos_token_id)
- )
- # Check if we are done so that we can save a pad step if all(done)
- if beam_hyp.is_done(
- next_scores[batch_idx].max().item(), cur_len, origin_len
- ):
- self._done[batch_idx] = 1
- return {
- "next_beam_scores": next_beam_scores.reshape([-1]),
- "next_beam_tokens": next_beam_tokens.reshape([-1]),
- "next_beam_indices": next_beam_indices.reshape([-1]),
- }
- def finalize(
- self,
- input_ids,
- final_beam_scores,
- final_beam_tokens,
- final_beam_indices,
- origin_len=0,
- pad_token_id=None,
- eos_token_id=None,
- ):
- batch_size = len(self._beam_hyps)
- # finalize all open beam hypotheses and add to generated hypotheses
- for batch_idx, beam_hyp in enumerate(self._beam_hyps):
- if self._done[batch_idx] == 1:
- continue
- # all open beam hypotheses are added to the beam hypothesis
- # beam hypothesis class automatically keeps the best beams
- for beam_id in range(self.num_beams):
- batch_beam_idx = batch_idx * self.num_beams + beam_id
- final_score = final_beam_scores[batch_beam_idx].item()
- final_tokens = input_ids[batch_beam_idx]
- beam_hyp.add(final_tokens, final_score, origin_len=origin_len)
- # select the best hypotheses
- sent_lengths = paddle.zeros(
- [batch_size * self.num_beam_hyps_to_keep], dtype=input_ids.dtype
- )
- best = []
- # retrieve best hypotheses
- for i, beam_hyp in enumerate(self._beam_hyps):
- sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
- for j in range(self.num_beam_hyps_to_keep):
- best_score, best_hyp = sorted_hyps.pop()
- sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
- best.append([best_hyp, best_score])
- # prepare for adding eos
- sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
- decoded = paddle.zeros(
- [batch_size * self.num_beam_hyps_to_keep, sent_max_len],
- dtype=input_ids.dtype,
- )
- # shorter batches are padded if needed
- if sent_lengths.min().item() != sent_lengths.max().item():
- assert pad_token_id is not None, "`pad_token_id` has to be defined"
- decoded[:, :] = pad_token_id
- decoded_score = paddle.zeros([batch_size * self.num_beam_hyps_to_keep, 1])
- # fill with hypotheses and eos_token_id if the latter fits in
- for i, (hypo, score) in enumerate(best):
- decoded[i, : sent_lengths[i].item()] = hypo.cpu().numpy()
- decoded_score[i] = score
- if sent_lengths[i] < self.max_length:
- decoded[i, sent_lengths[i].item()] = eos_token_id
- return decoded, decoded_score
- class GenerationMixin(object):
- r"""
- This class implements the interface for generation task.
- It's used as the base class of `paddlenlp.transformers.PretrainedModel
- <https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.transformers.model_utils.html>`__.
- """
- # enable `to_static` method for CausalLM Model
- enable_to_static_method = False
- @staticmethod
- def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
- batch_size = 1
- if bos_token_id is None:
- raise ValueError(
- "`bos_token_id` should be defined when no " "`input_ids` are provided."
- )
- if encoder_output is not None:
- batch_size = encoder_output.shape[0]
- return paddle.ones([batch_size, 1], dtype="int64") * bos_token_id
- @staticmethod
- def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id):
- is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
- input_ids == pad_token_id
- ).item()
- is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
- (eos_token_id is not None) and (pad_token_id != eos_token_id)
- )
- if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
- attention_mask = (input_ids == pad_token_id).astype(
- paddle.get_default_dtype()
- ) * get_scale_by_dtype(return_positive=False)
- else:
- attention_mask = paddle.zeros_like(
- input_ids, dtype=paddle.get_default_dtype()
- )
- return paddle.unsqueeze(attention_mask, axis=[1, 2])
- @staticmethod
- def prepare_seq_len_for_generation(input_ids, pad_token_id, eos_token_id):
- is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
- input_ids == pad_token_id
- ).item()
- is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
- (eos_token_id is not None) and (pad_token_id != eos_token_id)
- )
- if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
- seq_len = paddle.sum(input_ids != pad_token_id, axis=1).unsqueeze(-1)
- else:
- seq_len = paddle.full(
- (input_ids.shape[0], 1), input_ids.shape[1], dtype="int64"
- )
- return seq_len
- def get_logits_processor(
- self,
- min_length=None,
- max_length=None,
- eos_token_id=None,
- forced_bos_token_id=None,
- forced_eos_token_id=None,
- num_beams=1,
- num_beam_groups=1,
- diversity_rate=0.0,
- repetition_penalty=None,
- no_repeat_ngram_size=None,
- logits_processors=None,
- ):
- processors = LogitsProcessorList()
- if min_length is not None and eos_token_id is not None and min_length > -1:
- processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
- if num_beam_groups > 1 and diversity_rate > 0.0:
- processors.append(
- HammingDiversityLogitsProcessor(
- diversity_rate=diversity_rate,
- num_beams=num_beams,
- num_beam_groups=num_beam_groups,
- )
- )
- if repetition_penalty is not None and repetition_penalty != 1.0:
- processors.append(
- RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
- )
- if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
- processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
- if forced_bos_token_id is not None:
- processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
- if forced_eos_token_id is not None:
- processors.append(
- ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)
- )
- # TODO
- # Add more pre_processing for distribution
- if logits_processors is not None:
- custom_processors = LogitsProcessorList()
- custom_processors_type = [type(lp) for lp in logits_processors]
- for processor in processors:
- if type(processor) not in custom_processors_type:
- custom_processors.append(processor)
- custom_processors.extend(logits_processors)
- return custom_processors
- else:
- return processors
- @staticmethod
- def expand_inputs_for_generation(
- input_ids, expand_size, attention_mask=None, **model_kwargs
- ):
- index = paddle.tile(
- paddle.arange(input_ids.shape[0], dtype="int64").unsqueeze(-1),
- [1, expand_size],
- ).reshape([-1])
- input_ids = paddle.gather(input_ids, index)
- if attention_mask is not None:
- model_kwargs["attention_mask"] = paddle.gather(attention_mask, index)
- if (
- "token_type_ids" in model_kwargs
- and model_kwargs["token_type_ids"] is not None
- ):
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = paddle.gather(token_type_ids, index)
- if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
- position_ids = model_kwargs["position_ids"]
- model_kwargs["position_ids"] = paddle.gather(position_ids, index)
- if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None:
- seq_len = model_kwargs["seq_len"]
- model_kwargs["seq_len"] = paddle.gather(seq_len, index)
- if (
- "encoder_output" in model_kwargs
- and model_kwargs["encoder_output"] is not None
- ):
- encoder_output = model_kwargs["encoder_output"]
- model_kwargs["encoder_output"] = paddle.gather(encoder_output, index)
- if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
- role_ids = model_kwargs["role_ids"]
- model_kwargs["role_ids"] = paddle.gather(role_ids, index)
- return input_ids, model_kwargs
- @staticmethod
- def update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=False
- ):
- # Update the model inputs during generation.
- # Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
- # and they contain pad value, the result vectors updated by this method
- # may be different from expected. In this case, you need to rewrite the
- # method.
- # update cache
- if (
- isinstance(outputs, tuple)
- and len(outputs) > 1
- and not isinstance(outputs[1], paddle.Tensor)
- ):
- model_kwargs["cache"] = outputs[1]
- model_kwargs["past_key_values"] = outputs[1]
- if isinstance(outputs, ModelOutput) and "past_key_values" in outputs:
- model_kwargs["cache"] = outputs.past_key_values
- model_kwargs["past_key_values"] = outputs.past_key_values
- # update token_type_ids with last value
- if (
- "token_type_ids" in model_kwargs
- and model_kwargs["token_type_ids"] is not None
- ):
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = paddle.concat(
- [token_type_ids, token_type_ids[:, -1:]], axis=-1
- )
- # update position_ids
- if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
- position_ids = model_kwargs["position_ids"]
- model_kwargs["position_ids"] = paddle.concat(
- [position_ids, position_ids[..., -1:] + 1], axis=-1
- )
- # update attention_mask
- if not is_encoder_decoder and "attention_mask" in model_kwargs:
- attention_mask = model_kwargs["attention_mask"]
- # nn.Pad2D don't support the data type `bool`
- if convert_dtype(attention_mask.dtype) == "bool":
- attention_mask = paddle.cast(attention_mask, "int64")
- if len(attention_mask.shape) == 4:
- cur_device = paddle.get_device()
- if cur_device.split(":")[0] == "npu":
- attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(
- attention_mask
- )
- attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
- else:
- attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(
- attention_mask
- )
- attention_mask = nn.Pad2D(
- [0, 1, 0, 0], value=get_scale_by_dtype(return_positive=False)
- )(attention_mask)
- dtype = convert_dtype(attention_mask.dtype)
- if "int" in dtype:
- attention_mask[:, :, -1, -1] = 1
- elif "float" in dtype:
- attention_mask[:, :, -1, -1] = 0.0
- else:
- raise ValueError(
- "The data type of input `attention_mask` must "
- "be bool, int or float"
- )
- else:
- attention_mask = paddle.concat(
- [
- attention_mask,
- paddle.ones([attention_mask.shape[0], 1], dtype="int64"),
- ],
- axis=-1,
- )
- model_kwargs["attention_mask"] = attention_mask
- # update role_ids
- if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
- role_ids = model_kwargs["role_ids"]
- model_kwargs["role_ids"] = paddle.concat(
- [role_ids, role_ids[:, -1:]], axis=-1
- )
- return model_kwargs
- @staticmethod
- def update_scores_for_generation(scores, next_scores, length, unfinished_flag):
- # update scores
- unfinished_scores = (
- scores * paddle.to_tensor(length, dtype=scores.dtype) + next_scores
- ) / (paddle.to_tensor(length, dtype=scores.dtype) + 1)
- scores = paddle.where(unfinished_flag, unfinished_scores, scores)
- return scores
- def prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
- if "encoder_output" not in model_kwargs:
- # retrieve encoder hidden states
- encoder = self.get_encoder()
- encoder_kwargs = {
- argument: value
- for argument, value in model_kwargs.items()
- if not (
- argument.startswith("decoder_")
- or argument.startswith("cross_attn")
- or argument == "use_cache"
- )
- }
- # Use inputs_embeds as the priority if inputs_embeds exists
- if "inputs_embeds" in encoder_kwargs:
- model_kwargs["encoder_output"] = encoder(**encoder_kwargs)
- else:
- model_kwargs["encoder_output"] = encoder(
- input_ids=input_ids, **encoder_kwargs
- )
- return model_kwargs
- def prepare_decoder_input_ids_for_generation(
- self, input_ids, decoder_start_token_id=None, bos_token_id=None
- ):
- decoder_start_token_id = (
- decoder_start_token_id
- if decoder_start_token_id is not None
- else self.config.decoder_start_token_id
- )
- decoder_start_token_id = (
- decoder_start_token_id
- if decoder_start_token_id is not None
- else bos_token_id
- )
- decoder_input_ids = (
- paddle.ones([input_ids.shape[0], 1], dtype="int64") * decoder_start_token_id
- )
- return decoder_input_ids
- def get_decoder_start_token_id(
- self, decoder_start_token_id=None, bos_token_id=None
- ):
- decoder_start_token_id = (
- decoder_start_token_id
- if decoder_start_token_id is not None
- else self.config.decoder_start_token_id
- )
- bos_token_id = (
- bos_token_id if bos_token_id is not None else self.config.bos_token_id
- )
- if decoder_start_token_id is not None:
- return decoder_start_token_id
- elif self.config.decoder_start_token_id is not None:
- return self.config.decoder_start_token_id
- elif bos_token_id is not None:
- return bos_token_id
- elif self.config.bos_token_id is not None:
- return self.config.bos_token_id
- raise ValueError(
- "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
- )
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
- # Implement in subclasses for custom behavior to prepare inputs in the
- # generate method.
- return {"input_ids": input_ids}
- def adjust_logits_during_generation(self, logits):
- # Implement in subclasses for custom behavior to adjust the logits in
- # the generate method.
- return logits
- def prepare_fast_entry(self, kwargs):
- return False
- def _convert_to_fast(self, kwargs):
- # try general convert
- pass
- def _build_fast(self, kwargs):
- self._fast_entry = False
- if kwargs["num_beam_groups"] != 1:
- # not support for group_beam_search yet in the fast version
- raise AttributeError(
- "'num_beam_groups != 1' is not supported yet in the fast version"
- )
- if (
- paddle.get_default_dtype() == "float16"
- and kwargs["use_fp16_decoding"] is False
- ):
- logging.info(
- "Since the default dtype is float16, float16 would be used "
- "though 'use_fp16_decoding=False'."
- )
- kwargs["use_fp16_decoding"] = True
- self.prepare_fast_entry(kwargs)
- def set_pad_token_id(self, pad_token_id, eos_token_id):
- if pad_token_id is None and eos_token_id is not None:
- logging.warning(
- "Setting `pad_token_id` to `eos_token_id`:{} for "
- "open-end generation.".format(eos_token_id)
- )
- if isinstance(eos_token_id, list):
- pad_token_id = eos_token_id[0]
- else:
- pad_token_id = eos_token_id
- return pad_token_id
- @paddle.no_grad()
- def generate(
- self,
- input_ids: paddle.Tensor = None,
- generation_config: GenerationConfig = None,
- stopping_criteria: StoppingCriteria = None,
- streamer=None,
- synced_gpus: Optional[bool] = None,
- **kwargs,
- ):
- r"""
- The interface for generation task. This method can generate sequences
- by using decoding strategy. Currently, there are three decoding
- strategies supported: "greedy_search", "sampling" and "beam_search".
- Args:
- input_ids (Tensor, optional): The input sequence ids for the
- generation. It is a Tensor with shape [batch_size, sequence_length].
- The data type should be int32 or int64. Default to None, which
- we will initialize it as a Tensor with shape [1, 1], filled
- with the value `bos_token_id`.
- generation_config (`~generation.GenerationConfig`, *optional*):
- The generation configuration to be used as base parametrization for the generation call. `**kwargs`
- passed to generate matching the attributes of `generation_config` will override them. If
- `generation_config` is not provided, the default will be used, which had the following loading
- priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
- configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
- default values, whose documentation should be checked to parameterize generation.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- Custom stopping criteria that complement the default stopping criteria built from arguments and a
- generation config. If a stopping criteria is passed that is already created with the arguments or a
- generation config an error is thrown. This feature is intended for advanced users.
- streamer (`~streamer.BaseStreamer`, *optional*):
- Streamer object that will be used to stream the generated sequences. Generated tokens are passed
- through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
- synced_gpus (`bool`, *optional*):
- Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
- `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
- generating before other GPUs. Otherwise it'll be set to `False`.
- kwargs (dict): It can be used to specify additional kwargs
- passed to the model.
- Returns:
- tuple[Tensor]: It is a tuple contains two elements: ids and scores.
- Each element is a Tensor.
- With the fields:
- - ids (Tensor):
- The ids of the generated sequences. It is a Tensor with shape
- [batch_size * num_return_sequences, sequence_length]. The data
- type is same as the input `input_ids`.
- - scores (Tensor):
- The scores of the generated sequences. It is a Tensor with shape
- [batch_size * num_return_sequences, 1]. The data type is float32
- or float64, which is the same as the parameters in the model.
- Example:
- .. code-block::
- import paddle
- from paddlenlp.transformers import (
- UnifiedTransformerLMHeadModel,
- UnifiedTransformerTokenizer
- )
- paddle.seed(2)
- # Initialize the model and tokenizer
- model_name_or_path = 'unified_transformer-12L-cn-luge'
- model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
- tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)
- # Prepare the model inputs.
- history = "早上好,今天空气质量不错。"
- inputs = tokenizer.dialogue_encode(history, task_type='chitchat',
- add_start_token_as_response=True, return_tensors=True)
- .. code-block::
- # Generate the sequence by using "greedy_search" strategy
- ids, scores = model.generate(
- **inputs,
- decode_strategy="greedy_search")
- print(ids.shape, scores.shape)
- # [1, 3] [1, 1]
- sequence_ids = ids.cpu().numpy().tolist()[0]
- sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
- response = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
- print(response)
- # 是的
- .. code-block::
- # Generate 2 sequences by using "sampling" strategy (top_k=5)
- generation_config = GenerationConfig(
- decode_strategy="sampling",
- top_k=5,
- num_return_sequences=2
- )
- ids, scores = model.generate(
- **inputs,
- generation_config=generation_config,
- )
- print(ids.shape, scores.shape)
- # [2, 7] [2, 1]
- response = []
- for sequence_ids in ids.cpu().numpy().tolist():
- sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
- text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
- response.append(text)
- print(response)
- # ['天气好,心情也好', '你也是']
- .. code-block::
- # Generate 2 sequences by using "beam_search" strategy (num_beams=5)
- generation_config = GenerationConfig(
- decode_strategy="beam_search",
- num_beams=5,
- num_return_sequences=2
- )
- ids, scores = model.generate(
- **inputs,
- generation_config=generation_config,
- )
- print(ids.shape, scores.shape)
- # [2, 3] [2, 1]
- response = []
- for sequence_ids in ids.cpu().numpy().tolist():
- sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
- text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
- response.append(text)
- print(response)
- # ['是的', '嗯嗯']
- """
- if generation_config is None:
- if self.generation_config is None or (
- self.generation_config._from_model_config
- and self.config._has_non_default_generation_parameters()
- ):
- new_generation_config = GenerationConfig.from_model_config(self.config)
- if new_generation_config != self.generation_config:
- logging.warning(
- "model.generation_config is in conflict with model.config, "
- "model.config is used."
- )
- self.generation_config = new_generation_config
- generation_config = self.generation_config
- # without update model.generation_config
- generation_config = copy.deepcopy(generation_config)
- model_kwargs = generation_config.update(**kwargs)
- assert generation_config.decode_strategy in [
- "greedy_search",
- "sampling",
- "beam_search",
- ], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format(
- generation_config.decode_strategy
- )
- if getattr(self, "deprecated_warnings", None) is None:
- self.deprecated_warnings = {}
- use_fast = False
- if "use_faster" in model_kwargs:
- raise ValueError("`use_faster` is deprecated now.")
- if "use_fast" in model_kwargs:
- raise ValueError("`use_fast` is deprecated now.")
- bos_token_id = (
- generation_config.bos_token_id
- if generation_config.bos_token_id is not None
- else self.config.bos_token_id
- )
- eos_token_id = (
- generation_config.eos_token_id
- if generation_config.eos_token_id is not None
- else self.config.eos_token_id
- )
- pad_token_id = (
- generation_config.pad_token_id
- if generation_config.pad_token_id is not None
- else self.config.pad_token_id
- )
- forced_bos_token_id = (
- generation_config.forced_bos_token_id
- if generation_config.forced_bos_token_id is not None
- else self.config.forced_bos_token_id
- )
- forced_eos_token_id = (
- generation_config.forced_eos_token_id
- if generation_config.forced_eos_token_id is not None
- else self.config.forced_eos_token_id
- )
- decoder_start_token_id = (
- generation_config.decoder_start_token_id
- if generation_config.decoder_start_token_id is not None
- else self.config.decoder_start_token_id
- )
- no_repeat_ngram_size = (
- generation_config.no_repeat_ngram_size
- if generation_config.no_repeat_ngram_size is not None
- else self.config.no_repeat_ngram_size
- )
- if getattr(self, "_fast_entry", None) is not False and use_fast:
- fg_args = locals()
- fg_args.pop("self")
- fg_args.pop("__class__", None)
- model_kwargs = fg_args.pop("model_kwargs")
- fg_args.update(model_kwargs)
- try:
- if getattr(self, "_fast_entry", None) is None:
- self._build_fast(fg_args)
- if self._fast_entry:
- output = self._fast_entry(**fg_args)
- if isinstance(output, tuple):
- output_ids, dummy_srore = output
- else:
- output_ids = output
- # make result and fast result oneconsistent
- dummy_srore = None
- if generation_config.decode_strategy == "beam_search":
- output_ids = output_ids.transpose([1, 2, 0])
- output_ids = output_ids[
- :, : generation_config.num_return_sequences, :
- ].reshape([-1, output_ids.shape[-1]])
- if dummy_srore is not None:
- dummy_srore = dummy_srore[
- :, : generation_config.num_return_sequences
- ].flatten()
- else:
- output_ids = output_ids.transpose([1, 0])
- return output_ids, dummy_srore
- except Exception as e:
- fg_args["model_kwargs"] = model_kwargs
- # TODO
- # Prevent self._convert_to_fast to throw Exception
- self._convert_to_fast(fg_args)
- logging.warning(e)
- logging.warning(
- "FastGeneration is not available, "
- "and the original version would be used instead."
- )
- # input_ids in model_kwargs is supported
- if "input_ids" in model_kwargs:
- _input_ids = model_kwargs.pop("input_ids")
- if input_ids is None:
- input_ids = _input_ids
- # params check
- if input_ids is None and "inputs_embeds" not in model_kwargs:
- # Init `input_ids` with bos_token_id
- input_ids = self.prepare_input_ids_for_generation(bos_token_id)
- elif "inputs_embeds" in model_kwargs:
- # Add input embeds support
- input_ids = self.prepare_input_ids_for_generation(
- bos_token_id, encoder_output=model_kwargs["inputs_embeds"]
- )
- if model_kwargs.get("attention_mask", None) is None:
- # TODO
- # Init `attention_mask` depending on `pad_token_id`
- model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation(
- input_ids, pad_token_id, eos_token_id
- )
- self.is_encoder_decoder = self.config.is_encoder_decoder
- if self.is_encoder_decoder:
- model_kwargs = self.prepare_encoder_decoder_kwargs_for_generation(
- input_ids, model_kwargs
- )
- # set input_ids as decoder_input_ids
- if "decoder_input_ids" in model_kwargs:
- input_ids = model_kwargs.pop("decoder_input_ids")
- else:
- input_ids = self.prepare_decoder_input_ids_for_generation(
- input_ids, decoder_start_token_id, bos_token_id
- )
- # streamer
- if streamer is not None:
- # streamer couldn't support beam_search strategy
- if (
- generation_config.decode_strategy == "beam_search"
- or generation_config.num_beams > 1
- ):
- raise ValueError(
- "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
- )
- pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
- if (
- generation_config.max_length != 0
- and generation_config.max_new_tokens == DEFAULT_MAX_NEW_TOKENS
- ):
- logging.warning(
- "`max_length` will be deprecated in future releases, use `max_new_tokens` instead."
- )
- generation_config.max_new_tokens = generation_config.max_length
- if generation_config.min_length != 0 and generation_config.min_new_tokens == 0:
- logging.warning(
- "`min_length` will be deprecated in future releases, use `min_new_tokens` instead."
- )
- generation_config.min_new_tokens = generation_config.min_length
- max_length = generation_config.max_new_tokens
- min_length = generation_config.min_new_tokens
- input_len = input_ids.shape[-1]
- min_len = input_len + min_length
- max_len = input_len + max_length
- logits_processors = self.get_logits_processor(
- min_length=min_len if min_length > 0 else None,
- max_length=max_len,
- eos_token_id=eos_token_id,
- forced_bos_token_id=forced_bos_token_id,
- forced_eos_token_id=forced_eos_token_id,
- num_beams=generation_config.num_beams,
- num_beam_groups=generation_config.num_beam_groups,
- diversity_rate=generation_config.diversity_rate,
- repetition_penalty=generation_config.repetition_penalty,
- no_repeat_ngram_size=generation_config.no_repeat_ngram_size,
- logits_processors=(
- model_kwargs["logits_processors"]
- if "logits_processors" in model_kwargs
- and isinstance(model_kwargs["logits_processors"], LogitsProcessorList)
- else None
- ),
- )
- if "logits_processors" in model_kwargs:
- model_kwargs.pop("logits_processors")
- model_kwargs["use_cache"] = generation_config.use_cache
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if generation_config.decode_strategy == "greedy_search":
- if generation_config.num_return_sequences > 1:
- raise ValueError(
- "`num_return_sequences` has to be 1, but is {} "
- "when doing greedy search.".format(
- generation_config.num_return_sequences
- )
- )
- return self.greedy_search(
- input_ids,
- logits_processors,
- max_len,
- pad_token_id,
- eos_token_id,
- stopping_criteria=stopping_criteria,
- streamer=streamer,
- fast_ptq_sampling=generation_config.fast_ptq_sampling,
- trunc_input=generation_config.trunc_input,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
- elif generation_config.decode_strategy == "sampling":
- if generation_config.num_return_sequences > 1:
- input_ids, model_kwargs = self.expand_inputs_for_generation(
- input_ids,
- expand_size=generation_config.num_return_sequences,
- **model_kwargs,
- )
- return self.sample(
- input_ids,
- logits_processors,
- max_len,
- pad_token_id,
- eos_token_id,
- generation_config.top_k,
- generation_config.top_p,
- generation_config.temperature,
- stopping_criteria=stopping_criteria,
- streamer=streamer,
- fast_ptq_sampling=generation_config.fast_ptq_sampling,
- trunc_input=generation_config.trunc_input,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
- elif generation_config.decode_strategy == "beam_search":
- batch_size = input_ids.shape[0]
- if generation_config.num_return_sequences > generation_config.num_beams:
- raise ValueError(
- "`num_return_sequences` has to be smaller or equal to "
- "`num_beams`. But received `num_return_sequences` is {}, "
- "`num_beams` is {}".format(
- generation_config.num_return_sequences,
- generation_config.num_beams,
- )
- )
- if generation_config.num_beams <= 1:
- raise ValueError(
- "`num_beams` has to be bigger than 1. But received "
- "`num_beams` is {}. If `num_beams` is 1, `decode_strategy` "
- "should be 'greedy_search'".format(generation_config.num_beams)
- )
- if generation_config.num_beam_groups > 1:
- diverse_beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- max_length=max_len,
- num_beams=generation_config.num_beams,
- length_penalty=generation_config.length_penalty,
- do_early_stopping=generation_config.early_stopping,
- num_beam_hyps_to_keep=generation_config.num_return_sequences,
- num_beam_groups=generation_config.num_beam_groups,
- )
- # interleave with `num_beams`
- input_ids, model_kwargs = self.expand_inputs_for_generation(
- input_ids, expand_size=generation_config.num_beams, **model_kwargs
- )
- return self.group_beam_search(
- input_ids,
- diverse_beam_scorer,
- logits_processors,
- max_len,
- pad_token_id,
- eos_token_id,
- stopping_criteria=stopping_criteria,
- fast_ptq_sampling=generation_config.fast_ptq_sampling,
- trunc_input=generation_config.trunc_input,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
- else:
- beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- max_length=max_len,
- num_beams=generation_config.num_beams,
- length_penalty=generation_config.length_penalty,
- do_early_stopping=generation_config.early_stopping,
- num_beam_hyps_to_keep=generation_config.num_return_sequences,
- )
- input_ids, model_kwargs = self.expand_inputs_for_generation(
- input_ids, expand_size=generation_config.num_beams, **model_kwargs
- )
- return self.beam_search(
- input_ids,
- beam_scorer,
- logits_processors,
- max_len,
- generation_config.diversity_rate,
- pad_token_id,
- eos_token_id,
- stopping_criteria=stopping_criteria,
- fast_ptq_sampling=generation_config.fast_ptq_sampling,
- trunc_input=generation_config.trunc_input,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
- def greedy_search(
- self,
- input_ids,
- logits_processors,
- max_length,
- pad_token_id,
- eos_token_id,
- stopping_criteria=None,
- streamer=None,
- fast_ptq_sampling=False,
- trunc_input=True,
- synced_gpus=False,
- **model_kwargs,
- ):
- logits_processors = (
- logits_processors
- if logits_processors is not None
- else LogitsProcessorList()
- )
- # max_length will be convert to MaxLengthCriteria
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- # logging.warning(
- # "`max_length` is deprecated in this function, use"
- # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
- # )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- batch_size, cur_len = input_ids.shape
- origin_len = cur_len
- unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
- scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
- generate_end = False
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
- # prepare model inputs & get model output
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
- outputs = self(**model_inputs)
- if synced_gpus and generate_end:
- continue # don't waste resources running the code we don't need
- if isinstance(outputs, tuple):
- logits = outputs[0]
- elif isinstance(outputs, ModelOutput):
- logits = outputs.logits
- else:
- logits = outputs
- # [batch_size, vocab_size]
- next_token_logits = logits[:, -1, :]
- # pre-process distribution
- next_token_logits = self.adjust_logits_during_generation(next_token_logits)
- probs = logits_processors(input_ids, next_token_logits)
- # greedy
- next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
- next_scores = paddle.index_sample(probs, next_tokens)
- if eos_token_id is not None:
- next_tokens = paddle.where(
- unfinished_flag,
- next_tokens,
- paddle.full_like(next_tokens, pad_token_id),
- )
- scores = self.update_scores_for_generation(
- scores, next_scores, cur_len - origin_len, unfinished_flag
- )
- cur_len += 1
- input_ids = paddle.concat([input_ids, next_tokens], axis=1)
- if streamer is not None:
- if self.config.tensor_parallel_rank == 0:
- streamer.put(next_tokens.cpu())
- if stopping_criteria(input_ids, scores):
- generate_end = True
- if eos_token_id is not None:
- unfinished_flag = get_unfinished_flag(
- input_ids, unfinished_flag, eos_token_id
- )
- if not paddle.any(unfinished_flag):
- generate_end = True
- # Stop when there is a </s> in all sentences
- if generate_end and not synced_gpus:
- break
- model_kwargs = self.update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
- if fast_ptq_sampling:
- break
- if streamer is not None:
- streamer.end()
- return input_ids[:, origin_len:] if trunc_input else input_ids, scores
- def sample(
- self,
- input_ids,
- logits_processors,
- max_length,
- pad_token_id,
- eos_token_id,
- top_k=None,
- top_p=None,
- temperature=None,
- min_tokens_to_keep=1,
- stopping_criteria=None,
- streamer=None,
- fast_ptq_sampling=False,
- trunc_input=True,
- synced_gpus=False,
- **model_kwargs,
- ):
- logits_processors = (
- logits_processors
- if logits_processors is not None
- else LogitsProcessorList()
- )
- # max_length will be convert to MaxLengthCriteria
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- # logging.warning(
- # "`max_length` is deprecated in this function, use"
- # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
- # )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- batch_size, cur_len = input_ids.shape
- origin_len = cur_len
- unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
- scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
- generate_end = False
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
- # prepare model inputs & get model output
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
- # NOTE: to decrease ref-count and clear outdate cache in-time
- model_kwargs["cache"] = None
- model_kwargs["past_key_values"] = None
- outputs = self(**model_inputs)
- if synced_gpus and generate_end:
- continue # don't waste resources running the code we don't need
- if isinstance(outputs, tuple):
- logits = outputs[0]
- elif isinstance(outputs, ModelOutput):
- logits = outputs.logits
- else:
- logits = outputs
- # [batch_size, vocab_size]
- logits = logits[:, -1, :]
- # pre-process distribution
- logits = self.adjust_logits_during_generation(logits)
- logits = logits_processors(input_ids, logits)
- # sample
- origin_probs = F.softmax(logits)
- origin_probs = paddle.log(origin_probs)
- if temperature is not None and temperature != 1.0:
- logits = logits / temperature
- probs = F.softmax(logits)
- if top_k is not None and top_k != 0:
- probs = TopKProcess(probs, top_k, min_tokens_to_keep)
- if top_p is not None and top_p < 1.0:
- probs = TopPProcess(probs, top_p, min_tokens_to_keep)
- if paddle.device.is_compiled_with_custom_device("gcu"):
- probs = paddle.cast(probs, "float32")
- if paddle.device.is_compiled_with_xpu():
- probs = paddle.cast(probs, "float32")
- # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
- next_tokens = paddle.multinomial(probs)
- if self.config.tensor_parallel_degree > 1:
- # Maybe no need to broadcast if seed is set correctly.
- from paddle.distributed import fleet
- try:
- hcg = fleet.get_hybrid_communicate_group()
- group = hcg.get_model_parallel_group()
- src = hcg.get_model_parallel_group_src_rank()
- except:
- group, src = None, 0
- paddle.distributed.broadcast(next_tokens, src=src, group=group)
- # config does not include pipeline_parallel_degree, and pipeline parallel
- # uses trainer.model_wrapped to run in both train and predict mode
- # which has pp_group as a attribute
- # TODO(guosheng): only let the last stage of pipeline to do softmax
- # and sampling, and then broadcast to avoid broadcast logits.
- if getattr(self, "pp_group", None) is not None:
- paddle.distributed.broadcast(
- next_tokens,
- src=self.pp_group.ranks[0],
- group=self.pp_group, # use rank 0 for same seed to check
- )
- next_scores = paddle.index_sample(origin_probs, next_tokens)
- if eos_token_id is not None:
- next_tokens = paddle.where(
- unfinished_flag,
- next_tokens,
- paddle.full_like(next_tokens, pad_token_id),
- )
- scores = self.update_scores_for_generation(
- scores, next_scores, cur_len - origin_len, unfinished_flag
- )
- cur_len += 1
- input_ids = paddle.concat([input_ids, next_tokens], axis=1)
- if streamer is not None:
- if self.config.tensor_parallel_rank == 0:
- streamer.put(next_tokens.cpu())
- if stopping_criteria(input_ids, scores):
- generate_end = True
- if eos_token_id is not None:
- unfinished_flag = get_unfinished_flag(
- input_ids, unfinished_flag, eos_token_id
- )
- if not paddle.any(unfinished_flag):
- generate_end = True
- # Stop when there is a </s> in all sentences
- if generate_end and not synced_gpus:
- break
- model_kwargs = self.update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
- )
- if fast_ptq_sampling:
- break
- if streamer is not None:
- streamer.end()
- return input_ids[:, origin_len:] if trunc_input else input_ids, scores
- def _get_model_inputs_spec(self, dtype: str):
- spec = {
- "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
- "attention_mask": paddle.static.InputSpec(
- shape=[None, None], dtype="int64"
- ),
- }
- if "position_ids" in inspect.getfullargspec(self.forward).args:
- spec["position_ids"] = paddle.static.InputSpec(
- shape=[None, None], dtype="int64"
- )
- return spec
- def to_static(self, path: str, config: dict):
- """export generation model to static
- Args:
- path (str): path of saved inference model
- config (dict): configuration for generation
- bos_token_id (int): token id of begin-of-sentence
- eos_token_id (int): token id of end-of-sentence
- pad_token_id (int): token id of pad token
- use_top_p (bool): whether use top_p decoding strategy
- """
- use_top_p = config.get("use_top_p", True)
- top_k_spec = (
- paddle.static.InputSpec(shape=[1], dtype="int64") if not use_top_p else 0
- )
- top_p_spec = (
- paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
- )
- temperature = (
- paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
- )
- dtype = config.get("dtype", None)
- logits_processors = config.get("logits_processors", None)
- model_inputs_spec = self._get_model_inputs_spec(dtype)
- input_spec = [
- model_inputs_spec["input_ids"], # input_ids
- model_inputs_spec["attention_mask"], # attention_mask
- model_inputs_spec.get("position_ids", None), # attention_mask
- logits_processors,
- paddle.static.InputSpec(shape=[1], dtype="int64"), # max_length
- self.generation_config.pad_token_id or config.get("pad_token_id", None),
- self.generation_config.eos_token_id or config.get("eos_token_id", None),
- top_k_spec, # top_k
- top_p_spec, # top_p
- temperature, # temperature
- 1,
- ]
- model = paddle.jit.to_static(self.sample_d2s, input_spec=input_spec)
- paddle.jit.save(model, path)
- def sample_d2s(
- self,
- input_ids,
- attention_mask,
- position_ids,
- logits_processors,
- max_new_tokens,
- pad_token_id,
- eos_token_id,
- top_k=None,
- top_p=None,
- temperature=None,
- min_tokens_to_keep=1,
- ):
- pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
- logits_processors = (
- logits_processors
- if logits_processors is not None
- else LogitsProcessorList()
- )
- if paddle.is_tensor(top_k) and not paddle.is_tensor(top_p):
- use_top_p = False
- elif not paddle.is_tensor(top_k) and paddle.is_tensor(top_p):
- use_top_p = True
- # top_k and top_p are the const value
- elif isinstance(top_p, float) or isinstance(top_k, int):
- use_top_p = True
- else:
- if top_p is None and top_k is None:
- raise ValueError("top_k and top_p should not be None")
- raise ValueError(
- "you should not specify InputSpec for top_k and top_p parameters, one of InputSpec is expected"
- )
- batch_size, cur_len = input_ids.shape
- # used for compute on gpu, avoid memcpy D2H
- cur_len_gpu = paddle.full([1], cur_len, dtype="int64")
- origin_len = input_ids.shape[1]
- # used for compute on gpu, avoid memcpy D2H
- origin_len_gpu = paddle.full([1], origin_len, dtype="int64")
- unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
- scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
- # use_cache is immutable, we split it off other mutable kwargs.
- immutable = {"use_cache": True}
- model_kwargs = {"attention_mask": attention_mask, "position_ids": position_ids}
- def _forward_(**args):
- model_inputs = self.prepare_inputs_for_generation(
- input_ids, **args, **immutable
- )
- assert "use_cache" in model_inputs
- del model_inputs["use_cache"]
- return self(**model_inputs, **immutable)
- def _post_process_(
- outputs,
- input_ids,
- cur_len,
- origin_len,
- scores,
- unfinished_flag,
- model_kwargs,
- pad_token_id,
- ):
- if isinstance(outputs, tuple):
- logits = outputs[0]
- elif isinstance(outputs, ModelOutput):
- logits = outputs.logits
- else:
- logits = outputs
- # [batch_size, vocab_size]
- logits = logits[:, -1, :]
- # pre-process distribution
- logits = self.adjust_logits_during_generation(logits)
- logits = logits_processors(input_ids, logits)
- probs = F.softmax(logits)
- # sample
- origin_probs = F.log_softmax(logits)
- # compute next_tokens
- if use_top_p:
- logits = logits / temperature
- top_ps_tensor = paddle.full(
- shape=[probs.shape[0], 1], fill_value=top_p, dtype=probs.dtype
- )
- _, next_tokens = paddle.tensor.top_p_sampling(probs, top_ps_tensor)
- else:
- probs = TopKProcess(probs, top_k, min_tokens_to_keep)
- if top_k == 1:
- next_tokens = paddle.unsqueeze_(paddle.argmax(probs, axis=-1), -1)
- else:
- next_tokens = paddle.multinomial(probs)
- next_scores = paddle.index_sample(origin_probs, next_tokens)
- scores = self.update_scores_for_generation(
- scores, next_scores, cur_len - origin_len, unfinished_flag
- )
- if eos_token_id is not None:
- next_tokens = paddle.where(
- unfinished_flag,
- next_tokens,
- paddle.full_like(next_tokens, pad_token_id),
- )
- input_ids = paddle.concat([input_ids, next_tokens], axis=1)
- if eos_token_id is not None:
- unfinished_flag = get_unfinished_flag(
- input_ids, unfinished_flag, eos_token_id
- )
- model_kwargs = self.update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
- return input_ids, scores, unfinished_flag, model_kwargs
- outputs = _forward_(**model_kwargs)
- input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
- outputs,
- input_ids,
- cur_len_gpu,
- origin_len_gpu,
- scores,
- unfinished_flag,
- model_kwargs,
- pad_token_id,
- )
- cur_len += 1
- cur_len_gpu += 1
- attn_mask = model_kwargs["attention_mask"]
- # make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
- model_kwargs["attention_mask"] = paddle.reshape(attn_mask, attn_mask.shape)
- model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
- max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")
- while cur_len < max_new_tokens and paddle.any(unfinished_flag):
- input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
- _forward_(**model_kwargs),
- input_ids,
- cur_len_gpu,
- origin_len_gpu,
- scores,
- unfinished_flag,
- model_kwargs,
- pad_token_id,
- )
- cur_len += 1
- cur_len_gpu += 1
- return input_ids[:, origin_len:], scores
- def reorder_cache(self, cache, beam_idx):
- cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache)
- return cache
- def beam_search(
- self,
- input_ids,
- beam_scorer,
- logits_processors,
- max_length,
- diversity_rate,
- pad_token_id,
- eos_token_id,
- stopping_criteria=None,
- fast_ptq_sampling=False,
- trunc_input=True,
- synced_gpus=False,
- **model_kwargs,
- ):
- logits_processors = (
- logits_processors
- if logits_processors is not None
- else LogitsProcessorList()
- )
- # max_length will be convert to MaxLengthCriteria
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- # logging.warning(
- # "`max_length` is deprecated in this function, use"
- # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
- # )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- batch_size = len(beam_scorer._beam_hyps)
- num_beams = beam_scorer.num_beams
- batch_beam_size, cur_len = input_ids.shape
- origin_len = cur_len
- assert (
- num_beams * batch_size == batch_beam_size
- ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
- num_beams * batch_size, batch_beam_size
- )
- beam_scores = paddle.zeros(
- (batch_size, num_beams), dtype=paddle.get_default_dtype()
- )
- beam_scores[:, 1:] = get_scale_by_dtype(return_positive=False)
- beam_scores = paddle.reshape(beam_scores, [-1])
- generate_end = False
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
- # prepare model inputs & get model output
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
- outputs = self(**model_inputs)
- if synced_gpus and generate_end:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
- if isinstance(outputs, tuple):
- logits = outputs[0]
- elif isinstance(outputs, ModelOutput):
- logits = outputs.logits
- else:
- logits = outputs
- # [batch_size, vocab_size]
- logits = logits[:, -1, :]
- # pre-process distribution
- logits = self.adjust_logits_during_generation(logits)
- # beam search
- # [batch_size * num_beams, vocab_size]
- next_scores = F.softmax(logits)
- next_scores = paddle.log(next_scores)
- next_scores = logits_processors(input_ids, next_scores)
- next_scores = next_scores + beam_scores.unsqueeze(-1)
- vocab_size = next_scores.shape[-1]
- if diversity_rate == 0.0:
- # reshape for beam search
- next_scores = next_scores.reshape([batch_size, num_beams * vocab_size])
- next_scores, next_tokens = paddle.topk(
- next_scores, 2 * num_beams, axis=1
- )
- next_indices = next_tokens // vocab_size
- next_tokens = next_tokens % vocab_size
- else:
- next_scores, next_tokens = paddle.topk(
- next_scores, 2 * num_beams, axis=1
- )
- sibling_score = (
- paddle.arange(1, 2 * num_beams + 1, dtype="int64").unsqueeze(0)
- * diversity_rate
- )
- diversed_score = next_scores - sibling_score
- next_scores = next_scores.reshape(
- [batch_size, 2 * num_beams * num_beams]
- )
- next_tokens = next_tokens.reshape(
- [batch_size, 2 * num_beams * num_beams]
- )
- diversed_score = diversed_score.reshape(
- [batch_size, 2 * num_beams * num_beams]
- )
- diversed_score, diversed_tokens = paddle.topk(
- diversed_score, 2 * num_beams, axis=1
- )
- # TODO
- # Use gather_nd() to select origan token and score
- next_scores = paddle.stack(
- [
- paddle.index_select(next_scores[i], diversed_tokens[i])
- for i in range(next_scores.shape[0])
- ]
- )
- next_tokens = paddle.stack(
- [
- paddle.index_select(next_tokens[i], diversed_tokens[i])
- for i in range(next_tokens.shape[0])
- ]
- )
- next_indices = diversed_tokens // (2 * num_beams)
- # stateless
- beam_outputs = beam_scorer.process(
- input_ids,
- next_scores,
- next_tokens,
- next_indices,
- origin_len=origin_len,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- )
- beam_scores = beam_outputs["next_beam_scores"]
- beam_next_tokens = beam_outputs["next_beam_tokens"]
- beam_idx = beam_outputs["next_beam_indices"]
- # beam_idx may contain element -1 and cause error
- # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
- beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
- cur_len += 1
- input_ids = paddle.concat(
- [
- paddle.index_select(input_ids, beam_idx),
- beam_next_tokens.unsqueeze(-1),
- ],
- axis=-1,
- )
- if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
- if not synced_gpus:
- break
- else:
- generate_end = True
- model_kwargs = self.update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
- )
- if "cache" in model_kwargs:
- # reorder the cache
- model_kwargs["cache"] = self.reorder_cache(
- model_kwargs["cache"], beam_idx
- )
- if "past_key_values" in model_kwargs:
- # reorder the cache
- model_kwargs["past_key_values"] = self.reorder_cache(
- model_kwargs["past_key_values"], beam_idx
- )
- if fast_ptq_sampling:
- break
- pred_ids, scores = beam_scorer.finalize(
- input_ids,
- beam_scores,
- next_tokens,
- next_indices,
- origin_len=origin_len,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- )
- return pred_ids[:, origin_len:] if trunc_input else input_ids, scores
- def group_beam_search(
- self,
- input_ids,
- beam_scorer,
- logits_processors,
- max_length,
- pad_token_id,
- eos_token_id,
- stopping_criteria=None,
- fast_ptq_sampling=False,
- trunc_input=True,
- synced_gpus=False,
- **model_kwargs,
- ):
- logits_processors = (
- logits_processors
- if logits_processors is not None
- else LogitsProcessorList()
- )
- # max_length will be convert to MaxLengthCriteria
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- # logging.warning(
- # "`max_length` is deprecated in this function, use"
- # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
- # )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- batch_size = len(beam_scorer._beam_hyps)
- num_beams = beam_scorer.num_beams
- num_beam_groups = beam_scorer.num_beam_groups
- num_sub_beams = num_beams // num_beam_groups
- batch_beam_size, cur_len = input_ids.shape
- origin_len = cur_len
- assert (
- num_beams * batch_size == batch_beam_size
- ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
- num_beams * batch_size, batch_beam_size
- )
- beam_scores = paddle.full(
- (batch_size, num_beams),
- get_scale_by_dtype(return_positive=False),
- dtype="float32",
- )
- # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
- # the same group don't produce same tokens everytime.
- beam_scores[:, ::num_sub_beams] = 0
- beam_scores = paddle.reshape(beam_scores, [-1])
- generate_end = False
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
- # predicted tokens in cur_len step
- current_tokens = paddle.zeros(
- shape=[batch_size * num_beams], dtype=input_ids.dtype
- )
- # indices which will form the beams in the next time step
- reordering_indices = paddle.zeros(
- shape=[batch_size * num_beams], dtype="int64"
- )
- # prepare model inputs & get model output
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
- outputs = self(**model_inputs)
- if synced_gpus and generate_end:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
- for beam_group_idx in range(num_beam_groups):
- group_start_idx = beam_group_idx * num_sub_beams
- group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
- group_size = group_end_idx - group_start_idx
- # indices of beams of current group among all sentences in batch
- batch_group_indices = []
- for batch_idx in range(batch_size):
- batch_group_indices.extend(
- [
- batch_idx * num_beams + idx
- for idx in range(group_start_idx, group_end_idx)
- ]
- )
- group_input_ids = input_ids[batch_group_indices]
- if isinstance(outputs, tuple):
- logits = outputs[0]
- elif isinstance(outputs, ModelOutput):
- logits = outputs.logits
- else:
- logits = outputs
- logits = logits[:, -1, :]
- logits = paddle.index_select(
- logits, paddle.to_tensor(batch_group_indices)
- )
- logits = self.adjust_logits_during_generation(logits)
- next_scores = F.softmax(logits)
- next_scores = paddle.log(next_scores)
- vocab_size = next_scores.shape[-1]
- next_scores = logits_processors(
- group_input_ids,
- next_scores,
- current_tokens=current_tokens,
- beam_group_idx=beam_group_idx,
- )
- next_scores = next_scores + beam_scores[batch_group_indices].unsqueeze(
- -1
- )
- # reshape for beam search
- next_scores = next_scores.reshape([batch_size, group_size * vocab_size])
- next_scores, next_tokens = paddle.topk(
- next_scores, 2 * group_size, axis=1
- )
- next_indices = next_tokens // vocab_size
- next_tokens = next_tokens % vocab_size
- beam_outputs = beam_scorer.process(
- group_input_ids,
- next_scores,
- next_tokens,
- next_indices,
- origin_len=origin_len,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- )
- beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
- beam_next_tokens = beam_outputs["next_beam_tokens"]
- beam_idx = beam_outputs["next_beam_indices"]
- # beam_idx may contain element -1 and cause error
- # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
- beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
- input_ids[batch_group_indices] = group_input_ids[beam_idx]
- group_input_ids = paddle.concat(
- [
- paddle.index_select(group_input_ids, index=beam_idx),
- beam_next_tokens.unsqueeze(-1),
- ],
- axis=-1,
- )
- current_tokens[batch_group_indices] = beam_next_tokens
- reordering_indices[batch_group_indices] = (
- num_beams * (beam_idx // group_size)
- + group_start_idx
- + (beam_idx % group_size)
- )
- input_ids = paddle.concat(
- [input_ids, current_tokens.unsqueeze(-1)], axis=-1
- )
- cur_len += 1
- if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
- if not synced_gpus:
- break
- else:
- generate_end = True
- model_kwargs = self.update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
- )
- if "cache" in model_kwargs:
- # reorder the cache
- model_kwargs["cache"] = self.reorder_cache(
- model_kwargs["cache"], reordering_indices
- )
- if "past_key_values" in model_kwargs:
- # reorder the cache
- model_kwargs["past_key_values"] = self.reorder_cache(
- model_kwargs["past_key_values"], reordering_indices
- )
- if fast_ptq_sampling:
- break
- pred_ids, scores = beam_scorer.finalize(
- input_ids,
- beam_scores,
- next_tokens,
- next_indices,
- origin_len=origin_len,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- )
- return pred_ids[:, origin_len:] if trunc_input else input_ids, scores
|