|
|
@@ -38,9 +38,7 @@ class WhisperPredictor(BasicPredictor):
|
|
|
"""
|
|
|
super().__init__(*args, **kwargs)
|
|
|
self.audio_reader = self._build()
|
|
|
- download_and_extract(
|
|
|
- self.config["resource_path"], self.config["resource_dir"], "assets"
|
|
|
- )
|
|
|
+ download_and_extract(self.config["resource_path"], self.model_dir, "assets")
|
|
|
|
|
|
def _build_batch_sampler(self):
|
|
|
"""Builds and returns an AudioBatchSampler instance.
|
|
|
@@ -72,7 +70,8 @@ class WhisperPredictor(BasicPredictor):
|
|
|
)
|
|
|
|
|
|
# build model
|
|
|
- model_dict = paddle.load(self.config["model_file"])
|
|
|
+ model_file = (self.model_dir / f"{self.MODEL_FILE_PREFIX}.pdparams").as_posix()
|
|
|
+ model_dict = paddle.load(model_file)
|
|
|
dims = ModelDimensions(**model_dict["dims"])
|
|
|
self.model = Whisper(dims)
|
|
|
self.model.load_dict(model_dict)
|
|
|
@@ -98,7 +97,7 @@ class WhisperPredictor(BasicPredictor):
|
|
|
audio, sample_rate = self.audio_reader.read(batch_data[0])
|
|
|
audio = paddle.to_tensor(audio)
|
|
|
audio = audio[:, 0]
|
|
|
- audio = log_mel_spectrogram(audio, resource_path=self.config["resource_dir"])
|
|
|
+ audio = log_mel_spectrogram(audio, resource_path=self.model_dir)
|
|
|
|
|
|
# adapt temperature
|
|
|
temperature_increment_on_fallback = self.config[
|
|
|
@@ -124,7 +123,7 @@ class WhisperPredictor(BasicPredictor):
|
|
|
verbose=self.config["verbose"],
|
|
|
task=self.config["task"],
|
|
|
language=self.config["language"],
|
|
|
- resource_path=self.config["resource_dir"],
|
|
|
+ resource_path=self.model_dir,
|
|
|
temperature=temperature,
|
|
|
compression_ratio_threshold=self.config["compression_ratio_threshold"],
|
|
|
logprob_threshold=self.config["logprob_threshold"],
|