logit_processor.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from typing import List
  2. from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
  3. class Mineru2LogitProcessor(CustomLogitProcessor):
  4. """
  5. Stateless logit processor for Mineru2.
  6. (base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
  7. This processor applies token-level constraints to prevent repetition during generation.
  8. It supports two main constraints:
  9. - no_repeat_ngram_size (int):
  10. Prevents repeating the same n-gram of specified size in the output.
  11. Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
  12. This implementation is slower due to its lack of specialized optimization.
  13. - no_repeat_token_count (int):
  14. (Placeholder for future logic)
  15. Intended to prevent repeating the same token multiple times.
  16. Not yet implemented in this version.
  17. """
  18. def __init__(self) -> None:
  19. super().__init__()
  20. self._generated_ngrams = {} # Cache of generated n-grams by request ID
  21. self._time = {} # Timestamp of the last update for each request
  22. self._gen_step = 0 # Global generation step counter
  23. def __call__(self, logits, batch_info: List[dict]):
  24. """
  25. Applies repetition constraints to the logits before sampling tokens.
  26. Args:
  27. logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
  28. batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
  29. - "__req__": Request object containing request ID and output_ids.
  30. - "no_repeat_ngram_size": Size of n-gram to avoid repeating.
  31. Returns:
  32. FloatTensor: The modified logits tensor with banned token logits set to -inf.
  33. """
  34. from sglang.srt.managers.schedule_batch import Req
  35. self._gen_step += 1 # Update global generation step
  36. for idx, info in enumerate(batch_info):
  37. if not isinstance(info, dict) or "__req__" not in info:
  38. continue
  39. req: Req = info["__req__"]
  40. rid = req.rid
  41. output_ids = req.output_ids
  42. ngram_size = info.get("no_repeat_ngram_size", 0)
  43. # Skip if there are not enough tokens to form an n-gram
  44. if ngram_size <= 0 or len(output_ids) < ngram_size:
  45. continue
  46. # Record the current step for cache cleanup tracking
  47. self._time[rid] = self._gen_step
  48. # Initialize n-gram cache for this request if it doesn't exist
  49. if rid not in self._generated_ngrams:
  50. self._generated_ngrams[rid] = {}
  51. # Get the n-gram prefix (all but the last token)
  52. prev_ngram = tuple(output_ids[-ngram_size:-1])
  53. last_token = output_ids[-1]
  54. # Store this n-gram occurrence
  55. self._generated_ngrams[rid][prev_ngram] = self._generated_ngrams[rid].get(prev_ngram, []) + [last_token]
  56. # Get the next-token candidates to ban based on current prefix
  57. current_prefix = tuple(output_ids[-ngram_size + 1 :])
  58. banned_tokens = self._generated_ngrams[rid].get(current_prefix, [])
  59. # Set the logits of banned tokens to negative infinity
  60. for token in banned_tokens:
  61. logits[idx][token] = -float("inf")
  62. # Clean up cache for expired requests
  63. expired_rids = [rid for rid, last_used in self._time.items() if last_used < self._gen_step]
  64. for rid in expired_rids:
  65. self._generated_ngrams.pop(rid, None)
  66. self._time.pop(rid, None)
  67. return logits