|
|
@@ -1342,6 +1342,7 @@ class SuppressBlank(LogitFilter):
|
|
|
|
|
|
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
if tokens.shape[1] == self.sample_begin:
|
|
|
+ logits.contiguous()
|
|
|
logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
|
|
|
-np.inf
|
|
|
)
|
|
|
@@ -1352,6 +1353,7 @@ class SuppressTokens(LogitFilter):
|
|
|
self.suppress_tokens = list(suppress_tokens)
|
|
|
|
|
|
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
+ logits.contiguous()
|
|
|
logits[:, self.suppress_tokens] = -np.inf
|
|
|
|
|
|
|
|
|
@@ -1369,6 +1371,7 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
# suppress <|notimestamps|> which is handled by without_timestamps
|
|
|
if self.tokenizer.no_timestamps is not None:
|
|
|
+ logits.contiguous()
|
|
|
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
|
|
|
|
|
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
|
@@ -1382,6 +1385,7 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
)
|
|
|
|
|
|
if last_was_timestamp:
|
|
|
+ logits.contiguous()
|
|
|
if penultimate_was_timestamp: # has to be non-timestamp
|
|
|
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
|
|
else: # cannot be normal text tokens
|
|
|
@@ -1395,6 +1399,7 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
last_allowed = (
|
|
|
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
|
|
)
|
|
|
+ logits.contiguous()
|
|
|
logits[:, last_allowed + 1 :] = -np.inf
|
|
|
|
|
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
|
|
@@ -1413,6 +1418,7 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
logprobs[k, : self.tokenizer.timestamp_begin]
|
|
|
)
|
|
|
if timestamp_logprob > max_text_token_logprob:
|
|
|
+ logits.contiguous()
|
|
|
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
|
|
|
|
|
|