| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- from typing import List
- from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
- class Mineru2LogitProcessor(CustomLogitProcessor):
- """
- Stateless logit processor for Mineru2.
- (base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
- This processor applies token-level constraints to prevent repetition during generation.
- It supports two main constraints:
- - no_repeat_ngram_size (int):
- Prevents repeating the same n-gram of specified size in the output.
- Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
- This implementation is slower due to its lack of specialized optimization.
- - no_repeat_token_count (int):
- (Placeholder for future logic)
- Intended to prevent repeating the same token multiple times.
- Not yet implemented in this version.
- """
- def __init__(self) -> None:
- super().__init__()
- self._generated_ngrams = {} # Cache of generated n-grams by request ID
- self._time = {} # Timestamp of the last update for each request
- self._gen_step = 0 # Global generation step counter
- def __call__(self, logits, batch_info: List[dict]):
- """
- Applies repetition constraints to the logits before sampling tokens.
- Args:
- logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
- batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
- - "__req__": Request object containing request ID and output_ids.
- - "no_repeat_ngram_size": Size of n-gram to avoid repeating.
- Returns:
- FloatTensor: The modified logits tensor with banned token logits set to -inf.
- """
- from sglang.srt.managers.schedule_batch import Req
- self._gen_step += 1 # Update global generation step
- for idx, info in enumerate(batch_info):
- if not isinstance(info, dict) or "__req__" not in info:
- continue
- req: Req = info["__req__"]
- rid = req.rid
- output_ids = req.output_ids
- ngram_size = info.get("no_repeat_ngram_size", 0)
- # Skip if there are not enough tokens to form an n-gram
- if ngram_size <= 0 or len(output_ids) < ngram_size:
- continue
- # Record the current step for cache cleanup tracking
- self._time[rid] = self._gen_step
- # Initialize n-gram cache for this request if it doesn't exist
- if rid not in self._generated_ngrams:
- self._generated_ngrams[rid] = {}
- # Get the n-gram prefix (all but the last token)
- prev_ngram = tuple(output_ids[-ngram_size:-1])
- last_token = output_ids[-1]
- # Store this n-gram occurrence
- self._generated_ngrams[rid][prev_ngram] = self._generated_ngrams[rid].get(prev_ngram, []) + [last_token]
- # Get the next-token candidates to ban based on current prefix
- current_prefix = tuple(output_ids[-ngram_size + 1 :])
- banned_tokens = self._generated_ngrams[rid].get(current_prefix, [])
- # Set the logits of banned tokens to negative infinity
- for token in banned_tokens:
- logits[idx][token] = -float("inf")
- # Clean up cache for expired requests
- expired_rids = [rid for rid, last_used in self._time.items() if last_used < self._gen_step]
- for rid in expired_rids:
- self._generated_ngrams.pop(rid, None)
- self._time.pop(rid, None)
- return logits
|