Răsfoiți Sursa

adapt python3.10 and fix corner case for whisper model (#2843)

* adapt python3.10 and fix corner case

* fix

* fix

* fix

* add temperature adapt
zxcd 10 luni în urmă
părinte
comite
387d18ab9a

+ 20 - 1
paddlex/inference/models_new/multilingual_speech_recognition/predictor.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import lazy_paddle as paddle
+import numpy as np
 
 from ....utils.func_register import FuncRegister
 from ...common.batch_sampler import AudioBatchSampler
@@ -99,6 +100,24 @@ class WhisperPredictor(BasicPredictor):
         audio = audio[:, 0]
         audio = log_mel_spectrogram(audio, resource_path=self.config["resource_dir"])
 
+        # adapt temperature
+        temperature_increment_on_fallback = self.config[
+            "temperature_increment_on_fallback"
+        ]
+        if (
+            temperature_increment_on_fallback is not None
+            and temperature_increment_on_fallback != "None"
+        ):
+            temperature = tuple(
+                np.arange(
+                    self.config["temperature"],
+                    1.0 + 1e-6,
+                    temperature_increment_on_fallback,
+                )
+            )
+        else:
+            temperature = [self.config["temperature"]]
+
         # model inference
         result = self.model.transcribe(
             audio,
@@ -106,7 +125,7 @@ class WhisperPredictor(BasicPredictor):
             task=self.config["task"],
             language=self.config["language"],
             resource_path=self.config["resource_dir"],
-            temperature=self.config["temperature"],
+            temperature=temperature,
             compression_ratio_threshold=self.config["compression_ratio_threshold"],
             logprob_threshold=self.config["logprob_threshold"],
             best_of=self.config["best_of"],

+ 8 - 6
paddlex/inference/models_new/multilingual_speech_recognition/processors.py

@@ -494,7 +494,7 @@ class MultiHeadAttention(paddle.nn.Layer):
         if mask is not None:
             qk = qk + mask[:n_ctx, :n_ctx]
 
-        w = paddle.nn.functional.softmax(qk.astype("float32"), axis=-1).to(q.dtype)
+        w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
         return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
 
 
@@ -936,10 +936,12 @@ def transcribe(
     all_segments = []
     prompt_reset_since = 0
 
-    initial_prompt = decode_options.pop("initial_prompt", None) or []
-    if initial_prompt:
+    initial_prompt = decode_options.pop("initial_prompt", None)
+    if initial_prompt and initial_prompt != "None":
         initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
         all_tokens.extend(initial_prompt)
+    else:
+        initial_prompt = []
 
     def add_segment(
         *, start: float, end: float, text_tokens: paddle.Tensor, result: DecodingResult
@@ -959,7 +961,7 @@ def transcribe(
                 "text": text,
                 "tokens": result.tokens,
                 "temperature": result.temperature,
-                "avg_logprob": result.avg_logprob.tolist(),
+                "avg_logprob": result.avg_logprob,
                 "compression_ratio": result.compression_ratio,
                 "no_speech_prob": result.no_speech_prob,
             }
@@ -1253,8 +1255,8 @@ class BeamSearchDecoder(TokenDecoder):
                 prefix = tokens[idx].tolist()
                 logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
                 for logprob, token in zip(logprob, token):
-                    new_logprob = sum_logprobs[idx] + logprob
-                    sequence = tuple(prefix + [token])
+                    new_logprob = (sum_logprobs[idx] + logprob).item()
+                    sequence = tuple(prefix + [token.item()])
                     scores[sequence] = new_logprob
                     sources[sequence] = idx