瀏覽代碼

fix slice bug (#4470)

supotato6 2 月之前
父節點
當前提交
1eff3be4b2
共有 1 個文件被更改,包括 6 次插入0 次删除
  1. 6 0
      paddlex/inference/models/multilingual_speech_recognition/processors.py

+ 6 - 0
paddlex/inference/models/multilingual_speech_recognition/processors.py

@@ -1342,6 +1342,7 @@ class SuppressBlank(LogitFilter):
 
     def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
         if tokens.shape[1] == self.sample_begin:
+            logits.contiguous()
             logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
                 -np.inf
             )
@@ -1352,6 +1353,7 @@ class SuppressTokens(LogitFilter):
         self.suppress_tokens = list(suppress_tokens)
 
     def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
+        logits.contiguous()
         logits[:, self.suppress_tokens] = -np.inf
 
 
@@ -1369,6 +1371,7 @@ class ApplyTimestampRules(LogitFilter):
     def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
         # suppress <|notimestamps|> which is handled by without_timestamps
         if self.tokenizer.no_timestamps is not None:
+            logits.contiguous()
             logits[:, self.tokenizer.no_timestamps] = -np.inf
 
         # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
@@ -1382,6 +1385,7 @@ class ApplyTimestampRules(LogitFilter):
             )
 
             if last_was_timestamp:
+                logits.contiguous()
                 if penultimate_was_timestamp:  # has to be non-timestamp
                     logits[k, self.tokenizer.timestamp_begin :] = -np.inf
                 else:  # cannot be normal text tokens
@@ -1395,6 +1399,7 @@ class ApplyTimestampRules(LogitFilter):
             last_allowed = (
                 self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
             )
+            logits.contiguous()
             logits[:, last_allowed + 1 :] = -np.inf
 
         # if sum of probability over timestamps is above any other token, sample timestamp
@@ -1413,6 +1418,7 @@ class ApplyTimestampRules(LogitFilter):
                 logprobs[k, : self.tokenizer.timestamp_begin]
             )
             if timestamp_logprob > max_text_token_logprob:
+                logits.contiguous()
                 logits[k, : self.tokenizer.timestamp_begin] = -np.inf