Эх сурвалжийг харах

adapt windows and mac folder path. (#2923)

zxcd 10 сар өмнө
parent
commit
2dc0bfb72a

+ 5 - 6
paddlex/inference/models_new/multilingual_speech_recognition/predictor.py

@@ -38,9 +38,7 @@ class WhisperPredictor(BasicPredictor):
         """
         """
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         self.audio_reader = self._build()
         self.audio_reader = self._build()
-        download_and_extract(
-            self.config["resource_path"], self.config["resource_dir"], "assets"
-        )
+        download_and_extract(self.config["resource_path"], self.model_dir, "assets")
 
 
     def _build_batch_sampler(self):
     def _build_batch_sampler(self):
         """Builds and returns an AudioBatchSampler instance.
         """Builds and returns an AudioBatchSampler instance.
@@ -72,7 +70,8 @@ class WhisperPredictor(BasicPredictor):
         )
         )
 
 
         # build model
         # build model
-        model_dict = paddle.load(self.config["model_file"])
+        model_file = (self.model_dir / f"{self.MODEL_FILE_PREFIX}.pdparams").as_posix()
+        model_dict = paddle.load(model_file)
         dims = ModelDimensions(**model_dict["dims"])
         dims = ModelDimensions(**model_dict["dims"])
         self.model = Whisper(dims)
         self.model = Whisper(dims)
         self.model.load_dict(model_dict)
         self.model.load_dict(model_dict)
@@ -98,7 +97,7 @@ class WhisperPredictor(BasicPredictor):
         audio, sample_rate = self.audio_reader.read(batch_data[0])
         audio, sample_rate = self.audio_reader.read(batch_data[0])
         audio = paddle.to_tensor(audio)
         audio = paddle.to_tensor(audio)
         audio = audio[:, 0]
         audio = audio[:, 0]
-        audio = log_mel_spectrogram(audio, resource_path=self.config["resource_dir"])
+        audio = log_mel_spectrogram(audio, resource_path=self.model_dir)
 
 
         # adapt temperature
         # adapt temperature
         temperature_increment_on_fallback = self.config[
         temperature_increment_on_fallback = self.config[
@@ -124,7 +123,7 @@ class WhisperPredictor(BasicPredictor):
             verbose=self.config["verbose"],
             verbose=self.config["verbose"],
             task=self.config["task"],
             task=self.config["task"],
             language=self.config["language"],
             language=self.config["language"],
-            resource_path=self.config["resource_dir"],
+            resource_path=self.model_dir,
             temperature=temperature,
             temperature=temperature,
             compression_ratio_threshold=self.config["compression_ratio_threshold"],
             compression_ratio_threshold=self.config["compression_ratio_threshold"],
             logprob_threshold=self.config["logprob_threshold"],
             logprob_threshold=self.config["logprob_threshold"],