Kaynağa Gözat

add whisper pipeline (#2789)

zxcd 10 ay önce
ebeveyn
işleme
e7ce2ec75b

+ 25 - 0
api_examples/pipelines/test_multilingual_speech_recognition.py

@@ -0,0 +1,25 @@
+# copyright (c) 2025 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="multilingual_speech_recognition")
+
+output = pipeline.predict("https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav")
+
+
+for res in output:
+    print(res)
+    res.print()  ## 打印预测的结构化输出
+    res.save_to_json("./output/")  ## 保存预测的结构化输出

+ 9 - 0
paddlex/configs/pipelines/multilingual_speech_recognition.yaml

@@ -0,0 +1,9 @@
+
+pipeline_name: multilingual_speech_recognition
+
+SubModules:
+  MultilingualSpeechRecognition:
+    module_name: multilingual_speech_recognition
+    model_name: whisper_large
+    model_dir: null
+    batch_size: 1   

+ 3 - 3
paddlex/inference/common/reader/audio_reader.py

@@ -26,12 +26,12 @@ class ReadAudio:
 
         """
         super().__init__()
-        self._audio_reader = AudioReader()
+        self._audio_reader = AudioReader(backend="wav")
 
     def read(self, input):
         if isinstance(input, str):
-            audio, sample_rate = self._audio_reader.read_file(input)
-            if sample_rate != "16000":
+            audio, sample_rate = self._audio_reader.read(input)
+            if sample_rate != 16000:
                 raise ValueError(
                     f"ReadAudio only supports 16k pcm or wav file.\n"
                     f"However, got: {sample_rate}."

+ 1 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -26,6 +26,7 @@ from .image_classification import ImageClassificationPipeline
 from .object_detection import ObjectDetectionPipeline
 from .seal_recognition import SealRecognitionPipeline
 from .table_recognition import TableRecognitionPipeline
+from .multilingual_speech_recognition import MultilingualSpeechRecognitionPipeline
 from .formula_recognition import FormulaRecognitionPipeline
 from .video_classification import VideoClassificationPipeline
 from .anomaly_detection import AnomalyDetectionPipeline

+ 15 - 0
paddlex/inference/pipelines_new/multilingual_speech_recognition/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2025 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import MultilingualSpeechRecognitionPipeline

+ 71 - 0
paddlex/inference/pipelines_new/multilingual_speech_recognition/pipeline.py

@@ -0,0 +1,71 @@
+# copyright (c) 2025 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional
+import numpy as np
+
+from ...utils.pp_option import PaddlePredictorOption
+from ..base import BasePipeline
+from ...models_new.multilingual_speech_recognition.result import WhisperResult
+
+
+class MultilingualSpeechRecognitionPipeline(BasePipeline):
+    """Multilingual Speech Recognition Pipeline"""
+
+    entities = "multilingual_speech_recognition"
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """
+        Initializes the class with given configurations and options.
+
+        Args:
+            config (Dict): Configuration dictionary containing model and other parameters.
+            device (str): The device to run the prediction on. Default is None.
+            pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None.
+            use_hpip (bool): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]]): HPIP specific parameters. Default is None.
+        """
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        multilingual_speech_recognition_model_config = config["SubModules"][
+            "MultilingualSpeechRecognition"
+        ]
+        self.multilingual_speech_recognition_model = self.create_model(
+            multilingual_speech_recognition_model_config
+        )
+        # only support batch size 1
+        batch_size = multilingual_speech_recognition_model_config["batch_size"]
+
+    def predict(
+        self, input: str | list[str] | np.ndarray | list[np.ndarray], **kwargs
+    ) -> WhisperResult:
+        """Predicts speech recognition results for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input audio or path.
+            **kwargs: Additional keyword arguments that can be passed to the function.
+
+        Returns:
+            WhisperResult: The predicted whisper results, support str and json output.
+        """
+        yield from self.multilingual_speech_recognition_model(input)