|
|
@@ -19,6 +19,7 @@ from functools import lru_cache
|
|
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import numpy as np
|
|
|
+import paddle
|
|
|
|
|
|
from ....utils.deps import function_requires_deps, is_dep_available
|
|
|
from ..common.tokenizer import GPTTokenizer
|
|
|
@@ -28,1917 +29,1899 @@ if is_dep_available("soundfile"):
|
|
|
if is_dep_available("tqdm"):
|
|
|
import tqdm
|
|
|
|
|
|
-if is_dep_available("paddlepaddle"):
|
|
|
+__all__ = [
|
|
|
+ "Whisper",
|
|
|
+ "Tokenizer",
|
|
|
+]
|
|
|
+
|
|
|
+
|
|
|
+def exact_div(x, y):
|
|
|
+ assert x % y == 0
|
|
|
+ return x // y
|
|
|
+
|
|
|
+
|
|
|
+_MODELS = ["large"]
|
|
|
+SAMPLE_RATE = 16000
|
|
|
+N_FFT = 400
|
|
|
+N_MELS = 80
|
|
|
+HOP_LENGTH = 160
|
|
|
+CHUNK_LENGTH = 30
|
|
|
+N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
|
|
+N_FRAMES = exact_div(
|
|
|
+ N_SAMPLES, HOP_LENGTH
|
|
|
+) # 3000: number of frames in a mel spectrogram input
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ModelDimensions:
|
|
|
+ n_mels: int
|
|
|
+ n_audio_ctx: int
|
|
|
+ n_audio_state: int
|
|
|
+ n_audio_head: int
|
|
|
+ n_audio_layer: int
|
|
|
+ n_vocab: int
|
|
|
+ n_text_ctx: int
|
|
|
+ n_text_state: int
|
|
|
+ n_text_head: int
|
|
|
+ n_text_layer: int
|
|
|
+
|
|
|
+
|
|
|
+LANGUAGES = {
|
|
|
+ "en": "english",
|
|
|
+ "zh": "chinese",
|
|
|
+ "de": "german",
|
|
|
+ "es": "spanish",
|
|
|
+ "ru": "russian",
|
|
|
+ "ko": "korean",
|
|
|
+ "fr": "french",
|
|
|
+ "ja": "japanese",
|
|
|
+ "pt": "portuguese",
|
|
|
+ "tr": "turkish",
|
|
|
+ "pl": "polish",
|
|
|
+ "ca": "catalan",
|
|
|
+ "nl": "dutch",
|
|
|
+ "ar": "arabic",
|
|
|
+ "sv": "swedish",
|
|
|
+ "it": "italian",
|
|
|
+ "id": "indonesian",
|
|
|
+ "hi": "hindi",
|
|
|
+ "fi": "finnish",
|
|
|
+ "vi": "vietnamese",
|
|
|
+ "iw": "hebrew",
|
|
|
+ "uk": "ukrainian",
|
|
|
+ "el": "greek",
|
|
|
+ "ms": "malay",
|
|
|
+ "cs": "czech",
|
|
|
+ "ro": "romanian",
|
|
|
+ "da": "danish",
|
|
|
+ "hu": "hungarian",
|
|
|
+ "ta": "tamil",
|
|
|
+ "no": "norwegian",
|
|
|
+ "th": "thai",
|
|
|
+ "ur": "urdu",
|
|
|
+ "hr": "croatian",
|
|
|
+ "bg": "bulgarian",
|
|
|
+ "lt": "lithuanian",
|
|
|
+ "la": "latin",
|
|
|
+ "mi": "maori",
|
|
|
+ "ml": "malayalam",
|
|
|
+ "cy": "welsh",
|
|
|
+ "sk": "slovak",
|
|
|
+ "te": "telugu",
|
|
|
+ "fa": "persian",
|
|
|
+ "lv": "latvian",
|
|
|
+ "bn": "bengali",
|
|
|
+ "sr": "serbian",
|
|
|
+ "az": "azerbaijani",
|
|
|
+ "sl": "slovenian",
|
|
|
+ "kn": "kannada",
|
|
|
+ "et": "estonian",
|
|
|
+ "mk": "macedonian",
|
|
|
+ "br": "breton",
|
|
|
+ "eu": "basque",
|
|
|
+ "is": "icelandic",
|
|
|
+ "hy": "armenian",
|
|
|
+ "ne": "nepali",
|
|
|
+ "mn": "mongolian",
|
|
|
+ "bs": "bosnian",
|
|
|
+ "kk": "kazakh",
|
|
|
+ "sq": "albanian",
|
|
|
+ "sw": "swahili",
|
|
|
+ "gl": "galician",
|
|
|
+ "mr": "marathi",
|
|
|
+ "pa": "punjabi",
|
|
|
+ "si": "sinhala",
|
|
|
+ "km": "khmer",
|
|
|
+ "sn": "shona",
|
|
|
+ "yo": "yoruba",
|
|
|
+ "so": "somali",
|
|
|
+ "af": "afrikaans",
|
|
|
+ "oc": "occitan",
|
|
|
+ "ka": "georgian",
|
|
|
+ "be": "belarusian",
|
|
|
+ "tg": "tajik",
|
|
|
+ "sd": "sindhi",
|
|
|
+ "gu": "gujarati",
|
|
|
+ "am": "amharic",
|
|
|
+ "yi": "yiddish",
|
|
|
+ "lo": "lao",
|
|
|
+ "uz": "uzbek",
|
|
|
+ "fo": "faroese",
|
|
|
+ "ht": "haitian creole",
|
|
|
+ "ps": "pashto",
|
|
|
+ "tk": "turkmen",
|
|
|
+ "nn": "nynorsk",
|
|
|
+ "mt": "maltese",
|
|
|
+ "sa": "sanskrit",
|
|
|
+ "lb": "luxembourgish",
|
|
|
+ "my": "myanmar",
|
|
|
+ "bo": "tibetan",
|
|
|
+ "tl": "tagalog",
|
|
|
+ "mg": "malagasy",
|
|
|
+ "as": "assamese",
|
|
|
+ "tt": "tatar",
|
|
|
+ "haw": "hawaiian",
|
|
|
+ "ln": "lingala",
|
|
|
+ "ha": "hausa",
|
|
|
+ "ba": "bashkir",
|
|
|
+ "jw": "javanese",
|
|
|
+ "su": "sundanese",
|
|
|
+}
|
|
|
+
|
|
|
+# language code lookup by name, with a few language aliases
|
|
|
+TO_LANGUAGE_CODE = {
|
|
|
+ **{language: code for code, language in LANGUAGES.items()},
|
|
|
+ "burmese": "my",
|
|
|
+ "valencian": "ca",
|
|
|
+ "flemish": "nl",
|
|
|
+ "haitian": "ht",
|
|
|
+ "letzeburgesch": "lb",
|
|
|
+ "pushto": "ps",
|
|
|
+ "panjabi": "pa",
|
|
|
+ "moldavian": "ro",
|
|
|
+ "moldovan": "ro",
|
|
|
+ "sinhalese": "si",
|
|
|
+ "castilian": "es",
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+def compression_ratio(text) -> float:
|
|
|
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
|
|
|
+
|
|
|
+
|
|
|
+def format_timestamp(
|
|
|
+ seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
|
|
+):
|
|
|
+ assert seconds >= 0, "non-negative timestamp expected"
|
|
|
+ milliseconds = round(seconds * 1000.0)
|
|
|
+
|
|
|
+ hours = milliseconds // 3_600_000
|
|
|
+ milliseconds -= hours * 3_600_000
|
|
|
+
|
|
|
+ minutes = milliseconds // 60_000
|
|
|
+ milliseconds -= minutes * 60_000
|
|
|
+
|
|
|
+ seconds = milliseconds // 1_000
|
|
|
+ milliseconds -= seconds * 1_000
|
|
|
+
|
|
|
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
|
|
+ return (
|
|
|
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@dataclass(frozen=True)
|
|
|
+class Tokenizer:
|
|
|
+ """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
|
|
|
+
|
|
|
+ tokenizer: "GPTTokenizer"
|
|
|
+ language: Optional[str]
|
|
|
+ sot_sequence: Tuple[int]
|
|
|
+
|
|
|
+ def encode(self, text, **kwargs):
|
|
|
+ return self.tokenizer.encode(text, **kwargs)
|
|
|
|
|
|
- import paddle
|
|
|
-
|
|
|
- __all__ = [
|
|
|
- "Whisper",
|
|
|
- "Tokenizer",
|
|
|
- ]
|
|
|
-
|
|
|
- def exact_div(x, y):
|
|
|
- assert x % y == 0
|
|
|
- return x // y
|
|
|
-
|
|
|
- _MODELS = ["large"]
|
|
|
- SAMPLE_RATE = 16000
|
|
|
- N_FFT = 400
|
|
|
- N_MELS = 80
|
|
|
- HOP_LENGTH = 160
|
|
|
- CHUNK_LENGTH = 30
|
|
|
- N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
|
|
- N_FRAMES = exact_div(
|
|
|
- N_SAMPLES, HOP_LENGTH
|
|
|
- ) # 3000: number of frames in a mel spectrogram input
|
|
|
-
|
|
|
- @dataclass
|
|
|
- class ModelDimensions:
|
|
|
- n_mels: int
|
|
|
- n_audio_ctx: int
|
|
|
- n_audio_state: int
|
|
|
- n_audio_head: int
|
|
|
- n_audio_layer: int
|
|
|
- n_vocab: int
|
|
|
- n_text_ctx: int
|
|
|
- n_text_state: int
|
|
|
- n_text_head: int
|
|
|
- n_text_layer: int
|
|
|
-
|
|
|
- LANGUAGES = {
|
|
|
- "en": "english",
|
|
|
- "zh": "chinese",
|
|
|
- "de": "german",
|
|
|
- "es": "spanish",
|
|
|
- "ru": "russian",
|
|
|
- "ko": "korean",
|
|
|
- "fr": "french",
|
|
|
- "ja": "japanese",
|
|
|
- "pt": "portuguese",
|
|
|
- "tr": "turkish",
|
|
|
- "pl": "polish",
|
|
|
- "ca": "catalan",
|
|
|
- "nl": "dutch",
|
|
|
- "ar": "arabic",
|
|
|
- "sv": "swedish",
|
|
|
- "it": "italian",
|
|
|
- "id": "indonesian",
|
|
|
- "hi": "hindi",
|
|
|
- "fi": "finnish",
|
|
|
- "vi": "vietnamese",
|
|
|
- "iw": "hebrew",
|
|
|
- "uk": "ukrainian",
|
|
|
- "el": "greek",
|
|
|
- "ms": "malay",
|
|
|
- "cs": "czech",
|
|
|
- "ro": "romanian",
|
|
|
- "da": "danish",
|
|
|
- "hu": "hungarian",
|
|
|
- "ta": "tamil",
|
|
|
- "no": "norwegian",
|
|
|
- "th": "thai",
|
|
|
- "ur": "urdu",
|
|
|
- "hr": "croatian",
|
|
|
- "bg": "bulgarian",
|
|
|
- "lt": "lithuanian",
|
|
|
- "la": "latin",
|
|
|
- "mi": "maori",
|
|
|
- "ml": "malayalam",
|
|
|
- "cy": "welsh",
|
|
|
- "sk": "slovak",
|
|
|
- "te": "telugu",
|
|
|
- "fa": "persian",
|
|
|
- "lv": "latvian",
|
|
|
- "bn": "bengali",
|
|
|
- "sr": "serbian",
|
|
|
- "az": "azerbaijani",
|
|
|
- "sl": "slovenian",
|
|
|
- "kn": "kannada",
|
|
|
- "et": "estonian",
|
|
|
- "mk": "macedonian",
|
|
|
- "br": "breton",
|
|
|
- "eu": "basque",
|
|
|
- "is": "icelandic",
|
|
|
- "hy": "armenian",
|
|
|
- "ne": "nepali",
|
|
|
- "mn": "mongolian",
|
|
|
- "bs": "bosnian",
|
|
|
- "kk": "kazakh",
|
|
|
- "sq": "albanian",
|
|
|
- "sw": "swahili",
|
|
|
- "gl": "galician",
|
|
|
- "mr": "marathi",
|
|
|
- "pa": "punjabi",
|
|
|
- "si": "sinhala",
|
|
|
- "km": "khmer",
|
|
|
- "sn": "shona",
|
|
|
- "yo": "yoruba",
|
|
|
- "so": "somali",
|
|
|
- "af": "afrikaans",
|
|
|
- "oc": "occitan",
|
|
|
- "ka": "georgian",
|
|
|
- "be": "belarusian",
|
|
|
- "tg": "tajik",
|
|
|
- "sd": "sindhi",
|
|
|
- "gu": "gujarati",
|
|
|
- "am": "amharic",
|
|
|
- "yi": "yiddish",
|
|
|
- "lo": "lao",
|
|
|
- "uz": "uzbek",
|
|
|
- "fo": "faroese",
|
|
|
- "ht": "haitian creole",
|
|
|
- "ps": "pashto",
|
|
|
- "tk": "turkmen",
|
|
|
- "nn": "nynorsk",
|
|
|
- "mt": "maltese",
|
|
|
- "sa": "sanskrit",
|
|
|
- "lb": "luxembourgish",
|
|
|
- "my": "myanmar",
|
|
|
- "bo": "tibetan",
|
|
|
- "tl": "tagalog",
|
|
|
- "mg": "malagasy",
|
|
|
- "as": "assamese",
|
|
|
- "tt": "tatar",
|
|
|
- "haw": "hawaiian",
|
|
|
- "ln": "lingala",
|
|
|
- "ha": "hausa",
|
|
|
- "ba": "bashkir",
|
|
|
- "jw": "javanese",
|
|
|
- "su": "sundanese",
|
|
|
- }
|
|
|
-
|
|
|
- # language code lookup by name, with a few language aliases
|
|
|
- TO_LANGUAGE_CODE = {
|
|
|
- **{language: code for code, language in LANGUAGES.items()},
|
|
|
- "burmese": "my",
|
|
|
- "valencian": "ca",
|
|
|
- "flemish": "nl",
|
|
|
- "haitian": "ht",
|
|
|
- "letzeburgesch": "lb",
|
|
|
- "pushto": "ps",
|
|
|
- "panjabi": "pa",
|
|
|
- "moldavian": "ro",
|
|
|
- "moldovan": "ro",
|
|
|
- "sinhalese": "si",
|
|
|
- "castilian": "es",
|
|
|
- }
|
|
|
-
|
|
|
- def compression_ratio(text) -> float:
|
|
|
- return len(text) / len(zlib.compress(text.encode("utf-8")))
|
|
|
-
|
|
|
- def format_timestamp(
|
|
|
- seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
|
|
+ def decode(
|
|
|
+ self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
|
|
|
):
|
|
|
- assert seconds >= 0, "non-negative timestamp expected"
|
|
|
- milliseconds = round(seconds * 1000.0)
|
|
|
-
|
|
|
- hours = milliseconds // 3_600_000
|
|
|
- milliseconds -= hours * 3_600_000
|
|
|
-
|
|
|
- minutes = milliseconds // 60_000
|
|
|
- milliseconds -= minutes * 60_000
|
|
|
-
|
|
|
- seconds = milliseconds // 1_000
|
|
|
- milliseconds -= seconds * 1_000
|
|
|
-
|
|
|
- hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
|
|
- return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
|
|
-
|
|
|
- @dataclass(frozen=True)
|
|
|
- class Tokenizer:
|
|
|
- """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
|
|
|
-
|
|
|
- tokenizer: "GPTTokenizer"
|
|
|
- language: Optional[str]
|
|
|
- sot_sequence: Tuple[int]
|
|
|
+ if len(token_ids) > 1:
|
|
|
+ ids_list = []
|
|
|
+ for ids in token_ids:
|
|
|
+ if paddle.is_tensor(ids):
|
|
|
+ ids = ids.item()
|
|
|
+ if ids < len(self.tokenizer):
|
|
|
+ ids_list.append(ids)
|
|
|
+ token_ids = ids_list
|
|
|
+ elif len(token_ids) == 1:
|
|
|
+ token_ids = token_ids[0]
|
|
|
+ else:
|
|
|
+ raise ValueError(f"token_ids {token_ids} load error.")
|
|
|
|
|
|
- def encode(self, text, **kwargs):
|
|
|
- return self.tokenizer.encode(text, **kwargs)
|
|
|
+ return self.tokenizer.decode(token_ids, **kwargs)
|
|
|
|
|
|
- def decode(
|
|
|
- self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
|
|
|
- ):
|
|
|
- if len(token_ids) > 1:
|
|
|
- ids_list = []
|
|
|
- for ids in token_ids:
|
|
|
- if paddle.is_tensor(ids):
|
|
|
- ids = ids.item()
|
|
|
- if ids < len(self.tokenizer):
|
|
|
- ids_list.append(ids)
|
|
|
- token_ids = ids_list
|
|
|
- elif len(token_ids) == 1:
|
|
|
- token_ids = token_ids[0]
|
|
|
+ def decode_with_timestamps(self, tokens) -> str:
|
|
|
+ """
|
|
|
+ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
|
|
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
|
|
+ """
|
|
|
+ outputs = [[]]
|
|
|
+ for token in tokens:
|
|
|
+ if token >= self.timestamp_begin:
|
|
|
+ timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
|
|
+ outputs.append(timestamp)
|
|
|
+ outputs.append([])
|
|
|
else:
|
|
|
- raise ValueError(f"token_ids {token_ids} load error.")
|
|
|
-
|
|
|
- return self.tokenizer.decode(token_ids, **kwargs)
|
|
|
-
|
|
|
- def decode_with_timestamps(self, tokens) -> str:
|
|
|
- """
|
|
|
- Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
|
|
- This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
|
|
- """
|
|
|
- outputs = [[]]
|
|
|
- for token in tokens:
|
|
|
- if token >= self.timestamp_begin:
|
|
|
- timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
|
|
- outputs.append(timestamp)
|
|
|
- outputs.append([])
|
|
|
- else:
|
|
|
- outputs[-1].append(token)
|
|
|
- outputs = [
|
|
|
- s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
|
|
|
- ]
|
|
|
- return "".join(outputs)
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def eot(self) -> int:
|
|
|
- return self.tokenizer.eos_token_id
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def sot(self) -> int:
|
|
|
- return self._get_single_token_id("<|startoftranscript|>")
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def sot_lm(self) -> int:
|
|
|
- return self._get_single_token_id("<|startoflm|>")
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def sot_prev(self) -> int:
|
|
|
- return self._get_single_token_id("<|startofprev|>")
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def no_speech(self) -> int:
|
|
|
- return self._get_single_token_id("<|nospeech|>")
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def no_timestamps(self) -> int:
|
|
|
- return self._get_single_token_id("<|notimestamps|>")
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def timestamp_begin(self) -> int:
|
|
|
- return self.tokenizer.all_special_ids[-1] + 1
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def language_token(self) -> int:
|
|
|
- """Returns the token id corresponding to the value of the `language` field"""
|
|
|
- if self.language is None:
|
|
|
- raise ValueError(
|
|
|
- "This tokenizer does not have language token configured"
|
|
|
- )
|
|
|
-
|
|
|
- additional_tokens = dict(
|
|
|
- zip(
|
|
|
- self.tokenizer.additional_special_tokens,
|
|
|
- self.tokenizer.additional_special_tokens_ids,
|
|
|
- )
|
|
|
- )
|
|
|
- candidate = f"<|{self.language}|>"
|
|
|
- if candidate in additional_tokens:
|
|
|
- return additional_tokens[candidate]
|
|
|
-
|
|
|
- raise KeyError(f"Language {self.language} not found in tokenizer.")
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def all_language_tokens(self) -> Tuple[int]:
|
|
|
- result = []
|
|
|
- for token, token_id in zip(
|
|
|
+ outputs[-1].append(token)
|
|
|
+ outputs = [
|
|
|
+ s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
|
|
|
+ ]
|
|
|
+ return "".join(outputs)
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def eot(self) -> int:
|
|
|
+ return self.tokenizer.eos_token_id
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def sot(self) -> int:
|
|
|
+ return self._get_single_token_id("<|startoftranscript|>")
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def sot_lm(self) -> int:
|
|
|
+ return self._get_single_token_id("<|startoflm|>")
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def sot_prev(self) -> int:
|
|
|
+ return self._get_single_token_id("<|startofprev|>")
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def no_speech(self) -> int:
|
|
|
+ return self._get_single_token_id("<|nospeech|>")
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def no_timestamps(self) -> int:
|
|
|
+ return self._get_single_token_id("<|notimestamps|>")
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def timestamp_begin(self) -> int:
|
|
|
+ return self.tokenizer.all_special_ids[-1] + 1
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def language_token(self) -> int:
|
|
|
+ """Returns the token id corresponding to the value of the `language` field"""
|
|
|
+ if self.language is None:
|
|
|
+ raise ValueError("This tokenizer does not have language token configured")
|
|
|
+
|
|
|
+ additional_tokens = dict(
|
|
|
+ zip(
|
|
|
self.tokenizer.additional_special_tokens,
|
|
|
self.tokenizer.additional_special_tokens_ids,
|
|
|
- ):
|
|
|
- if token.strip("<|>") in LANGUAGES:
|
|
|
- result.append(token_id)
|
|
|
- return tuple(result)
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def all_language_codes(self) -> Tuple[str]:
|
|
|
- return tuple(
|
|
|
- self.decode([l]).strip("<|>") for l in self.all_language_tokens
|
|
|
)
|
|
|
+ )
|
|
|
+ candidate = f"<|{self.language}|>"
|
|
|
+ if candidate in additional_tokens:
|
|
|
+ return additional_tokens[candidate]
|
|
|
+
|
|
|
+ raise KeyError(f"Language {self.language} not found in tokenizer.")
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def all_language_tokens(self) -> Tuple[int]:
|
|
|
+ result = []
|
|
|
+ for token, token_id in zip(
|
|
|
+ self.tokenizer.additional_special_tokens,
|
|
|
+ self.tokenizer.additional_special_tokens_ids,
|
|
|
+ ):
|
|
|
+ if token.strip("<|>") in LANGUAGES:
|
|
|
+ result.append(token_id)
|
|
|
+ return tuple(result)
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def all_language_codes(self) -> Tuple[str]:
|
|
|
+ return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
|
|
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
|
|
+
|
|
|
+ @property
|
|
|
+ @lru_cache()
|
|
|
+ def non_speech_tokens(self) -> Tuple[int]:
|
|
|
+ """
|
|
|
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
|
|
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
|
|
+ - ♪♪♪
|
|
|
+ - ( SPEAKING FOREIGN LANGUAGE )
|
|
|
+ - [DAVID] Hey there,
|
|
|
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
|
|
+ """
|
|
|
+ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
|
|
+ symbols += (
|
|
|
+ "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
|
|
+ )
|
|
|
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
|
|
- return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def non_speech_tokens(self) -> Tuple[int]:
|
|
|
- """
|
|
|
- Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
|
|
- annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
|
|
- - ♪♪♪
|
|
|
- - ( SPEAKING FOREIGN LANGUAGE )
|
|
|
- - [DAVID] Hey there,
|
|
|
- keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
|
|
- """
|
|
|
- symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
|
|
- symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
|
|
-
|
|
|
- # symbols that may be a single token or multiple tokens depending on the tokenizer.
|
|
|
- # In case they're multiple tokens, suppress the first token, which is safe because:
|
|
|
- # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
|
|
- # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
|
|
- miscellaneous = set("♩♪♫♬♭♮♯")
|
|
|
- assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
|
|
-
|
|
|
- # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
|
|
- result = {
|
|
|
- self.tokenizer.encode(" -").input_ids[0],
|
|
|
- self.tokenizer.encode(" '").input_ids[0],
|
|
|
- }
|
|
|
- for symbol in symbols + list(miscellaneous):
|
|
|
- for tokens in [
|
|
|
- self.tokenizer.encode(symbol).input_ids,
|
|
|
- self.tokenizer.encode(" " + symbol).input_ids,
|
|
|
- ]:
|
|
|
- if len(tokens) == 1 or symbol in miscellaneous:
|
|
|
- result.add(tokens[0])
|
|
|
-
|
|
|
- return tuple(sorted(result))
|
|
|
-
|
|
|
- def _get_single_token_id(self, text) -> int:
|
|
|
- tokens = self.tokenizer.encode(text).input_ids
|
|
|
- assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
|
|
- return tokens[0]
|
|
|
-
|
|
|
- @lru_cache(maxsize=None)
|
|
|
- def build_tokenizer(resource_path: str, name: str = "gpt2"):
|
|
|
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
- path = os.path.join(resource_path, "assets", name)
|
|
|
- tokenizer = GPTTokenizer.from_pretrained(path)
|
|
|
-
|
|
|
- specials = [
|
|
|
- "<|startoftranscript|>",
|
|
|
- *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
|
|
- "<|translate|>",
|
|
|
- "<|transcribe|>",
|
|
|
- "<|startoflm|>",
|
|
|
- "<|startofprev|>",
|
|
|
- "<|nospeech|>",
|
|
|
- "<|notimestamps|>",
|
|
|
- ]
|
|
|
-
|
|
|
- tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
|
|
- return tokenizer
|
|
|
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
|
|
|
+ # In case they're multiple tokens, suppress the first token, which is safe because:
|
|
|
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
|
|
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
|
|
+ miscellaneous = set("♩♪♫♬♭♮♯")
|
|
|
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
|
|
+
|
|
|
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
|
|
+ result = {
|
|
|
+ self.tokenizer.encode(" -").input_ids[0],
|
|
|
+ self.tokenizer.encode(" '").input_ids[0],
|
|
|
+ }
|
|
|
+ for symbol in symbols + list(miscellaneous):
|
|
|
+ for tokens in [
|
|
|
+ self.tokenizer.encode(symbol).input_ids,
|
|
|
+ self.tokenizer.encode(" " + symbol).input_ids,
|
|
|
+ ]:
|
|
|
+ if len(tokens) == 1 or symbol in miscellaneous:
|
|
|
+ result.add(tokens[0])
|
|
|
+
|
|
|
+ return tuple(sorted(result))
|
|
|
+
|
|
|
+ def _get_single_token_id(self, text) -> int:
|
|
|
+ tokens = self.tokenizer.encode(text).input_ids
|
|
|
+ assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
|
|
+ return tokens[0]
|
|
|
+
|
|
|
+
|
|
|
+@lru_cache(maxsize=None)
|
|
|
+def build_tokenizer(resource_path: str, name: str = "gpt2"):
|
|
|
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
+ path = os.path.join(resource_path, "assets", name)
|
|
|
+ tokenizer = GPTTokenizer.from_pretrained(path)
|
|
|
+
|
|
|
+ specials = [
|
|
|
+ "<|startoftranscript|>",
|
|
|
+ *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
|
|
+ "<|translate|>",
|
|
|
+ "<|transcribe|>",
|
|
|
+ "<|startoflm|>",
|
|
|
+ "<|startofprev|>",
|
|
|
+ "<|nospeech|>",
|
|
|
+ "<|notimestamps|>",
|
|
|
+ ]
|
|
|
|
|
|
- @lru_cache(maxsize=None)
|
|
|
- def get_tokenizer(
|
|
|
- multilingual: bool,
|
|
|
- resource_path: str,
|
|
|
- *,
|
|
|
- task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
|
|
- language: Optional[str] = None,
|
|
|
- ) -> Tokenizer:
|
|
|
- if language is not None:
|
|
|
- language = language.lower()
|
|
|
- if language not in LANGUAGES:
|
|
|
- if language in TO_LANGUAGE_CODE:
|
|
|
- language = TO_LANGUAGE_CODE[language]
|
|
|
- else:
|
|
|
- raise ValueError(f"Unsupported language: {language}")
|
|
|
+ tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
|
|
+ return tokenizer
|
|
|
+
|
|
|
+
|
|
|
+@lru_cache(maxsize=None)
|
|
|
+def get_tokenizer(
|
|
|
+ multilingual: bool,
|
|
|
+ resource_path: str,
|
|
|
+ *,
|
|
|
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
|
|
+ language: Optional[str] = None,
|
|
|
+) -> Tokenizer:
|
|
|
+ if language is not None:
|
|
|
+ language = language.lower()
|
|
|
+ if language not in LANGUAGES:
|
|
|
+ if language in TO_LANGUAGE_CODE:
|
|
|
+ language = TO_LANGUAGE_CODE[language]
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unsupported language: {language}")
|
|
|
+
|
|
|
+ if multilingual:
|
|
|
+ tokenizer_name = "multilingual"
|
|
|
+ task = task or "transcribe"
|
|
|
+ language = language or "en"
|
|
|
+ else:
|
|
|
+ tokenizer_name = "gpt2"
|
|
|
+ task = None
|
|
|
+ language = None
|
|
|
+
|
|
|
+ tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
|
|
|
+ all_special_ids: List[int] = tokenizer.all_special_ids
|
|
|
+ sot: int = all_special_ids[1]
|
|
|
+ translate: int = all_special_ids[-6]
|
|
|
+ transcribe: int = all_special_ids[-5]
|
|
|
+
|
|
|
+ langs = tuple(LANGUAGES.keys())
|
|
|
+ sot_sequence = [sot]
|
|
|
+ if language is not None:
|
|
|
+ sot_sequence.append(sot + 1 + langs.index(language))
|
|
|
+ if task is not None:
|
|
|
+ sot_sequence.append(transcribe if task == "transcribe" else translate)
|
|
|
+
|
|
|
+ return Tokenizer(
|
|
|
+ tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class MultiHeadAttention(paddle.nn.Layer):
|
|
|
+ def __init__(self, n_state: int, n_head: int):
|
|
|
+ super().__init__()
|
|
|
+ self.n_head = n_head
|
|
|
+ self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
|
|
|
+ self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
|
|
|
+ self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
|
|
|
+ self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ x: paddle.Tensor,
|
|
|
+ xa: Optional[paddle.Tensor] = None,
|
|
|
+ mask: Optional[paddle.Tensor] = None,
|
|
|
+ kv_cache: Optional[dict] = None,
|
|
|
+ ):
|
|
|
+ q = self.query(x)
|
|
|
|
|
|
- if multilingual:
|
|
|
- tokenizer_name = "multilingual"
|
|
|
- task = task or "transcribe"
|
|
|
- language = language or "en"
|
|
|
+ if kv_cache is None or xa is None or self.key not in kv_cache:
|
|
|
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
|
|
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
|
|
|
+ k = self.key(x if xa is None else xa)
|
|
|
+ v = self.value(x if xa is None else xa)
|
|
|
else:
|
|
|
- tokenizer_name = "gpt2"
|
|
|
- task = None
|
|
|
- language = None
|
|
|
-
|
|
|
- tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
|
|
|
- all_special_ids: List[int] = tokenizer.all_special_ids
|
|
|
- sot: int = all_special_ids[1]
|
|
|
- translate: int = all_special_ids[-6]
|
|
|
- transcribe: int = all_special_ids[-5]
|
|
|
-
|
|
|
- langs = tuple(LANGUAGES.keys())
|
|
|
- sot_sequence = [sot]
|
|
|
- if language is not None:
|
|
|
- sot_sequence.append(sot + 1 + langs.index(language))
|
|
|
- if task is not None:
|
|
|
- sot_sequence.append(transcribe if task == "transcribe" else translate)
|
|
|
-
|
|
|
- return Tokenizer(
|
|
|
- tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
|
|
|
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
|
|
+ k = kv_cache[self.key]
|
|
|
+ v = kv_cache[self.value]
|
|
|
+
|
|
|
+ wv = self.qkv_attention(q, k, v, mask)
|
|
|
+ return self.out(wv)
|
|
|
+
|
|
|
+ def qkv_attention(
|
|
|
+ self,
|
|
|
+ q: paddle.Tensor,
|
|
|
+ k: paddle.Tensor,
|
|
|
+ v: paddle.Tensor,
|
|
|
+ mask: Optional[paddle.Tensor] = None,
|
|
|
+ ):
|
|
|
+ n_batch, n_ctx, n_state = q.shape
|
|
|
+ scale = (n_state // self.n_head) ** -0.25
|
|
|
+ q = (
|
|
|
+ paddle.transpose(q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
|
|
|
+ * scale
|
|
|
)
|
|
|
+ k = (
|
|
|
+ paddle.transpose(k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1))
|
|
|
+ * scale
|
|
|
+ )
|
|
|
+ v = paddle.transpose(v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
|
|
|
|
|
|
- class MultiHeadAttention(paddle.nn.Layer):
|
|
|
- def __init__(self, n_state: int, n_head: int):
|
|
|
- super().__init__()
|
|
|
- self.n_head = n_head
|
|
|
- self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
|
|
|
- self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
|
|
|
- self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
|
|
|
- self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
|
|
|
-
|
|
|
- def forward(
|
|
|
- self,
|
|
|
- x: paddle.Tensor,
|
|
|
- xa: Optional[paddle.Tensor] = None,
|
|
|
- mask: Optional[paddle.Tensor] = None,
|
|
|
- kv_cache: Optional[dict] = None,
|
|
|
- ):
|
|
|
- q = self.query(x)
|
|
|
-
|
|
|
- if kv_cache is None or xa is None or self.key not in kv_cache:
|
|
|
- # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
|
|
- # otherwise, perform key/value projections for self- or cross-attention as usual.
|
|
|
- k = self.key(x if xa is None else xa)
|
|
|
- v = self.value(x if xa is None else xa)
|
|
|
- else:
|
|
|
- # for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
|
|
- k = kv_cache[self.key]
|
|
|
- v = kv_cache[self.value]
|
|
|
-
|
|
|
- wv = self.qkv_attention(q, k, v, mask)
|
|
|
- return self.out(wv)
|
|
|
-
|
|
|
- def qkv_attention(
|
|
|
- self,
|
|
|
- q: paddle.Tensor,
|
|
|
- k: paddle.Tensor,
|
|
|
- v: paddle.Tensor,
|
|
|
- mask: Optional[paddle.Tensor] = None,
|
|
|
- ):
|
|
|
- n_batch, n_ctx, n_state = q.shape
|
|
|
- scale = (n_state // self.n_head) ** -0.25
|
|
|
- q = (
|
|
|
- paddle.transpose(
|
|
|
- q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3)
|
|
|
- )
|
|
|
- * scale
|
|
|
- )
|
|
|
- k = (
|
|
|
- paddle.transpose(
|
|
|
- k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1)
|
|
|
- )
|
|
|
- * scale
|
|
|
- )
|
|
|
- v = paddle.transpose(
|
|
|
- v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3)
|
|
|
- )
|
|
|
+ qk = q @ k
|
|
|
+ if mask is not None:
|
|
|
+ qk = qk + mask[:n_ctx, :n_ctx]
|
|
|
|
|
|
- qk = q @ k
|
|
|
- if mask is not None:
|
|
|
- qk = qk + mask[:n_ctx, :n_ctx]
|
|
|
+ w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
|
|
|
+ return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
|
|
|
|
|
|
- w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
|
|
|
- return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
|
|
|
|
|
|
- class ResidualAttentionBlock(paddle.nn.Layer):
|
|
|
- def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
|
|
- super().__init__()
|
|
|
+class ResidualAttentionBlock(paddle.nn.Layer):
|
|
|
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
|
|
+ super().__init__()
|
|
|
|
|
|
- self.attn = MultiHeadAttention(n_state, n_head)
|
|
|
- self.attn_ln = paddle.nn.LayerNorm(n_state)
|
|
|
+ self.attn = MultiHeadAttention(n_state, n_head)
|
|
|
+ self.attn_ln = paddle.nn.LayerNorm(n_state)
|
|
|
|
|
|
- self.cross_attn = (
|
|
|
- MultiHeadAttention(n_state, n_head) if cross_attention else None
|
|
|
- )
|
|
|
- self.cross_attn_ln = (
|
|
|
- paddle.nn.LayerNorm(n_state) if cross_attention else None
|
|
|
- )
|
|
|
+ self.cross_attn = (
|
|
|
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
|
|
|
+ )
|
|
|
+ self.cross_attn_ln = paddle.nn.LayerNorm(n_state) if cross_attention else None
|
|
|
|
|
|
- n_mlp = n_state * 4
|
|
|
- self.mlp = paddle.nn.Sequential(
|
|
|
- paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
|
|
|
- paddle.nn.GELU(),
|
|
|
- paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
|
|
|
- )
|
|
|
- self.mlp_ln = paddle.nn.LayerNorm(n_state)
|
|
|
-
|
|
|
- def forward(
|
|
|
- self,
|
|
|
- x: paddle.Tensor,
|
|
|
- xa: Optional[paddle.Tensor] = None,
|
|
|
- mask: Optional[paddle.Tensor] = None,
|
|
|
- kv_cache: Optional[dict] = None,
|
|
|
- ):
|
|
|
- x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
|
|
- if self.cross_attn:
|
|
|
- x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
|
|
- x = x + self.mlp(self.mlp_ln(x))
|
|
|
- return x
|
|
|
-
|
|
|
- def sinusoids(length, channels, max_timescale=10000):
|
|
|
- """Returns sinusoids for positional embedding"""
|
|
|
- assert channels % 2 == 0
|
|
|
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
|
|
- inv_timescales = paddle.exp(
|
|
|
- -log_timescale_increment
|
|
|
- * paddle.arange(channels // 2, dtype=paddle.float32)
|
|
|
+ n_mlp = n_state * 4
|
|
|
+ self.mlp = paddle.nn.Sequential(
|
|
|
+ paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
|
|
|
+ paddle.nn.GELU(),
|
|
|
+ paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
|
|
|
)
|
|
|
- scaled_time = (
|
|
|
- paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
|
|
|
- * inv_timescales[np.newaxis, :]
|
|
|
+ self.mlp_ln = paddle.nn.LayerNorm(n_state)
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ x: paddle.Tensor,
|
|
|
+ xa: Optional[paddle.Tensor] = None,
|
|
|
+ mask: Optional[paddle.Tensor] = None,
|
|
|
+ kv_cache: Optional[dict] = None,
|
|
|
+ ):
|
|
|
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
|
|
+ if self.cross_attn:
|
|
|
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
|
|
+ x = x + self.mlp(self.mlp_ln(x))
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+def sinusoids(length, channels, max_timescale=10000):
|
|
|
+ """Returns sinusoids for positional embedding"""
|
|
|
+ assert channels % 2 == 0
|
|
|
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
|
|
+ inv_timescales = paddle.exp(
|
|
|
+ -log_timescale_increment * paddle.arange(channels // 2, dtype=paddle.float32)
|
|
|
+ )
|
|
|
+ scaled_time = (
|
|
|
+ paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
|
|
|
+ * inv_timescales[np.newaxis, :]
|
|
|
+ )
|
|
|
+ return paddle.to_tensor(
|
|
|
+ paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class AudioEncoder(paddle.nn.Layer):
|
|
|
+ def __init__(
|
|
|
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.conv1 = paddle.nn.Conv1D(
|
|
|
+ n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
|
|
|
)
|
|
|
- return paddle.to_tensor(
|
|
|
- paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
|
|
|
+ self.conv2 = paddle.nn.Conv1D(
|
|
|
+ n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
|
|
|
)
|
|
|
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
|
|
|
|
|
- class AudioEncoder(paddle.nn.Layer):
|
|
|
- def __init__(
|
|
|
- self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
|
|
- ):
|
|
|
- super().__init__()
|
|
|
- self.conv1 = paddle.nn.Conv1D(
|
|
|
- n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
|
|
|
- )
|
|
|
- self.conv2 = paddle.nn.Conv1D(
|
|
|
- n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
|
|
|
- )
|
|
|
- self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
|
|
-
|
|
|
- self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
|
|
|
- [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
|
|
- )
|
|
|
- self.ln_post = paddle.nn.LayerNorm(n_state)
|
|
|
-
|
|
|
- def forward(self, x: paddle.Tensor):
|
|
|
- """
|
|
|
- x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
|
|
|
- the mel spectrogram of the audio
|
|
|
- """
|
|
|
- x = paddle.nn.functional.gelu(self.conv1(x))
|
|
|
- x = paddle.nn.functional.gelu(self.conv2(x))
|
|
|
- x = paddle.transpose(x, (0, 2, 1))
|
|
|
-
|
|
|
- assert (
|
|
|
- x.shape[1:] == self.positional_embedding.shape
|
|
|
- ), "incorrect audio shape"
|
|
|
- x = x + self.positional_embedding
|
|
|
-
|
|
|
- for block in self.blocks:
|
|
|
- x = block(x)
|
|
|
-
|
|
|
- x = self.ln_post(x)
|
|
|
- return x
|
|
|
-
|
|
|
- class TextDecoder(paddle.nn.Layer):
|
|
|
- def __init__(
|
|
|
- self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
|
|
- ):
|
|
|
- super().__init__()
|
|
|
-
|
|
|
- self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
|
|
|
- self.positional_embedding = paddle.create_parameter(
|
|
|
- shape=[n_ctx, n_state], dtype="float32"
|
|
|
- )
|
|
|
+ self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
|
|
|
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
|
|
+ )
|
|
|
+ self.ln_post = paddle.nn.LayerNorm(n_state)
|
|
|
|
|
|
- self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
|
|
|
- [
|
|
|
- ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
|
|
- for _ in range(n_layer)
|
|
|
- ]
|
|
|
- )
|
|
|
- self.ln = paddle.nn.LayerNorm(n_state)
|
|
|
+ def forward(self, x: paddle.Tensor):
|
|
|
+ """
|
|
|
+ x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
|
|
|
+ the mel spectrogram of the audio
|
|
|
+ """
|
|
|
+ x = paddle.nn.functional.gelu(self.conv1(x))
|
|
|
+ x = paddle.nn.functional.gelu(self.conv2(x))
|
|
|
+ x = paddle.transpose(x, (0, 2, 1))
|
|
|
|
|
|
- mask = paddle.full(
|
|
|
- shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32"
|
|
|
- )
|
|
|
- mask = paddle.triu(mask, diagonal=1)
|
|
|
- self.register_buffer("mask", mask, persistable=False)
|
|
|
+ assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
|
|
+ x = x + self.positional_embedding
|
|
|
|
|
|
- def forward(
|
|
|
- self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
|
|
|
- ):
|
|
|
- """
|
|
|
- x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
|
|
|
- the text tokens
|
|
|
- xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
|
|
- the encoded audio features to be attended on
|
|
|
- """
|
|
|
- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
|
|
- x = (
|
|
|
- self.token_embedding(x)
|
|
|
- + self.positional_embedding[offset : offset + x.shape[-1]]
|
|
|
- )
|
|
|
- x = x.to(xa.dtype)
|
|
|
+ for block in self.blocks:
|
|
|
+ x = block(x)
|
|
|
|
|
|
- for block in self.blocks:
|
|
|
- x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
|
|
+ x = self.ln_post(x)
|
|
|
+ return x
|
|
|
|
|
|
- x = self.ln(x)
|
|
|
- logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
|
|
|
|
|
|
- return logits
|
|
|
+class TextDecoder(paddle.nn.Layer):
|
|
|
+ def __init__(
|
|
|
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
|
|
|
- @dataclass(frozen=True)
|
|
|
- class DecodingOptions:
|
|
|
- task: str = (
|
|
|
- "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
|
|
+ self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
|
|
|
+ self.positional_embedding = paddle.create_parameter(
|
|
|
+ shape=[n_ctx, n_state], dtype="float32"
|
|
|
)
|
|
|
- language: Optional[str] = (
|
|
|
- None # language that the audio is in; uses detected language if None
|
|
|
+
|
|
|
+ self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
|
|
|
+ [
|
|
|
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
|
|
+ for _ in range(n_layer)
|
|
|
+ ]
|
|
|
)
|
|
|
- # sampling-related options
|
|
|
- temperature: float = 0.0
|
|
|
- sample_len: Optional[int] = None # maximum number of tokens to sample
|
|
|
- best_of: Optional[int] = (
|
|
|
- None # number of independent samples to collect, when t > 0
|
|
|
+ self.ln = paddle.nn.LayerNorm(n_state)
|
|
|
+
|
|
|
+ mask = paddle.full(shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32")
|
|
|
+ mask = paddle.triu(mask, diagonal=1)
|
|
|
+ self.register_buffer("mask", mask, persistable=False)
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
|
|
|
+ the text tokens
|
|
|
+ xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
|
|
+ the encoded audio features to be attended on
|
|
|
+ """
|
|
|
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
|
|
+ x = (
|
|
|
+ self.token_embedding(x)
|
|
|
+ + self.positional_embedding[offset : offset + x.shape[-1]]
|
|
|
)
|
|
|
- beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
|
|
- patience: Optional[float] = (
|
|
|
- None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
|
|
+ x = x.to(xa.dtype)
|
|
|
+
|
|
|
+ for block in self.blocks:
|
|
|
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
|
|
+
|
|
|
+ x = self.ln(x)
|
|
|
+ logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
|
|
|
+
|
|
|
+ return logits
|
|
|
+
|
|
|
+
|
|
|
+@dataclass(frozen=True)
|
|
|
+class DecodingOptions:
|
|
|
+ task: str = (
|
|
|
+ "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
|
|
+ )
|
|
|
+ language: Optional[str] = (
|
|
|
+ None # language that the audio is in; uses detected language if None
|
|
|
+ )
|
|
|
+ # sampling-related options
|
|
|
+ temperature: float = 0.0
|
|
|
+ sample_len: Optional[int] = None # maximum number of tokens to sample
|
|
|
+ best_of: Optional[int] = (
|
|
|
+ None # number of independent samples to collect, when t > 0
|
|
|
+ )
|
|
|
+ beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
|
|
+ patience: Optional[float] = (
|
|
|
+ None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
|
|
+ )
|
|
|
+
|
|
|
+ # options for ranking generations (either beams or best-of-N samples)
|
|
|
+ length_penalty: Optional[float] = (
|
|
|
+ None # "alpha" in Google NMT, None defaults to length norm
|
|
|
+ )
|
|
|
+
|
|
|
+ # prompt, prefix, and token suppression
|
|
|
+ prompt: Optional[Union[str, List[int]]] = (
|
|
|
+ None # text or tokens for the previous context
|
|
|
+ )
|
|
|
+ prefix: Optional[Union[str, List[int]]] = (
|
|
|
+ None # text or tokens to prefix the current context
|
|
|
+ )
|
|
|
+ suppress_blank: bool = True # this will suppress blank outputs
|
|
|
+
|
|
|
+ # list of tokens ids (or comma-separated token ids) to suppress
|
|
|
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
|
|
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
|
|
+
|
|
|
+ # timestamp sampling options
|
|
|
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
|
|
+ max_initial_timestamp: Optional[float] = (
|
|
|
+ 1.0 # the initial timestamp cannot be later than this
|
|
|
+ )
|
|
|
+
|
|
|
+ # implementation details
|
|
|
+ fp16: bool = False # use fp16 for most of the calculation
|
|
|
+
|
|
|
+
|
|
|
+@dataclass(frozen=True)
|
|
|
+class DecodingResult:
|
|
|
+ audio_features: paddle.Tensor
|
|
|
+ language: str
|
|
|
+ language_probs: Optional[Dict[str, float]] = None
|
|
|
+ tokens: List[int] = field(default_factory=list)
|
|
|
+ text: str = ""
|
|
|
+ avg_logprob: float = np.nan
|
|
|
+ no_speech_prob: float = np.nan
|
|
|
+ temperature: float = np.nan
|
|
|
+ compression_ratio: float = np.nan
|
|
|
+
|
|
|
+
|
|
|
+class Inference:
|
|
|
+ def logits(
|
|
|
+ self, tokens: paddle.Tensor, audio_features: paddle.Tensor
|
|
|
+ ) -> paddle.Tensor:
|
|
|
+ """Perform a forward pass on the decoder and return per-token logits"""
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+ def rearrange_kv_cache(self, source_indices) -> None:
|
|
|
+ """Update the key-value cache according to the updated beams"""
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+ def cleanup_caching(self) -> None:
|
|
|
+ """Clean up any resources or hooks after decoding is finished"""
|
|
|
+
|
|
|
+
|
|
|
+class WhisperInference(Inference):
|
|
|
+ def __init__(self, model: "Whisper", initial_token_length: int):
|
|
|
+ self.model: "Whisper" = model
|
|
|
+ self.initial_token_length = initial_token_length
|
|
|
+ self.kv_cache = {}
|
|
|
+ self.hooks = []
|
|
|
+
|
|
|
+ def logits(
|
|
|
+ self, tokens: paddle.Tensor, audio_features: paddle.Tensor
|
|
|
+ ) -> paddle.Tensor:
|
|
|
+ if not self.kv_cache:
|
|
|
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
|
|
+
|
|
|
+ if tokens.shape[-1] > self.initial_token_length:
|
|
|
+ # only need to use the last token except in the first forward pass
|
|
|
+ tokens = tokens[:, -1:]
|
|
|
+
|
|
|
+ return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
|
|
+
|
|
|
+ def cleanup_caching(self):
|
|
|
+ for hook in self.hooks:
|
|
|
+ hook.remove()
|
|
|
+
|
|
|
+ self.kv_cache = {}
|
|
|
+ self.hooks = []
|
|
|
+
|
|
|
+ def rearrange_kv_cache(self, source_indices):
|
|
|
+ for module, tensor in self.kv_cache.items():
|
|
|
+ # update the key/value cache to contain the selected sequences
|
|
|
+ self.kv_cache[module] = tensor[source_indices].detach()
|
|
|
+
|
|
|
+
|
|
|
+@paddle.no_grad()
|
|
|
+def detect_language(
|
|
|
+ model: "Whisper",
|
|
|
+ mel: paddle.Tensor,
|
|
|
+ resource_path: str,
|
|
|
+ tokenizer: Tokenizer = None,
|
|
|
+) -> Tuple[paddle.Tensor, List[dict]]:
|
|
|
+ """
|
|
|
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
|
|
+ of the most probable language tokens and the probability distribution over all language tokens.
|
|
|
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ language_tokens : Tensor, shape = (batch_size,)
|
|
|
+ ids of the most probable language tokens, which appears after the startoftranscript token.
|
|
|
+ language_probs : List[Dict[str, float]], length = batch_size
|
|
|
+ list of dictionaries containing the probability distribution over all languages.
|
|
|
+ """
|
|
|
+ if tokenizer is None:
|
|
|
+ tokenizer = get_tokenizer(model.is_multilingual, resource_path=resource_path)
|
|
|
+ if (
|
|
|
+ tokenizer.language is None
|
|
|
+ or tokenizer.language_token not in tokenizer.sot_sequence
|
|
|
+ ):
|
|
|
+ raise ValueError(
|
|
|
+ "This model doesn't have language tokens so it can't perform lang id"
|
|
|
)
|
|
|
|
|
|
- # options for ranking generations (either beams or best-of-N samples)
|
|
|
- length_penalty: Optional[float] = (
|
|
|
- None # "alpha" in Google NMT, None defaults to length norm
|
|
|
- )
|
|
|
+ single = mel.ndim == 2
|
|
|
+ if single:
|
|
|
+ mel = mel.unsqueeze(0)
|
|
|
+
|
|
|
+ # skip encoder forward pass if already-encoded audio features were given
|
|
|
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
|
|
+ mel = model.encoder(mel)
|
|
|
+
|
|
|
+ # forward pass using a single token, startoftranscript
|
|
|
+ batch_size = mel.shape[0]
|
|
|
+ x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
|
|
|
+ logits = model.logits(x, mel)[:, 0]
|
|
|
+
|
|
|
+ # collect detected languages; suppress all non-language tokens
|
|
|
+ mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
|
|
|
+ mask[list(tokenizer.all_language_tokens)] = False
|
|
|
+ logits[:, mask] = -np.inf
|
|
|
+ language_tokens = paddle.argmax(logits, axis=-1)
|
|
|
+ language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
|
|
|
+ language_probs = [
|
|
|
+ {
|
|
|
+ c: language_token_probs[i, j].tolist()
|
|
|
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
|
|
+ }
|
|
|
+ for i in range(batch_size)
|
|
|
+ ]
|
|
|
|
|
|
- # prompt, prefix, and token suppression
|
|
|
- prompt: Optional[Union[str, List[int]]] = (
|
|
|
- None # text or tokens for the previous context
|
|
|
- )
|
|
|
- prefix: Optional[Union[str, List[int]]] = (
|
|
|
- None # text or tokens to prefix the current context
|
|
|
+ if single:
|
|
|
+ language_tokens = language_tokens[0]
|
|
|
+ language_probs = language_probs[0]
|
|
|
+
|
|
|
+ return language_tokens, language_probs
|
|
|
+
|
|
|
+
|
|
|
+@function_requires_deps("tqdm")
|
|
|
+def transcribe(
|
|
|
+ model: "Whisper",
|
|
|
+ mel: paddle.Tensor,
|
|
|
+ resource_path: str,
|
|
|
+ *,
|
|
|
+ verbose: Optional[bool] = None,
|
|
|
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
|
|
+ compression_ratio_threshold: Optional[float] = 2.4,
|
|
|
+ logprob_threshold: Optional[float] = -1.0,
|
|
|
+ no_speech_threshold: Optional[float] = 0.6,
|
|
|
+ condition_on_previous_text: bool = True,
|
|
|
+ **decode_options,
|
|
|
+):
|
|
|
+ """
|
|
|
+ Transcribe an audio file using Whisper
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ model: Whisper
|
|
|
+ The Whisper model instance
|
|
|
+ mel: paddle.Tensor
|
|
|
+ The audio feature
|
|
|
+ verbose: bool
|
|
|
+ Whether to display the text being decoded to the console. If True, displays all the details,
|
|
|
+ If False, displays minimal details. If None, does not display anything
|
|
|
+ temperature: Union[float, Tuple[float, ...]]
|
|
|
+ Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
|
|
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
|
|
+ compression_ratio_threshold: float
|
|
|
+ If the gzip compression ratio is above this value, treat as failed
|
|
|
+ logprob_threshold: float
|
|
|
+ If the average log probability over sampled tokens is below this value, treat as failed
|
|
|
+ no_speech_threshold: float
|
|
|
+ If the no_speech probability is higher than this value AND the average log probability
|
|
|
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
|
|
+ condition_on_previous_text: bool
|
|
|
+ if True, the previous output of the model is provided as a prompt for the next window;
|
|
|
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
|
|
|
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
|
|
+ decode_options: dict
|
|
|
+ Keyword arguments to construct `DecodingOptions` instances
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
|
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
|
+ """
|
|
|
+ dtype = np.float32 # paddle only support float32
|
|
|
+
|
|
|
+ if dtype == np.float32:
|
|
|
+ decode_options["fp16"] = False
|
|
|
+
|
|
|
+ if (
|
|
|
+ decode_options.get("language") == "None"
|
|
|
+ or decode_options.get("language", None) is None
|
|
|
+ ):
|
|
|
+ if not model.is_multilingual:
|
|
|
+ decode_options["language"] = "en"
|
|
|
+ else:
|
|
|
+ if verbose:
|
|
|
+ print(
|
|
|
+ "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
|
|
+ )
|
|
|
+ segment = pad_or_trim(mel, N_FRAMES)
|
|
|
+ _, probs = model.detect_language(segment, resource_path)
|
|
|
+ decode_options["language"] = max(probs, key=probs.get)
|
|
|
+ if verbose is not None:
|
|
|
+ print(
|
|
|
+ f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
|
|
+ )
|
|
|
+
|
|
|
+ language = decode_options["language"]
|
|
|
+ task = decode_options.get("task", "transcribe")
|
|
|
+ tokenizer = get_tokenizer(
|
|
|
+ model.is_multilingual,
|
|
|
+ resource_path=resource_path,
|
|
|
+ language=language,
|
|
|
+ task=task,
|
|
|
+ )
|
|
|
+
|
|
|
+ def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
|
|
|
+ temperatures = (
|
|
|
+ [temperature] if isinstance(temperature, (int, float)) else temperature
|
|
|
)
|
|
|
- suppress_blank: bool = True # this will suppress blank outputs
|
|
|
+ decode_result = None
|
|
|
+
|
|
|
+ for t in temperatures:
|
|
|
+ kwargs = {**decode_options}
|
|
|
+ if t > 0:
|
|
|
+ # disable beam_size and patience when t > 0
|
|
|
+ kwargs.pop("beam_size", None)
|
|
|
+ kwargs.pop("patience", None)
|
|
|
+ else:
|
|
|
+ # disable best_of when t == 0
|
|
|
+ kwargs.pop("best_of", None)
|
|
|
|
|
|
- # list of tokens ids (or comma-separated token ids) to suppress
|
|
|
- # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
|
|
- suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
|
|
+ options = DecodingOptions(**kwargs, temperature=t)
|
|
|
+ decode_result = model.decode(segment, options, resource_path)
|
|
|
|
|
|
- # timestamp sampling options
|
|
|
- without_timestamps: bool = (
|
|
|
- False # use <|notimestamps|> to sample text tokens only
|
|
|
+ needs_fallback = False
|
|
|
+ if (
|
|
|
+ compression_ratio_threshold is not None
|
|
|
+ and decode_result.compression_ratio > compression_ratio_threshold
|
|
|
+ ):
|
|
|
+ needs_fallback = True # too repetitive
|
|
|
+ if (
|
|
|
+ logprob_threshold is not None
|
|
|
+ and decode_result.avg_logprob < logprob_threshold
|
|
|
+ ):
|
|
|
+ needs_fallback = True # average log probability is too low
|
|
|
+
|
|
|
+ if not needs_fallback:
|
|
|
+ break
|
|
|
+
|
|
|
+ return decode_result
|
|
|
+
|
|
|
+ seek = 0
|
|
|
+ input_stride = exact_div(
|
|
|
+ N_FRAMES, model.dims.n_audio_ctx
|
|
|
+ ) # mel frames per output token: 2
|
|
|
+ time_precision = (
|
|
|
+ input_stride * HOP_LENGTH / SAMPLE_RATE
|
|
|
+ ) # time per output token: 0.02 (seconds)
|
|
|
+ all_tokens = []
|
|
|
+ all_segments = []
|
|
|
+ prompt_reset_since = 0
|
|
|
+
|
|
|
+ initial_prompt = decode_options.pop("initial_prompt", None)
|
|
|
+ if initial_prompt and initial_prompt != "None":
|
|
|
+ initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
|
|
|
+ all_tokens.extend(initial_prompt)
|
|
|
+ else:
|
|
|
+ initial_prompt = []
|
|
|
+
|
|
|
+ def add_segment(
|
|
|
+ *,
|
|
|
+ start: float,
|
|
|
+ end: float,
|
|
|
+ text_tokens: paddle.Tensor,
|
|
|
+ result: DecodingResult,
|
|
|
+ ):
|
|
|
+ text = tokenizer.decode(
|
|
|
+ [token for token in text_tokens if token < tokenizer.eot]
|
|
|
)
|
|
|
- max_initial_timestamp: Optional[float] = (
|
|
|
- 1.0 # the initial timestamp cannot be later than this
|
|
|
+ if len(text.strip()) == 0: # skip empty text output
|
|
|
+ return
|
|
|
+
|
|
|
+ all_segments.append(
|
|
|
+ {
|
|
|
+ "id": len(all_segments),
|
|
|
+ "seek": seek,
|
|
|
+ "start": start,
|
|
|
+ "end": end,
|
|
|
+ "text": text,
|
|
|
+ "tokens": result.tokens,
|
|
|
+ "temperature": result.temperature,
|
|
|
+ "avg_logprob": result.avg_logprob,
|
|
|
+ "compression_ratio": result.compression_ratio,
|
|
|
+ "no_speech_prob": result.no_speech_prob,
|
|
|
+ }
|
|
|
)
|
|
|
+ if verbose:
|
|
|
+ print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
|
|
+
|
|
|
+ # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
|
|
+ num_frames = mel.shape[-1]
|
|
|
+ previous_seek_value = seek
|
|
|
+
|
|
|
+ with tqdm.tqdm(
|
|
|
+ total=num_frames, unit="frames", disable=verbose is not False
|
|
|
+ ) as pbar:
|
|
|
+ while seek < num_frames:
|
|
|
+ timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
|
+ segment = pad_or_trim(mel[:, seek:], N_FRAMES)
|
|
|
+ segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
|
|
+
|
|
|
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
|
+ result: DecodingResult = decode_with_fallback(segment)
|
|
|
+ tokens = paddle.to_tensor(result.tokens)
|
|
|
+
|
|
|
+ if no_speech_threshold is not None:
|
|
|
+ # no voice activity check
|
|
|
+ should_skip = result.no_speech_prob > no_speech_threshold
|
|
|
+ if (
|
|
|
+ logprob_threshold is not None
|
|
|
+ and result.avg_logprob > logprob_threshold
|
|
|
+ ):
|
|
|
+ # don't skip if the logprob is high enough, despite the no_speech_prob
|
|
|
+ should_skip = False
|
|
|
|
|
|
- # implementation details
|
|
|
- fp16: bool = False # use fp16 for most of the calculation
|
|
|
-
|
|
|
- @dataclass(frozen=True)
|
|
|
- class DecodingResult:
|
|
|
- audio_features: paddle.Tensor
|
|
|
- language: str
|
|
|
- language_probs: Optional[Dict[str, float]] = None
|
|
|
- tokens: List[int] = field(default_factory=list)
|
|
|
- text: str = ""
|
|
|
- avg_logprob: float = np.nan
|
|
|
- no_speech_prob: float = np.nan
|
|
|
- temperature: float = np.nan
|
|
|
- compression_ratio: float = np.nan
|
|
|
-
|
|
|
- class Inference:
|
|
|
- def logits(
|
|
|
- self, tokens: paddle.Tensor, audio_features: paddle.Tensor
|
|
|
- ) -> paddle.Tensor:
|
|
|
- """Perform a forward pass on the decoder and return per-token logits"""
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- def rearrange_kv_cache(self, source_indices) -> None:
|
|
|
- """Update the key-value cache according to the updated beams"""
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- def cleanup_caching(self) -> None:
|
|
|
- """Clean up any resources or hooks after decoding is finished"""
|
|
|
-
|
|
|
- class WhisperInference(Inference):
|
|
|
- def __init__(self, model: "Whisper", initial_token_length: int):
|
|
|
- self.model: "Whisper" = model
|
|
|
- self.initial_token_length = initial_token_length
|
|
|
- self.kv_cache = {}
|
|
|
- self.hooks = []
|
|
|
-
|
|
|
- def logits(
|
|
|
- self, tokens: paddle.Tensor, audio_features: paddle.Tensor
|
|
|
- ) -> paddle.Tensor:
|
|
|
- if not self.kv_cache:
|
|
|
- self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
|
|
-
|
|
|
- if tokens.shape[-1] > self.initial_token_length:
|
|
|
- # only need to use the last token except in the first forward pass
|
|
|
- tokens = tokens[:, -1:]
|
|
|
-
|
|
|
- return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
|
|
-
|
|
|
- def cleanup_caching(self):
|
|
|
- for hook in self.hooks:
|
|
|
- hook.remove()
|
|
|
-
|
|
|
- self.kv_cache = {}
|
|
|
- self.hooks = []
|
|
|
-
|
|
|
- def rearrange_kv_cache(self, source_indices):
|
|
|
- for module, tensor in self.kv_cache.items():
|
|
|
- # update the key/value cache to contain the selected sequences
|
|
|
- self.kv_cache[module] = tensor[source_indices].detach()
|
|
|
+ if should_skip:
|
|
|
+ seek += segment.shape[
|
|
|
+ -1
|
|
|
+ ] # fast-forward to the next segment boundary
|
|
|
+ continue
|
|
|
|
|
|
- @paddle.no_grad()
|
|
|
- def detect_language(
|
|
|
- model: "Whisper",
|
|
|
- mel: paddle.Tensor,
|
|
|
- resource_path: str,
|
|
|
- tokenizer: Tokenizer = None,
|
|
|
- ) -> Tuple[paddle.Tensor, List[dict]]:
|
|
|
- """
|
|
|
- Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
|
|
- of the most probable language tokens and the probability distribution over all language tokens.
|
|
|
- This is performed outside the main decode loop in order to not interfere with kv-caching.
|
|
|
- Returns
|
|
|
- -------
|
|
|
- language_tokens : Tensor, shape = (batch_size,)
|
|
|
- ids of the most probable language tokens, which appears after the startoftranscript token.
|
|
|
- language_probs : List[Dict[str, float]], length = batch_size
|
|
|
- list of dictionaries containing the probability distribution over all languages.
|
|
|
- """
|
|
|
- if tokenizer is None:
|
|
|
- tokenizer = get_tokenizer(
|
|
|
- model.is_multilingual, resource_path=resource_path
|
|
|
- )
|
|
|
- if (
|
|
|
- tokenizer.language is None
|
|
|
- or tokenizer.language_token not in tokenizer.sot_sequence
|
|
|
- ):
|
|
|
- raise ValueError(
|
|
|
- "This model doesn't have language tokens so it can't perform lang id"
|
|
|
+ timestamp_tokens: paddle.Tensor = tokens.greater_equal(
|
|
|
+ paddle.to_tensor(tokenizer.timestamp_begin)
|
|
|
)
|
|
|
|
|
|
- single = mel.ndim == 2
|
|
|
- if single:
|
|
|
- mel = mel.unsqueeze(0)
|
|
|
-
|
|
|
- # skip encoder forward pass if already-encoded audio features were given
|
|
|
- if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
|
|
- mel = model.encoder(mel)
|
|
|
-
|
|
|
- # forward pass using a single token, startoftranscript
|
|
|
- batch_size = mel.shape[0]
|
|
|
- x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
|
|
|
- logits = model.logits(x, mel)[:, 0]
|
|
|
-
|
|
|
- # collect detected languages; suppress all non-language tokens
|
|
|
- mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
|
|
|
- mask[list(tokenizer.all_language_tokens)] = False
|
|
|
- logits[:, mask] = -np.inf
|
|
|
- language_tokens = paddle.argmax(logits, axis=-1)
|
|
|
- language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
|
|
|
- language_probs = [
|
|
|
- {
|
|
|
- c: language_token_probs[i, j].tolist()
|
|
|
- for j, c in zip(
|
|
|
- tokenizer.all_language_tokens, tokenizer.all_language_codes
|
|
|
+ consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
|
|
+ if (
|
|
|
+ len(consecutive) > 0
|
|
|
+ ): # if the output contains two consecutive timestamp tokens
|
|
|
+ consecutive = paddle.add(consecutive, paddle.to_tensor(1))
|
|
|
+ last_slice = 0
|
|
|
+ for current_slice in consecutive:
|
|
|
+ sliced_tokens = tokens[last_slice:current_slice]
|
|
|
+ start_timestamp_position = (
|
|
|
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
|
|
|
+ )
|
|
|
+ end_timestamp_position = (
|
|
|
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
|
|
+ )
|
|
|
+ add_segment(
|
|
|
+ start=timestamp_offset
|
|
|
+ + start_timestamp_position * time_precision,
|
|
|
+ end=timestamp_offset + end_timestamp_position * time_precision,
|
|
|
+ text_tokens=sliced_tokens[1:-1],
|
|
|
+ result=result,
|
|
|
+ )
|
|
|
+ last_slice = current_slice
|
|
|
+ last_timestamp_position = (
|
|
|
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
|
|
)
|
|
|
- }
|
|
|
- for i in range(batch_size)
|
|
|
- ]
|
|
|
+ seek += last_timestamp_position * input_stride
|
|
|
+ all_tokens.extend(tokens[: last_slice + 1].tolist())
|
|
|
+ else:
|
|
|
+ duration = segment_duration
|
|
|
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
|
+ if (
|
|
|
+ len(timestamps) > 0
|
|
|
+ and timestamps[-1].item() != tokenizer.timestamp_begin
|
|
|
+ ):
|
|
|
+ # no consecutive timestamps but it has a timestamp; use the last one.
|
|
|
+ # single timestamp at the end means no speech after the last timestamp.
|
|
|
+ last_timestamp_position = (
|
|
|
+ timestamps[-1].item() - tokenizer.timestamp_begin
|
|
|
+ )
|
|
|
+ duration = last_timestamp_position * time_precision
|
|
|
|
|
|
- if single:
|
|
|
- language_tokens = language_tokens[0]
|
|
|
- language_probs = language_probs[0]
|
|
|
+ add_segment(
|
|
|
+ start=timestamp_offset,
|
|
|
+ end=timestamp_offset + duration,
|
|
|
+ text_tokens=tokens,
|
|
|
+ result=result,
|
|
|
+ )
|
|
|
|
|
|
- return language_tokens, language_probs
|
|
|
+ seek += segment.shape[-1]
|
|
|
+ all_tokens.extend(tokens.tolist())
|
|
|
|
|
|
- @function_requires_deps("tqdm")
|
|
|
- def transcribe(
|
|
|
- model: "Whisper",
|
|
|
- mel: paddle.Tensor,
|
|
|
- resource_path: str,
|
|
|
- *,
|
|
|
- verbose: Optional[bool] = None,
|
|
|
- temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
|
|
- compression_ratio_threshold: Optional[float] = 2.4,
|
|
|
- logprob_threshold: Optional[float] = -1.0,
|
|
|
- no_speech_threshold: Optional[float] = 0.6,
|
|
|
- condition_on_previous_text: bool = True,
|
|
|
- **decode_options,
|
|
|
- ):
|
|
|
+ if not condition_on_previous_text or result.temperature > 0.5:
|
|
|
+ # do not feed the prompt tokens if a high temperature was used
|
|
|
+ prompt_reset_since = len(all_tokens)
|
|
|
+
|
|
|
+ # update progress bar
|
|
|
+ pbar.update(min(num_frames, seek) - previous_seek_value)
|
|
|
+ previous_seek_value = seek
|
|
|
+
|
|
|
+ return dict(
|
|
|
+ text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
|
|
|
+ segments=all_segments,
|
|
|
+ language=language,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class SequenceRanker:
|
|
|
+ def rank(
|
|
|
+ self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
|
|
|
+ ) -> List[int]:
|
|
|
"""
|
|
|
- Transcribe an audio file using Whisper
|
|
|
+ Given a list of groups of samples and their cumulative log probabilities,
|
|
|
+ return the indices of the samples in each group to select as the final result
|
|
|
+ """
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+
|
|
|
+class MaximumLikelihoodRanker(SequenceRanker):
|
|
|
+ """
|
|
|
+ Select the sample with the highest log probabilities, penalized using either
|
|
|
+ a simple length normalization or Google NMT paper's length penalty
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, length_penalty: Optional[float]):
|
|
|
+ self.length_penalty = length_penalty
|
|
|
+
|
|
|
+ def rank(self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]):
|
|
|
+ def scores(logprobs, lengths):
|
|
|
+ result = []
|
|
|
+ for logprob, length in zip(logprobs, lengths):
|
|
|
+ if self.length_penalty is None or self.length_penalty == "None":
|
|
|
+ penalty = length
|
|
|
+ else:
|
|
|
+ # from the Google NMT paper
|
|
|
+ penalty = ((5 + length) / 6) ** self.length_penalty
|
|
|
+ result.append(logprob / penalty)
|
|
|
+ return result
|
|
|
+
|
|
|
+ # get the sequence with the highest score
|
|
|
+ lengths = [[len(t) for t in s] for s in tokens]
|
|
|
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
|
|
+
|
|
|
+
|
|
|
+class TokenDecoder:
|
|
|
+ def reset(self):
|
|
|
+ """Initialize any stateful variables for decoding a new sequence"""
|
|
|
+
|
|
|
+ def update(
|
|
|
+ self,
|
|
|
+ tokens: paddle.Tensor,
|
|
|
+ logits: paddle.Tensor,
|
|
|
+ sum_logprobs: paddle.Tensor,
|
|
|
+ ) -> Tuple[paddle.Tensor, bool]:
|
|
|
+ """Specify how to select the next token, based on the current trace and logits
|
|
|
Parameters
|
|
|
----------
|
|
|
- model: Whisper
|
|
|
- The Whisper model instance
|
|
|
- mel: paddle.Tensor
|
|
|
- The audio feature
|
|
|
- verbose: bool
|
|
|
- Whether to display the text being decoded to the console. If True, displays all the details,
|
|
|
- If False, displays minimal details. If None, does not display anything
|
|
|
- temperature: Union[float, Tuple[float, ...]]
|
|
|
- Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
|
|
- upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
|
|
- compression_ratio_threshold: float
|
|
|
- If the gzip compression ratio is above this value, treat as failed
|
|
|
- logprob_threshold: float
|
|
|
- If the average log probability over sampled tokens is below this value, treat as failed
|
|
|
- no_speech_threshold: float
|
|
|
- If the no_speech probability is higher than this value AND the average log probability
|
|
|
- over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
|
|
- condition_on_previous_text: bool
|
|
|
- if True, the previous output of the model is provided as a prompt for the next window;
|
|
|
- disabling may make the text inconsistent across windows, but the model becomes less prone to
|
|
|
- getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
|
|
- decode_options: dict
|
|
|
- Keyword arguments to construct `DecodingOptions` instances
|
|
|
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
|
|
|
+ all tokens in the context so far, including the prefix and sot_sequence tokens
|
|
|
+ logits : Tensor, shape = (n_batch, vocab_size)
|
|
|
+ per-token logits of the probability distribution at the current step
|
|
|
+ sum_logprobs : Tensor, shape = (n_batch)
|
|
|
+ cumulative log probabilities for each sequence
|
|
|
Returns
|
|
|
-------
|
|
|
- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
|
- the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
|
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
|
|
+ the tokens, appended with the selected next token
|
|
|
+ completed : bool
|
|
|
+ True if all sequences has reached the end of text
|
|
|
"""
|
|
|
- dtype = np.float32 # paddle only support float32
|
|
|
+ raise NotImplementedError
|
|
|
|
|
|
- if dtype == np.float32:
|
|
|
- decode_options["fp16"] = False
|
|
|
+ def finalize(
|
|
|
+ self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
|
|
|
+ ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
|
|
|
+ """Finalize search and return the final candidate sequences
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
|
|
|
+ all tokens in the context so far, including the prefix and sot_sequence
|
|
|
+ sum_logprobs : Tensor, shape = (batch_size, beam_size)
|
|
|
+ cumulative log probabilities for each sequence
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ tokens : Sequence[Sequence[Tensor]], length = batch_size
|
|
|
+ sequence of Tensors containing candidate token sequences, for each audio input
|
|
|
+ sum_logprobs : List[List[float]], length = batch_size
|
|
|
+ sequence of cumulative log probabilities corresponding to the above
|
|
|
+ """
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+
|
|
|
+class GreedyDecoder(TokenDecoder):
|
|
|
+ def __init__(self, temperature: float, eot: int):
|
|
|
+ self.temperature = temperature
|
|
|
+ self.eot = eot
|
|
|
+
|
|
|
+ def update(
|
|
|
+ self,
|
|
|
+ tokens: paddle.Tensor,
|
|
|
+ logits: paddle.Tensor,
|
|
|
+ sum_logprobs: paddle.Tensor,
|
|
|
+ ) -> Tuple[paddle.Tensor, bool]:
|
|
|
+ temperature = self.temperature
|
|
|
+ if temperature == 0:
|
|
|
+ next_tokens = paddle.argmax(logits, axis=-1)
|
|
|
+ else:
|
|
|
+ next_tokens = paddle.distribution.Categorical(
|
|
|
+ logits=logits / temperature
|
|
|
+ ).sample([1])
|
|
|
+ next_tokens = paddle.reshape(
|
|
|
+ next_tokens,
|
|
|
+ [
|
|
|
+ next_tokens.shape[0] * next_tokens.shape[1],
|
|
|
+ ],
|
|
|
+ )
|
|
|
|
|
|
- if (
|
|
|
- decode_options.get("language") == "None"
|
|
|
- or decode_options.get("language", None) is None
|
|
|
- ):
|
|
|
- if not model.is_multilingual:
|
|
|
- decode_options["language"] = "en"
|
|
|
- else:
|
|
|
- if verbose:
|
|
|
- print(
|
|
|
- "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
|
|
- )
|
|
|
- segment = pad_or_trim(mel, N_FRAMES)
|
|
|
- _, probs = model.detect_language(segment, resource_path)
|
|
|
- decode_options["language"] = max(probs, key=probs.get)
|
|
|
- if verbose is not None:
|
|
|
- print(
|
|
|
- f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
|
|
- )
|
|
|
+ logprobs = paddle.nn.functional.log_softmax(
|
|
|
+ logits, axis=-1, dtype=paddle.float32
|
|
|
+ )
|
|
|
+ current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
|
|
|
+ sum_logprobs += current_logprobs * paddle.to_tensor(
|
|
|
+ (tokens[:, -1] != self.eot), dtype=paddle.float32
|
|
|
+ )
|
|
|
|
|
|
- language = decode_options["language"]
|
|
|
- task = decode_options.get("task", "transcribe")
|
|
|
- tokenizer = get_tokenizer(
|
|
|
- model.is_multilingual,
|
|
|
- resource_path=resource_path,
|
|
|
- language=language,
|
|
|
- task=task,
|
|
|
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
|
|
|
+ tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
|
|
|
+
|
|
|
+ completed = paddle.all((tokens[:, -1] == self.eot))
|
|
|
+ return tokens, completed
|
|
|
+
|
|
|
+ def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
|
|
|
+ # make sure each sequence has at least one EOT token at the end
|
|
|
+ tokens = paddle.nn.functional.pad(
|
|
|
+ tokens, (0, 1), value=self.eot, data_format="NCL"
|
|
|
)
|
|
|
+ return tokens, sum_logprobs.tolist()
|
|
|
|
|
|
- def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
|
|
|
- temperatures = (
|
|
|
- [temperature] if isinstance(temperature, (int, float)) else temperature
|
|
|
- )
|
|
|
- decode_result = None
|
|
|
-
|
|
|
- for t in temperatures:
|
|
|
- kwargs = {**decode_options}
|
|
|
- if t > 0:
|
|
|
- # disable beam_size and patience when t > 0
|
|
|
- kwargs.pop("beam_size", None)
|
|
|
- kwargs.pop("patience", None)
|
|
|
+
|
|
|
+class BeamSearchDecoder(TokenDecoder):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ beam_size: int,
|
|
|
+ eot: int,
|
|
|
+ inference: Inference,
|
|
|
+ patience: Optional[float] = None,
|
|
|
+ ):
|
|
|
+ self.beam_size = beam_size
|
|
|
+ self.eot = eot
|
|
|
+ self.inference = inference
|
|
|
+ self.patience = patience or 1.0
|
|
|
+ if patience is None or patience == "None":
|
|
|
+ self.patience = 1.0
|
|
|
+ else:
|
|
|
+ self.patience = patience
|
|
|
+ self.max_candidates: int = round(beam_size * self.patience)
|
|
|
+ self.finished_sequences = None
|
|
|
+
|
|
|
+ assert (
|
|
|
+ self.max_candidates > 0
|
|
|
+ ), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
|
|
+
|
|
|
+ def reset(self):
|
|
|
+ self.finished_sequences = None
|
|
|
+
|
|
|
+ def update(
|
|
|
+ self,
|
|
|
+ tokens: paddle.Tensor,
|
|
|
+ logits: paddle.Tensor,
|
|
|
+ sum_logprobs: paddle.Tensor,
|
|
|
+ ) -> Tuple[paddle.Tensor, bool]:
|
|
|
+ if tokens.shape[0] % self.beam_size != 0:
|
|
|
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
|
|
+
|
|
|
+ batch_size = tokens.shape[0] // self.beam_size
|
|
|
+ if self.finished_sequences is None: # for the first update
|
|
|
+ self.finished_sequences = [{} for _ in range(batch_size)]
|
|
|
+
|
|
|
+ logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
|
|
|
+ next_tokens, source_indices, finished_sequences = [], [], []
|
|
|
+ for i in range(batch_size):
|
|
|
+ scores, sources, finished = {}, {}, {}
|
|
|
+
|
|
|
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
|
|
|
+ for j in range(self.beam_size):
|
|
|
+ idx = i * self.beam_size + j
|
|
|
+ prefix = tokens[idx].tolist()
|
|
|
+ logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
|
|
|
+ for logprob, token in zip(logprob, token):
|
|
|
+ new_logprob = (sum_logprobs[idx] + logprob).item()
|
|
|
+ sequence = tuple(prefix + [token.item()])
|
|
|
+ scores[sequence] = new_logprob
|
|
|
+ sources[sequence] = idx
|
|
|
+
|
|
|
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
|
|
+ saved = 0
|
|
|
+ for sequence in sorted(scores, key=scores.get, reverse=True):
|
|
|
+ if sequence[-1] == self.eot:
|
|
|
+ finished[sequence] = scores[sequence]
|
|
|
else:
|
|
|
- # disable best_of when t == 0
|
|
|
- kwargs.pop("best_of", None)
|
|
|
+ sum_logprobs[len(next_tokens)] = scores[sequence]
|
|
|
+ next_tokens.append(sequence)
|
|
|
+ source_indices.append(sources[sequence])
|
|
|
|
|
|
- options = DecodingOptions(**kwargs, temperature=t)
|
|
|
- decode_result = model.decode(segment, options, resource_path)
|
|
|
+ saved += 1
|
|
|
+ if saved == self.beam_size:
|
|
|
+ break
|
|
|
|
|
|
- needs_fallback = False
|
|
|
- if (
|
|
|
- compression_ratio_threshold is not None
|
|
|
- and decode_result.compression_ratio > compression_ratio_threshold
|
|
|
- ):
|
|
|
- needs_fallback = True # too repetitive
|
|
|
- if (
|
|
|
- logprob_threshold is not None
|
|
|
- and decode_result.avg_logprob < logprob_threshold
|
|
|
- ):
|
|
|
- needs_fallback = True # average log probability is too low
|
|
|
+ finished_sequences.append(finished)
|
|
|
|
|
|
- if not needs_fallback:
|
|
|
- break
|
|
|
+ tokens = paddle.to_tensor(next_tokens)
|
|
|
+ self.inference.rearrange_kv_cache(source_indices)
|
|
|
|
|
|
- return decode_result
|
|
|
-
|
|
|
- seek = 0
|
|
|
- input_stride = exact_div(
|
|
|
- N_FRAMES, model.dims.n_audio_ctx
|
|
|
- ) # mel frames per output token: 2
|
|
|
- time_precision = (
|
|
|
- input_stride * HOP_LENGTH / SAMPLE_RATE
|
|
|
- ) # time per output token: 0.02 (seconds)
|
|
|
- all_tokens = []
|
|
|
- all_segments = []
|
|
|
- prompt_reset_since = 0
|
|
|
-
|
|
|
- initial_prompt = decode_options.pop("initial_prompt", None)
|
|
|
- if initial_prompt and initial_prompt != "None":
|
|
|
- initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
|
|
|
- all_tokens.extend(initial_prompt)
|
|
|
- else:
|
|
|
- initial_prompt = []
|
|
|
-
|
|
|
- def add_segment(
|
|
|
- *,
|
|
|
- start: float,
|
|
|
- end: float,
|
|
|
- text_tokens: paddle.Tensor,
|
|
|
- result: DecodingResult,
|
|
|
+ # add newly finished sequences to self.finished_sequences
|
|
|
+ assert len(self.finished_sequences) == len(finished_sequences)
|
|
|
+ for previously_finished, newly_finished in zip(
|
|
|
+ self.finished_sequences, finished_sequences
|
|
|
):
|
|
|
- text = tokenizer.decode(
|
|
|
- [token for token in text_tokens if token < tokenizer.eot]
|
|
|
- )
|
|
|
- if len(text.strip()) == 0: # skip empty text output
|
|
|
- return
|
|
|
-
|
|
|
- all_segments.append(
|
|
|
- {
|
|
|
- "id": len(all_segments),
|
|
|
- "seek": seek,
|
|
|
- "start": start,
|
|
|
- "end": end,
|
|
|
- "text": text,
|
|
|
- "tokens": result.tokens,
|
|
|
- "temperature": result.temperature,
|
|
|
- "avg_logprob": result.avg_logprob,
|
|
|
- "compression_ratio": result.compression_ratio,
|
|
|
- "no_speech_prob": result.no_speech_prob,
|
|
|
- }
|
|
|
- )
|
|
|
- if verbose:
|
|
|
- print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
|
|
-
|
|
|
- # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
|
|
- num_frames = mel.shape[-1]
|
|
|
- previous_seek_value = seek
|
|
|
-
|
|
|
- with tqdm.tqdm(
|
|
|
- total=num_frames, unit="frames", disable=verbose is not False
|
|
|
- ) as pbar:
|
|
|
- while seek < num_frames:
|
|
|
- timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
|
- segment = pad_or_trim(mel[:, seek:], N_FRAMES)
|
|
|
- segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
|
|
-
|
|
|
- decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
|
- result: DecodingResult = decode_with_fallback(segment)
|
|
|
- tokens = paddle.to_tensor(result.tokens)
|
|
|
-
|
|
|
- if no_speech_threshold is not None:
|
|
|
- # no voice activity check
|
|
|
- should_skip = result.no_speech_prob > no_speech_threshold
|
|
|
- if (
|
|
|
- logprob_threshold is not None
|
|
|
- and result.avg_logprob > logprob_threshold
|
|
|
- ):
|
|
|
- # don't skip if the logprob is high enough, despite the no_speech_prob
|
|
|
- should_skip = False
|
|
|
-
|
|
|
- if should_skip:
|
|
|
- seek += segment.shape[
|
|
|
- -1
|
|
|
- ] # fast-forward to the next segment boundary
|
|
|
- continue
|
|
|
-
|
|
|
- timestamp_tokens: paddle.Tensor = tokens.greater_equal(
|
|
|
- paddle.to_tensor(tokenizer.timestamp_begin)
|
|
|
- )
|
|
|
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
|
|
+ if len(previously_finished) >= self.max_candidates:
|
|
|
+ break # the candidate list is full
|
|
|
+ previously_finished[seq] = newly_finished[seq]
|
|
|
+
|
|
|
+ # mark as completed if all audio has enough number of samples
|
|
|
+ completed = all(
|
|
|
+ len(sequences) >= self.max_candidates
|
|
|
+ for sequences in self.finished_sequences
|
|
|
+ )
|
|
|
+ return tokens, completed
|
|
|
|
|
|
- consecutive = paddle.where(
|
|
|
- timestamp_tokens[:-1] & timestamp_tokens[1:]
|
|
|
- )[0]
|
|
|
- if (
|
|
|
- len(consecutive) > 0
|
|
|
- ): # if the output contains two consecutive timestamp tokens
|
|
|
- consecutive = paddle.add(consecutive, paddle.to_tensor(1))
|
|
|
- last_slice = 0
|
|
|
- for current_slice in consecutive:
|
|
|
- sliced_tokens = tokens[last_slice:current_slice]
|
|
|
- start_timestamp_position = (
|
|
|
- sliced_tokens[0].item() - tokenizer.timestamp_begin
|
|
|
- )
|
|
|
- end_timestamp_position = (
|
|
|
- sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
|
|
- )
|
|
|
- add_segment(
|
|
|
- start=timestamp_offset
|
|
|
- + start_timestamp_position * time_precision,
|
|
|
- end=timestamp_offset
|
|
|
- + end_timestamp_position * time_precision,
|
|
|
- text_tokens=sliced_tokens[1:-1],
|
|
|
- result=result,
|
|
|
- )
|
|
|
- last_slice = current_slice
|
|
|
- last_timestamp_position = (
|
|
|
- tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
|
|
- )
|
|
|
- seek += last_timestamp_position * input_stride
|
|
|
- all_tokens.extend(tokens[: last_slice + 1].tolist())
|
|
|
- else:
|
|
|
- duration = segment_duration
|
|
|
- timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
|
- if (
|
|
|
- len(timestamps) > 0
|
|
|
- and timestamps[-1].item() != tokenizer.timestamp_begin
|
|
|
- ):
|
|
|
- # no consecutive timestamps but it has a timestamp; use the last one.
|
|
|
- # single timestamp at the end means no speech after the last timestamp.
|
|
|
- last_timestamp_position = (
|
|
|
- timestamps[-1].item() - tokenizer.timestamp_begin
|
|
|
- )
|
|
|
- duration = last_timestamp_position * time_precision
|
|
|
+ def finalize(self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
|
|
|
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
|
|
|
+ sum_logprobs = sum_logprobs.cpu()
|
|
|
+ for i, sequences in enumerate(self.finished_sequences):
|
|
|
+ if (
|
|
|
+ len(sequences) < self.beam_size
|
|
|
+ ): # when not enough sequences are finished
|
|
|
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
|
|
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
|
|
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
|
|
+ if len(sequences) >= self.beam_size:
|
|
|
+ break
|
|
|
|
|
|
- add_segment(
|
|
|
- start=timestamp_offset,
|
|
|
- end=timestamp_offset + duration,
|
|
|
- text_tokens=tokens,
|
|
|
- result=result,
|
|
|
- )
|
|
|
+ tokens: List[List[paddle.Tensor]] = [
|
|
|
+ [paddle.to_tensor(seq) for seq in sequences.keys()]
|
|
|
+ for sequences in self.finished_sequences
|
|
|
+ ]
|
|
|
+ sum_logprobs: List[List[float]] = [
|
|
|
+ list(sequences.values()) for sequences in self.finished_sequences
|
|
|
+ ]
|
|
|
+ return tokens, sum_logprobs
|
|
|
|
|
|
- seek += segment.shape[-1]
|
|
|
- all_tokens.extend(tokens.tolist())
|
|
|
|
|
|
- if not condition_on_previous_text or result.temperature > 0.5:
|
|
|
- # do not feed the prompt tokens if a high temperature was used
|
|
|
- prompt_reset_since = len(all_tokens)
|
|
|
+class LogitFilter:
|
|
|
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
|
|
|
+ """Apply any filtering or masking to logits in-place
|
|
|
|
|
|
- # update progress bar
|
|
|
- pbar.update(min(num_frames, seek) - previous_seek_value)
|
|
|
- previous_seek_value = seek
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ logits : Tensor, shape = (n_batch, vocab_size)
|
|
|
+ per-token logits of the probability distribution at the current step
|
|
|
|
|
|
- return dict(
|
|
|
- text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
|
|
|
- segments=all_segments,
|
|
|
- language=language,
|
|
|
- )
|
|
|
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
|
|
|
+ all tokens in the context so far, including the prefix and sot_sequence tokens
|
|
|
|
|
|
- class SequenceRanker:
|
|
|
- def rank(
|
|
|
- self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
|
|
|
- ) -> List[int]:
|
|
|
- """
|
|
|
- Given a list of groups of samples and their cumulative log probabilities,
|
|
|
- return the indices of the samples in each group to select as the final result
|
|
|
- """
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- class MaximumLikelihoodRanker(SequenceRanker):
|
|
|
- """
|
|
|
- Select the sample with the highest log probabilities, penalized using either
|
|
|
- a simple length normalization or Google NMT paper's length penalty
|
|
|
"""
|
|
|
+ raise NotImplementedError
|
|
|
|
|
|
- def __init__(self, length_penalty: Optional[float]):
|
|
|
- self.length_penalty = length_penalty
|
|
|
|
|
|
- def rank(
|
|
|
- self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
|
|
|
- ):
|
|
|
- def scores(logprobs, lengths):
|
|
|
- result = []
|
|
|
- for logprob, length in zip(logprobs, lengths):
|
|
|
- if self.length_penalty is None or self.length_penalty == "None":
|
|
|
- penalty = length
|
|
|
- else:
|
|
|
- # from the Google NMT paper
|
|
|
- penalty = ((5 + length) / 6) ** self.length_penalty
|
|
|
- result.append(logprob / penalty)
|
|
|
- return result
|
|
|
-
|
|
|
- # get the sequence with the highest score
|
|
|
- lengths = [[len(t) for t in s] for s in tokens]
|
|
|
- return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
|
|
-
|
|
|
- class TokenDecoder:
|
|
|
- def reset(self):
|
|
|
- """Initialize any stateful variables for decoding a new sequence"""
|
|
|
-
|
|
|
- def update(
|
|
|
- self,
|
|
|
- tokens: paddle.Tensor,
|
|
|
- logits: paddle.Tensor,
|
|
|
- sum_logprobs: paddle.Tensor,
|
|
|
- ) -> Tuple[paddle.Tensor, bool]:
|
|
|
- """Specify how to select the next token, based on the current trace and logits
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- tokens : Tensor, shape = (n_batch, current_sequence_length)
|
|
|
- all tokens in the context so far, including the prefix and sot_sequence tokens
|
|
|
- logits : Tensor, shape = (n_batch, vocab_size)
|
|
|
- per-token logits of the probability distribution at the current step
|
|
|
- sum_logprobs : Tensor, shape = (n_batch)
|
|
|
- cumulative log probabilities for each sequence
|
|
|
- Returns
|
|
|
- -------
|
|
|
- tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
|
|
- the tokens, appended with the selected next token
|
|
|
- completed : bool
|
|
|
- True if all sequences has reached the end of text
|
|
|
- """
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- def finalize(
|
|
|
- self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
|
|
|
- ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
|
|
|
- """Finalize search and return the final candidate sequences
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
|
|
|
- all tokens in the context so far, including the prefix and sot_sequence
|
|
|
- sum_logprobs : Tensor, shape = (batch_size, beam_size)
|
|
|
- cumulative log probabilities for each sequence
|
|
|
- Returns
|
|
|
- -------
|
|
|
- tokens : Sequence[Sequence[Tensor]], length = batch_size
|
|
|
- sequence of Tensors containing candidate token sequences, for each audio input
|
|
|
- sum_logprobs : List[List[float]], length = batch_size
|
|
|
- sequence of cumulative log probabilities corresponding to the above
|
|
|
- """
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- class GreedyDecoder(TokenDecoder):
|
|
|
- def __init__(self, temperature: float, eot: int):
|
|
|
- self.temperature = temperature
|
|
|
- self.eot = eot
|
|
|
-
|
|
|
- def update(
|
|
|
- self,
|
|
|
- tokens: paddle.Tensor,
|
|
|
- logits: paddle.Tensor,
|
|
|
- sum_logprobs: paddle.Tensor,
|
|
|
- ) -> Tuple[paddle.Tensor, bool]:
|
|
|
- temperature = self.temperature
|
|
|
- if temperature == 0:
|
|
|
- next_tokens = paddle.argmax(logits, axis=-1)
|
|
|
- else:
|
|
|
- next_tokens = paddle.distribution.Categorical(
|
|
|
- logits=logits / temperature
|
|
|
- ).sample([1])
|
|
|
- next_tokens = paddle.reshape(
|
|
|
- next_tokens,
|
|
|
- [
|
|
|
- next_tokens.shape[0] * next_tokens.shape[1],
|
|
|
- ],
|
|
|
- )
|
|
|
+class SuppressBlank(LogitFilter):
|
|
|
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
|
|
+ self.tokenizer = tokenizer
|
|
|
+ self.sample_begin = sample_begin
|
|
|
|
|
|
- logprobs = paddle.nn.functional.log_softmax(
|
|
|
- logits, axis=-1, dtype=paddle.float32
|
|
|
- )
|
|
|
- current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
|
|
|
- sum_logprobs += current_logprobs * paddle.to_tensor(
|
|
|
- (tokens[:, -1] != self.eot), dtype=paddle.float32
|
|
|
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
+ if tokens.shape[1] == self.sample_begin:
|
|
|
+ logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
|
|
|
+ -np.inf
|
|
|
)
|
|
|
|
|
|
- next_tokens[tokens[:, -1] == self.eot] = self.eot
|
|
|
- tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
|
|
|
|
|
|
- completed = paddle.all((tokens[:, -1] == self.eot))
|
|
|
- return tokens, completed
|
|
|
+class SuppressTokens(LogitFilter):
|
|
|
+ def __init__(self, suppress_tokens: Sequence[int]):
|
|
|
+ self.suppress_tokens = list(suppress_tokens)
|
|
|
|
|
|
- def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
|
|
|
- # make sure each sequence has at least one EOT token at the end
|
|
|
- tokens = paddle.nn.functional.pad(
|
|
|
- tokens, (0, 1), value=self.eot, data_format="NCL"
|
|
|
- )
|
|
|
- return tokens, sum_logprobs.tolist()
|
|
|
-
|
|
|
- class BeamSearchDecoder(TokenDecoder):
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- beam_size: int,
|
|
|
- eot: int,
|
|
|
- inference: Inference,
|
|
|
- patience: Optional[float] = None,
|
|
|
- ):
|
|
|
- self.beam_size = beam_size
|
|
|
- self.eot = eot
|
|
|
- self.inference = inference
|
|
|
- self.patience = patience or 1.0
|
|
|
- if patience is None or patience == "None":
|
|
|
- self.patience = 1.0
|
|
|
- else:
|
|
|
- self.patience = patience
|
|
|
- self.max_candidates: int = round(beam_size * self.patience)
|
|
|
- self.finished_sequences = None
|
|
|
-
|
|
|
- assert (
|
|
|
- self.max_candidates > 0
|
|
|
- ), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
|
|
-
|
|
|
- def reset(self):
|
|
|
- self.finished_sequences = None
|
|
|
-
|
|
|
- def update(
|
|
|
- self,
|
|
|
- tokens: paddle.Tensor,
|
|
|
- logits: paddle.Tensor,
|
|
|
- sum_logprobs: paddle.Tensor,
|
|
|
- ) -> Tuple[paddle.Tensor, bool]:
|
|
|
- if tokens.shape[0] % self.beam_size != 0:
|
|
|
- raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
|
|
-
|
|
|
- batch_size = tokens.shape[0] // self.beam_size
|
|
|
- if self.finished_sequences is None: # for the first update
|
|
|
- self.finished_sequences = [{} for _ in range(batch_size)]
|
|
|
-
|
|
|
- logprobs = paddle.nn.functional.log_softmax(
|
|
|
- logits, axis=-1, dtype="float32"
|
|
|
- )
|
|
|
- next_tokens, source_indices, finished_sequences = [], [], []
|
|
|
- for i in range(batch_size):
|
|
|
- scores, sources, finished = {}, {}, {}
|
|
|
-
|
|
|
- # STEP 1: calculate the cumulative log probabilities for possible candidates
|
|
|
- for j in range(self.beam_size):
|
|
|
- idx = i * self.beam_size + j
|
|
|
- prefix = tokens[idx].tolist()
|
|
|
- logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
|
|
|
- for logprob, token in zip(logprob, token):
|
|
|
- new_logprob = (sum_logprobs[idx] + logprob).item()
|
|
|
- sequence = tuple(prefix + [token.item()])
|
|
|
- scores[sequence] = new_logprob
|
|
|
- sources[sequence] = idx
|
|
|
-
|
|
|
- # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
|
|
- saved = 0
|
|
|
- for sequence in sorted(scores, key=scores.get, reverse=True):
|
|
|
- if sequence[-1] == self.eot:
|
|
|
- finished[sequence] = scores[sequence]
|
|
|
- else:
|
|
|
- sum_logprobs[len(next_tokens)] = scores[sequence]
|
|
|
- next_tokens.append(sequence)
|
|
|
- source_indices.append(sources[sequence])
|
|
|
-
|
|
|
- saved += 1
|
|
|
- if saved == self.beam_size:
|
|
|
- break
|
|
|
-
|
|
|
- finished_sequences.append(finished)
|
|
|
-
|
|
|
- tokens = paddle.to_tensor(next_tokens)
|
|
|
- self.inference.rearrange_kv_cache(source_indices)
|
|
|
-
|
|
|
- # add newly finished sequences to self.finished_sequences
|
|
|
- assert len(self.finished_sequences) == len(finished_sequences)
|
|
|
- for previously_finished, newly_finished in zip(
|
|
|
- self.finished_sequences, finished_sequences
|
|
|
- ):
|
|
|
- for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
|
|
- if len(previously_finished) >= self.max_candidates:
|
|
|
- break # the candidate list is full
|
|
|
- previously_finished[seq] = newly_finished[seq]
|
|
|
-
|
|
|
- # mark as completed if all audio has enough number of samples
|
|
|
- completed = all(
|
|
|
- len(sequences) >= self.max_candidates
|
|
|
- for sequences in self.finished_sequences
|
|
|
- )
|
|
|
- return tokens, completed
|
|
|
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
+ logits[:, self.suppress_tokens] = -np.inf
|
|
|
|
|
|
- def finalize(
|
|
|
- self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
|
|
|
- ):
|
|
|
- # collect all finished sequences, including patience, and add unfinished ones if not enough
|
|
|
- sum_logprobs = sum_logprobs.cpu()
|
|
|
- for i, sequences in enumerate(self.finished_sequences):
|
|
|
- if (
|
|
|
- len(sequences) < self.beam_size
|
|
|
- ): # when not enough sequences are finished
|
|
|
- for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
|
|
- sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
|
|
- sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
|
|
- if len(sequences) >= self.beam_size:
|
|
|
- break
|
|
|
-
|
|
|
- tokens: List[List[paddle.Tensor]] = [
|
|
|
- [paddle.to_tensor(seq) for seq in sequences.keys()]
|
|
|
- for sequences in self.finished_sequences
|
|
|
- ]
|
|
|
- sum_logprobs: List[List[float]] = [
|
|
|
- list(sequences.values()) for sequences in self.finished_sequences
|
|
|
- ]
|
|
|
- return tokens, sum_logprobs
|
|
|
-
|
|
|
- class LogitFilter:
|
|
|
- def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
|
|
|
- """Apply any filtering or masking to logits in-place
|
|
|
-
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- logits : Tensor, shape = (n_batch, vocab_size)
|
|
|
- per-token logits of the probability distribution at the current step
|
|
|
-
|
|
|
- tokens : Tensor, shape = (n_batch, current_sequence_length)
|
|
|
- all tokens in the context so far, including the prefix and sot_sequence tokens
|
|
|
-
|
|
|
- """
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- class SuppressBlank(LogitFilter):
|
|
|
- def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
|
|
- self.tokenizer = tokenizer
|
|
|
- self.sample_begin = sample_begin
|
|
|
-
|
|
|
- def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
- if tokens.shape[1] == self.sample_begin:
|
|
|
- logits[
|
|
|
- :, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]
|
|
|
- ] = -np.inf
|
|
|
-
|
|
|
- class SuppressTokens(LogitFilter):
|
|
|
- def __init__(self, suppress_tokens: Sequence[int]):
|
|
|
- self.suppress_tokens = list(suppress_tokens)
|
|
|
-
|
|
|
- def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
- logits[:, self.suppress_tokens] = -np.inf
|
|
|
-
|
|
|
- class ApplyTimestampRules(LogitFilter):
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- tokenizer: Tokenizer,
|
|
|
- sample_begin: int,
|
|
|
- max_initial_timestamp_index: Optional[int],
|
|
|
- ):
|
|
|
- self.tokenizer = tokenizer
|
|
|
- self.sample_begin = sample_begin
|
|
|
- self.max_initial_timestamp_index = max_initial_timestamp_index
|
|
|
-
|
|
|
- def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
- # suppress <|notimestamps|> which is handled by without_timestamps
|
|
|
- if self.tokenizer.no_timestamps is not None:
|
|
|
- logits[:, self.tokenizer.no_timestamps] = -np.inf
|
|
|
-
|
|
|
- # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
|
- for k in range(tokens.shape[0]):
|
|
|
- seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
|
|
- last_was_timestamp = (
|
|
|
- len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
|
|
- )
|
|
|
- penultimate_was_timestamp = (
|
|
|
- len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
|
|
- )
|
|
|
-
|
|
|
- if last_was_timestamp:
|
|
|
- if penultimate_was_timestamp: # has to be non-timestamp
|
|
|
- logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
|
|
- else: # cannot be normal text tokens
|
|
|
- logits[k, : self.tokenizer.eot] = -np.inf
|
|
|
|
|
|
- # apply the `max_initial_timestamp` option
|
|
|
- if (
|
|
|
- tokens.shape[1] == self.sample_begin
|
|
|
- and self.max_initial_timestamp_index is not None
|
|
|
- ):
|
|
|
- last_allowed = (
|
|
|
- self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
|
|
- )
|
|
|
- logits[:, last_allowed + 1 :] = -np.inf
|
|
|
-
|
|
|
- # if sum of probability over timestamps is above any other token, sample timestamp
|
|
|
- logprobs = paddle.nn.functional.log_softmax(
|
|
|
- logits, axis=-1, dtype="float32"
|
|
|
+class ApplyTimestampRules(LogitFilter):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ tokenizer: Tokenizer,
|
|
|
+ sample_begin: int,
|
|
|
+ max_initial_timestamp_index: Optional[int],
|
|
|
+ ):
|
|
|
+ self.tokenizer = tokenizer
|
|
|
+ self.sample_begin = sample_begin
|
|
|
+ self.max_initial_timestamp_index = max_initial_timestamp_index
|
|
|
+
|
|
|
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
+ # suppress <|notimestamps|> which is handled by without_timestamps
|
|
|
+ if self.tokenizer.no_timestamps is not None:
|
|
|
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
|
|
|
+
|
|
|
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
|
+ for k in range(tokens.shape[0]):
|
|
|
+ seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
|
|
+ last_was_timestamp = (
|
|
|
+ len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
|
|
+ )
|
|
|
+ penultimate_was_timestamp = (
|
|
|
+ len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
|
|
)
|
|
|
- for k in range(tokens.shape[0]):
|
|
|
- # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
|
|
|
- # To bypass this issue in CI, we have decomposed the operation into separate steps.
|
|
|
- # It will raise 2e-6 difference in precision.
|
|
|
- # TODO: revert this after logsumexp been fixed.
|
|
|
- timestamp_logprob = paddle.exp(
|
|
|
- logprobs[k, self.tokenizer.timestamp_begin :]
|
|
|
- )
|
|
|
- timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
|
|
|
- timestamp_logprob = paddle.log(timestamp_logprob)
|
|
|
- max_text_token_logprob = paddle.max(
|
|
|
- logprobs[k, : self.tokenizer.timestamp_begin]
|
|
|
- )
|
|
|
- if timestamp_logprob > max_text_token_logprob:
|
|
|
- logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
|
|
|
|
|
- class DecodingTask:
|
|
|
- inference: Inference
|
|
|
- sequence_ranker: SequenceRanker
|
|
|
- decoder: TokenDecoder
|
|
|
- logit_filters: List[LogitFilter]
|
|
|
+ if last_was_timestamp:
|
|
|
+ if penultimate_was_timestamp: # has to be non-timestamp
|
|
|
+ logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
|
|
+ else: # cannot be normal text tokens
|
|
|
+ logits[k, : self.tokenizer.eot] = -np.inf
|
|
|
|
|
|
- def __init__(
|
|
|
- self, model: "Whisper", options: DecodingOptions, resource_path: str
|
|
|
+ # apply the `max_initial_timestamp` option
|
|
|
+ if (
|
|
|
+ tokens.shape[1] == self.sample_begin
|
|
|
+ and self.max_initial_timestamp_index is not None
|
|
|
):
|
|
|
- self.model = model
|
|
|
-
|
|
|
- language = options.language or "en"
|
|
|
- tokenizer = get_tokenizer(
|
|
|
- model.is_multilingual,
|
|
|
- resource_path=resource_path,
|
|
|
- language=language,
|
|
|
- task=options.task,
|
|
|
+ last_allowed = (
|
|
|
+ self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
|
|
)
|
|
|
- self.tokenizer: Tokenizer = tokenizer
|
|
|
- self.options: DecodingOptions = self._verify_options(options)
|
|
|
- self.resource_path: str = resource_path
|
|
|
-
|
|
|
- self.beam_size: int = options.beam_size or options.best_of or 1
|
|
|
- self.n_ctx: int = model.dims.n_text_ctx
|
|
|
- self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
|
|
+ logits[:, last_allowed + 1 :] = -np.inf
|
|
|
+
|
|
|
+ # if sum of probability over timestamps is above any other token, sample timestamp
|
|
|
+ logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
|
|
|
+ for k in range(tokens.shape[0]):
|
|
|
+ # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
|
|
|
+ # To bypass this issue in CI, we have decomposed the operation into separate steps.
|
|
|
+ # It will raise 2e-6 difference in precision.
|
|
|
+ # TODO: revert this after logsumexp been fixed.
|
|
|
+ timestamp_logprob = paddle.exp(
|
|
|
+ logprobs[k, self.tokenizer.timestamp_begin :]
|
|
|
+ )
|
|
|
+ timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
|
|
|
+ timestamp_logprob = paddle.log(timestamp_logprob)
|
|
|
+ max_text_token_logprob = paddle.max(
|
|
|
+ logprobs[k, : self.tokenizer.timestamp_begin]
|
|
|
+ )
|
|
|
+ if timestamp_logprob > max_text_token_logprob:
|
|
|
+ logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
|
|
|
|
|
- self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
|
|
- if self.options.without_timestamps:
|
|
|
- self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
|
|
|
|
|
- self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
|
|
- self.sample_begin: int = len(self.initial_tokens)
|
|
|
- self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
|
|
+class DecodingTask:
|
|
|
+ inference: Inference
|
|
|
+ sequence_ranker: SequenceRanker
|
|
|
+ decoder: TokenDecoder
|
|
|
+ logit_filters: List[LogitFilter]
|
|
|
|
|
|
- # inference: implements the forward pass through the decoder, including kv caching
|
|
|
- self.inference = WhisperInference(model, len(self.initial_tokens))
|
|
|
+ def __init__(self, model: "Whisper", options: DecodingOptions, resource_path: str):
|
|
|
+ self.model = model
|
|
|
|
|
|
- # sequence ranker: implements how to rank a group of sampled sequences
|
|
|
- self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
|
|
+ language = options.language or "en"
|
|
|
+ tokenizer = get_tokenizer(
|
|
|
+ model.is_multilingual,
|
|
|
+ resource_path=resource_path,
|
|
|
+ language=language,
|
|
|
+ task=options.task,
|
|
|
+ )
|
|
|
+ self.tokenizer: Tokenizer = tokenizer
|
|
|
+ self.options: DecodingOptions = self._verify_options(options)
|
|
|
+ self.resource_path: str = resource_path
|
|
|
|
|
|
- # decoder: implements how to select the next tokens, given the autoregressive distribution
|
|
|
- if options.beam_size is not None:
|
|
|
- self.decoder = BeamSearchDecoder(
|
|
|
- options.beam_size, tokenizer.eot, self.inference, options.patience
|
|
|
- )
|
|
|
- else:
|
|
|
- self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
|
|
+ self.beam_size: int = options.beam_size or options.best_of or 1
|
|
|
+ self.n_ctx: int = model.dims.n_text_ctx
|
|
|
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
|
|
|
|
|
- # logit filters: applies various rules to suppress or penalize certain tokens
|
|
|
- self.logit_filters = []
|
|
|
- if self.options.suppress_blank:
|
|
|
- self.logit_filters.append(
|
|
|
- SuppressBlank(self.tokenizer, self.sample_begin)
|
|
|
- )
|
|
|
- if self.options.suppress_tokens:
|
|
|
- self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
|
|
- if not options.without_timestamps:
|
|
|
- precision = (
|
|
|
- CHUNK_LENGTH / model.dims.n_audio_ctx
|
|
|
- ) # usually 0.02 seconds
|
|
|
- max_initial_timestamp_index = None
|
|
|
- if options.max_initial_timestamp:
|
|
|
- max_initial_timestamp_index = round(
|
|
|
- self.options.max_initial_timestamp / precision
|
|
|
- )
|
|
|
- self.logit_filters.append(
|
|
|
- ApplyTimestampRules(
|
|
|
- tokenizer, self.sample_begin, max_initial_timestamp_index
|
|
|
- )
|
|
|
- )
|
|
|
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
|
|
+ if self.options.without_timestamps:
|
|
|
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
|
|
|
|
|
- def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
|
|
- if options.beam_size is not None and options.best_of is not None:
|
|
|
- raise ValueError("beam_size and best_of can't be given together")
|
|
|
- if options.temperature == 0:
|
|
|
- if options.best_of is not None:
|
|
|
- raise ValueError(
|
|
|
- "best_of with greedy sampling (T=0) is not compatible"
|
|
|
- )
|
|
|
- if options.patience is not None and options.beam_size is None:
|
|
|
- raise ValueError("patience requires beam_size to be given")
|
|
|
- if options.length_penalty is not None and options.length_penalty != "None":
|
|
|
- if not (0 <= options.length_penalty <= 1):
|
|
|
- raise ValueError(
|
|
|
- "length_penalty (alpha) should be a value between 0 and 1"
|
|
|
- )
|
|
|
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
|
|
+ self.sample_begin: int = len(self.initial_tokens)
|
|
|
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
|
|
|
|
|
- return options
|
|
|
+ # inference: implements the forward pass through the decoder, including kv caching
|
|
|
+ self.inference = WhisperInference(model, len(self.initial_tokens))
|
|
|
|
|
|
- def _get_initial_tokens(self) -> Tuple[int]:
|
|
|
- tokens = list(self.sot_sequence)
|
|
|
- prefix = self.options.prefix
|
|
|
- prompt = self.options.prompt
|
|
|
+ # sequence ranker: implements how to rank a group of sampled sequences
|
|
|
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
|
|
|
|
|
- if prefix:
|
|
|
- prefix_tokens = (
|
|
|
- self.tokenizer.encode(" " + prefix.strip().input_ids)
|
|
|
- if isinstance(prefix, str)
|
|
|
- else prefix
|
|
|
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
|
|
|
+ if options.beam_size is not None:
|
|
|
+ self.decoder = BeamSearchDecoder(
|
|
|
+ options.beam_size, tokenizer.eot, self.inference, options.patience
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
|
|
+
|
|
|
+ # logit filters: applies various rules to suppress or penalize certain tokens
|
|
|
+ self.logit_filters = []
|
|
|
+ if self.options.suppress_blank:
|
|
|
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
|
|
+ if self.options.suppress_tokens:
|
|
|
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
|
|
+ if not options.without_timestamps:
|
|
|
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
|
|
+ max_initial_timestamp_index = None
|
|
|
+ if options.max_initial_timestamp:
|
|
|
+ max_initial_timestamp_index = round(
|
|
|
+ self.options.max_initial_timestamp / precision
|
|
|
)
|
|
|
- if self.sample_len is not None:
|
|
|
- max_prefix_len = self.n_ctx // 2 - self.sample_len
|
|
|
- prefix_tokens = prefix_tokens[-max_prefix_len:]
|
|
|
- tokens = tokens + prefix_tokens
|
|
|
-
|
|
|
- if prompt:
|
|
|
- prompt_tokens = (
|
|
|
- self.tokenizer.encode(" " + prompt.strip().input_ids)
|
|
|
- if isinstance(prompt, str)
|
|
|
- else prompt
|
|
|
+ self.logit_filters.append(
|
|
|
+ ApplyTimestampRules(
|
|
|
+ tokenizer, self.sample_begin, max_initial_timestamp_index
|
|
|
)
|
|
|
- tokens = (
|
|
|
- [self.tokenizer.sot_prev]
|
|
|
- + prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
|
|
- + tokens
|
|
|
+ )
|
|
|
+
|
|
|
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
|
|
+ if options.beam_size is not None and options.best_of is not None:
|
|
|
+ raise ValueError("beam_size and best_of can't be given together")
|
|
|
+ if options.temperature == 0:
|
|
|
+ if options.best_of is not None:
|
|
|
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
|
|
+ if options.patience is not None and options.beam_size is None:
|
|
|
+ raise ValueError("patience requires beam_size to be given")
|
|
|
+ if options.length_penalty is not None and options.length_penalty != "None":
|
|
|
+ if not (0 <= options.length_penalty <= 1):
|
|
|
+ raise ValueError(
|
|
|
+ "length_penalty (alpha) should be a value between 0 and 1"
|
|
|
)
|
|
|
|
|
|
- return tuple(tokens)
|
|
|
+ return options
|
|
|
|
|
|
- def _get_suppress_tokens(self) -> Tuple[int]:
|
|
|
- suppress_tokens = self.options.suppress_tokens
|
|
|
+ def _get_initial_tokens(self) -> Tuple[int]:
|
|
|
+ tokens = list(self.sot_sequence)
|
|
|
+ prefix = self.options.prefix
|
|
|
+ prompt = self.options.prompt
|
|
|
|
|
|
- if isinstance(suppress_tokens, str):
|
|
|
- suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
|
|
+ if prefix:
|
|
|
+ prefix_tokens = (
|
|
|
+ self.tokenizer.encode(" " + prefix.strip().input_ids)
|
|
|
+ if isinstance(prefix, str)
|
|
|
+ else prefix
|
|
|
+ )
|
|
|
+ if self.sample_len is not None:
|
|
|
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
|
|
|
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
|
|
|
+ tokens = tokens + prefix_tokens
|
|
|
+
|
|
|
+ if prompt:
|
|
|
+ prompt_tokens = (
|
|
|
+ self.tokenizer.encode(" " + prompt.strip().input_ids)
|
|
|
+ if isinstance(prompt, str)
|
|
|
+ else prompt
|
|
|
+ )
|
|
|
+ tokens = (
|
|
|
+ [self.tokenizer.sot_prev]
|
|
|
+ + prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
|
|
+ + tokens
|
|
|
+ )
|
|
|
|
|
|
- if -1 in suppress_tokens:
|
|
|
- suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
|
|
- suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
|
|
- elif suppress_tokens is None or len(suppress_tokens) == 0:
|
|
|
- suppress_tokens = [] # interpret empty string as an empty list
|
|
|
- else:
|
|
|
- assert isinstance(
|
|
|
- suppress_tokens, list
|
|
|
- ), "suppress_tokens must be a list"
|
|
|
+ return tuple(tokens)
|
|
|
|
|
|
- suppress_tokens.extend(
|
|
|
- [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
|
|
- )
|
|
|
- if self.tokenizer.no_speech is not None:
|
|
|
- # no-speech probability is collected separately
|
|
|
- suppress_tokens.append(self.tokenizer.no_speech)
|
|
|
+ def _get_suppress_tokens(self) -> Tuple[int]:
|
|
|
+ suppress_tokens = self.options.suppress_tokens
|
|
|
|
|
|
- return tuple(sorted(set(suppress_tokens)))
|
|
|
+ if isinstance(suppress_tokens, str):
|
|
|
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
|
|
|
|
|
- def _get_audio_features(self, mel: paddle.Tensor):
|
|
|
+ if -1 in suppress_tokens:
|
|
|
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
|
|
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
|
|
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
|
|
|
+ suppress_tokens = [] # interpret empty string as an empty list
|
|
|
+ else:
|
|
|
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
|
|
|
|
|
- if mel.shape[-2:] == (
|
|
|
- self.model.dims.n_audio_ctx,
|
|
|
- self.model.dims.n_audio_state,
|
|
|
- ):
|
|
|
- # encoded audio features are given; skip audio encoding
|
|
|
- audio_features = mel
|
|
|
- else:
|
|
|
- audio_features = self.model.encoder(mel)
|
|
|
+ suppress_tokens.extend(
|
|
|
+ [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
|
|
+ )
|
|
|
+ if self.tokenizer.no_speech is not None:
|
|
|
+ # no-speech probability is collected separately
|
|
|
+ suppress_tokens.append(self.tokenizer.no_speech)
|
|
|
|
|
|
- return audio_features
|
|
|
+ return tuple(sorted(set(suppress_tokens)))
|
|
|
|
|
|
- def _detect_language(
|
|
|
- self,
|
|
|
- audio_features: paddle.Tensor,
|
|
|
- tokens: paddle.Tensor,
|
|
|
- resource_path: str,
|
|
|
+ def _get_audio_features(self, mel: paddle.Tensor):
|
|
|
+
|
|
|
+ if mel.shape[-2:] == (
|
|
|
+ self.model.dims.n_audio_ctx,
|
|
|
+ self.model.dims.n_audio_state,
|
|
|
):
|
|
|
- languages = [self.options.language] * audio_features.shape[0]
|
|
|
- lang_probs = None
|
|
|
+ # encoded audio features are given; skip audio encoding
|
|
|
+ audio_features = mel
|
|
|
+ else:
|
|
|
+ audio_features = self.model.encoder(mel)
|
|
|
|
|
|
- if self.options.language is None or self.options.task == "lang_id":
|
|
|
- lang_tokens, lang_probs = self.model.detect_language(
|
|
|
- audio_features, self.tokenizer, self.resource_path
|
|
|
- )
|
|
|
- languages = [max(probs, key=probs.get) for probs in lang_probs]
|
|
|
- if self.options.language is None:
|
|
|
- tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
|
|
+ return audio_features
|
|
|
|
|
|
- return languages, lang_probs
|
|
|
+ def _detect_language(
|
|
|
+ self,
|
|
|
+ audio_features: paddle.Tensor,
|
|
|
+ tokens: paddle.Tensor,
|
|
|
+ resource_path: str,
|
|
|
+ ):
|
|
|
+ languages = [self.options.language] * audio_features.shape[0]
|
|
|
+ lang_probs = None
|
|
|
|
|
|
- def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
- assert audio_features.shape[0] == tokens.shape[0]
|
|
|
- n_batch = tokens.shape[0]
|
|
|
- sum_logprobs: paddle.Tensor = paddle.zeros(
|
|
|
- paddle.to_tensor(n_batch), dtype=paddle.float32
|
|
|
- )
|
|
|
- no_speech_probs = [np.nan] * n_batch
|
|
|
-
|
|
|
- try:
|
|
|
- for i in range(self.sample_len):
|
|
|
- logits = self.inference.logits(tokens, audio_features)
|
|
|
-
|
|
|
- if (
|
|
|
- i == 0 and self.tokenizer.no_speech is not None
|
|
|
- ): # save no_speech_probs
|
|
|
- probs_at_sot = paddle.nn.functional.softmax(
|
|
|
- logits[:, self.sot_index], axis=-1, dtype=paddle.float32
|
|
|
- )
|
|
|
- no_speech_probs = probs_at_sot[
|
|
|
- :, self.tokenizer.no_speech
|
|
|
- ].tolist()
|
|
|
-
|
|
|
- # now we need to consider the logits at the last token only
|
|
|
- logits = logits[:, -1]
|
|
|
-
|
|
|
- # apply the logit filters, e.g. for suppressing or applying penalty to
|
|
|
- for logit_filter in self.logit_filters:
|
|
|
- logit_filter.apply(logits, tokens)
|
|
|
-
|
|
|
- # expand the tokens tensor with the selected next tokens
|
|
|
- tokens, completed = self.decoder.update(
|
|
|
- tokens, logits, sum_logprobs
|
|
|
- )
|
|
|
- if completed or tokens.shape[-1] > self.n_ctx:
|
|
|
- break
|
|
|
- finally:
|
|
|
- self.inference.cleanup_caching()
|
|
|
-
|
|
|
- return tokens, sum_logprobs, no_speech_probs
|
|
|
-
|
|
|
- @paddle.no_grad()
|
|
|
- def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
|
|
|
- self.decoder.reset()
|
|
|
- tokenizer: Tokenizer = self.tokenizer
|
|
|
- batch_size: int = mel.shape[0]
|
|
|
-
|
|
|
- audio_features: paddle.Tensor = self._get_audio_features(
|
|
|
- mel
|
|
|
- ) # encoder forward pass
|
|
|
-
|
|
|
- tokens: paddle.Tensor
|
|
|
- if batch_size > 1:
|
|
|
- for i in range(batch_size):
|
|
|
- tokens = paddle.concat(
|
|
|
- x=[
|
|
|
- paddle.to_tensor([self.initial_tokens]),
|
|
|
- paddle.to_tensor([self.initial_tokens]),
|
|
|
- ],
|
|
|
- axis=0,
|
|
|
- )
|
|
|
- elif batch_size == 1:
|
|
|
- tokens = paddle.to_tensor([self.initial_tokens])
|
|
|
-
|
|
|
- # detect language if requested, overwriting the language token
|
|
|
- languages, language_probs = self._detect_language(
|
|
|
- paddle.to_tensor(audio_features),
|
|
|
- paddle.to_tensor(tokens),
|
|
|
- self.resource_path,
|
|
|
+ if self.options.language is None or self.options.task == "lang_id":
|
|
|
+ lang_tokens, lang_probs = self.model.detect_language(
|
|
|
+ audio_features, self.tokenizer, self.resource_path
|
|
|
)
|
|
|
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
|
|
|
+ if self.options.language is None:
|
|
|
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
|
|
|
|
|
- if self.options.task == "lang_id":
|
|
|
- return [
|
|
|
- DecodingResult(
|
|
|
- audio_features=features, language=language, language_probs=probs
|
|
|
- )
|
|
|
- for features, language, probs in zip(
|
|
|
- audio_features, languages, language_probs
|
|
|
+ return languages, lang_probs
|
|
|
+
|
|
|
+ def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
+ assert audio_features.shape[0] == tokens.shape[0]
|
|
|
+ n_batch = tokens.shape[0]
|
|
|
+ sum_logprobs: paddle.Tensor = paddle.zeros(
|
|
|
+ paddle.to_tensor(n_batch), dtype=paddle.float32
|
|
|
+ )
|
|
|
+ no_speech_probs = [np.nan] * n_batch
|
|
|
+
|
|
|
+ try:
|
|
|
+ for i in range(self.sample_len):
|
|
|
+ logits = self.inference.logits(tokens, audio_features)
|
|
|
+
|
|
|
+ if (
|
|
|
+ i == 0 and self.tokenizer.no_speech is not None
|
|
|
+ ): # save no_speech_probs
|
|
|
+ probs_at_sot = paddle.nn.functional.softmax(
|
|
|
+ logits[:, self.sot_index], axis=-1, dtype=paddle.float32
|
|
|
)
|
|
|
- ]
|
|
|
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
|
|
|
|
|
- # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
|
|
- audio_features = paddle.repeat_interleave(
|
|
|
- audio_features, self.beam_size, axis=0
|
|
|
- )
|
|
|
- tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
|
|
|
- # call the main sampling loop
|
|
|
- tokens, sum_logprobs, no_speech_probs = self._main_loop(
|
|
|
- audio_features, tokens
|
|
|
- )
|
|
|
- # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
|
|
|
- audio_features = audio_features[:: self.beam_size]
|
|
|
- no_speech_probs = no_speech_probs[:: self.beam_size]
|
|
|
- assert audio_features.shape[0] == len(no_speech_probs) == batch_size
|
|
|
- tokens = tokens.reshape([batch_size, self.beam_size, -1])
|
|
|
- sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
|
|
|
-
|
|
|
- # get the final candidates for each group, and slice between the first sampled token and EOT
|
|
|
- tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
|
|
- tokens: List[List[paddle.Tensor]] = [
|
|
|
- [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
|
|
- for s in tokens
|
|
|
- ]
|
|
|
+ # now we need to consider the logits at the last token only
|
|
|
+ logits = logits[:, -1]
|
|
|
|
|
|
- # select the top-ranked sample in each group
|
|
|
- selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
|
|
- tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
|
|
- texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
|
|
+ # apply the logit filters, e.g. for suppressing or applying penalty to
|
|
|
+ for logit_filter in self.logit_filters:
|
|
|
+ logit_filter.apply(logits, tokens)
|
|
|
|
|
|
- sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
|
|
- avg_logprobs: List[float] = [
|
|
|
- lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
|
|
- ]
|
|
|
+ # expand the tokens tensor with the selected next tokens
|
|
|
+ tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
|
|
+ if completed or tokens.shape[-1] > self.n_ctx:
|
|
|
+ break
|
|
|
+ finally:
|
|
|
+ self.inference.cleanup_caching()
|
|
|
|
|
|
- fields = (
|
|
|
- texts,
|
|
|
- languages,
|
|
|
- tokens,
|
|
|
- audio_features,
|
|
|
- avg_logprobs,
|
|
|
- no_speech_probs,
|
|
|
- )
|
|
|
- if len(set(map(len, fields))) != 1:
|
|
|
- raise RuntimeError(
|
|
|
- f"inconsistent result lengths: {list(map(len, fields))}"
|
|
|
+ return tokens, sum_logprobs, no_speech_probs
|
|
|
+
|
|
|
+ @paddle.no_grad()
|
|
|
+ def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
|
|
|
+ self.decoder.reset()
|
|
|
+ tokenizer: Tokenizer = self.tokenizer
|
|
|
+ batch_size: int = mel.shape[0]
|
|
|
+
|
|
|
+ audio_features: paddle.Tensor = self._get_audio_features(
|
|
|
+ mel
|
|
|
+ ) # encoder forward pass
|
|
|
+
|
|
|
+ tokens: paddle.Tensor
|
|
|
+ if batch_size > 1:
|
|
|
+ for i in range(batch_size):
|
|
|
+ tokens = paddle.concat(
|
|
|
+ x=[
|
|
|
+ paddle.to_tensor([self.initial_tokens]),
|
|
|
+ paddle.to_tensor([self.initial_tokens]),
|
|
|
+ ],
|
|
|
+ axis=0,
|
|
|
)
|
|
|
+ elif batch_size == 1:
|
|
|
+ tokens = paddle.to_tensor([self.initial_tokens])
|
|
|
+
|
|
|
+ # detect language if requested, overwriting the language token
|
|
|
+ languages, language_probs = self._detect_language(
|
|
|
+ paddle.to_tensor(audio_features),
|
|
|
+ paddle.to_tensor(tokens),
|
|
|
+ self.resource_path,
|
|
|
+ )
|
|
|
|
|
|
+ if self.options.task == "lang_id":
|
|
|
return [
|
|
|
DecodingResult(
|
|
|
- audio_features=features,
|
|
|
- language=language,
|
|
|
- tokens=tokens,
|
|
|
- text=text,
|
|
|
- avg_logprob=avg_logprob,
|
|
|
- no_speech_prob=no_speech_prob,
|
|
|
- temperature=self.options.temperature,
|
|
|
- compression_ratio=compression_ratio(text),
|
|
|
+ audio_features=features, language=language, language_probs=probs
|
|
|
)
|
|
|
- for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
|
|
- *fields
|
|
|
+ for features, language, probs in zip(
|
|
|
+ audio_features, languages, language_probs
|
|
|
)
|
|
|
]
|
|
|
|
|
|
- @paddle.no_grad()
|
|
|
- def decode(
|
|
|
- model: "Whisper",
|
|
|
- mel: paddle.Tensor,
|
|
|
- options: DecodingOptions = DecodingOptions(),
|
|
|
- resource_path=str,
|
|
|
- ) -> Union[DecodingResult, List[DecodingResult]]:
|
|
|
- """
|
|
|
- Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- model: Whisper
|
|
|
- the Whisper model instance
|
|
|
- mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
|
|
- A tensor containing the Mel spectrogram(s)
|
|
|
- options: DecodingOptions
|
|
|
- A dataclass that contains all necessary options for decoding 30-second segments
|
|
|
- Returns
|
|
|
- -------
|
|
|
- result: Union[DecodingResult, List[DecodingResult]]
|
|
|
- The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
|
|
- """
|
|
|
- single = mel.ndim == 2
|
|
|
- if single:
|
|
|
- mel = mel.unsqueeze(0)
|
|
|
-
|
|
|
- result = DecodingTask(model, options, resource_path).run(mel)
|
|
|
+ # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
|
|
+ audio_features = paddle.repeat_interleave(
|
|
|
+ audio_features, self.beam_size, axis=0
|
|
|
+ )
|
|
|
+ tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
|
|
|
+ # call the main sampling loop
|
|
|
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
|
|
+ # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
|
|
|
+ audio_features = audio_features[:: self.beam_size]
|
|
|
+ no_speech_probs = no_speech_probs[:: self.beam_size]
|
|
|
+ assert audio_features.shape[0] == len(no_speech_probs) == batch_size
|
|
|
+ tokens = tokens.reshape([batch_size, self.beam_size, -1])
|
|
|
+ sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
|
|
|
+
|
|
|
+ # get the final candidates for each group, and slice between the first sampled token and EOT
|
|
|
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
|
|
+ tokens: List[List[paddle.Tensor]] = [
|
|
|
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
|
|
+ for s in tokens
|
|
|
+ ]
|
|
|
|
|
|
- if single:
|
|
|
- result = result[0]
|
|
|
+ # select the top-ranked sample in each group
|
|
|
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
|
|
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
|
|
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
|
|
|
|
|
- return result
|
|
|
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
|
|
+ avg_logprobs: List[float] = [
|
|
|
+ lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
|
|
+ ]
|
|
|
|
|
|
- class Whisper(paddle.nn.Layer):
|
|
|
- """
|
|
|
- The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
|
|
|
- """
|
|
|
+ fields = (
|
|
|
+ texts,
|
|
|
+ languages,
|
|
|
+ tokens,
|
|
|
+ audio_features,
|
|
|
+ avg_logprobs,
|
|
|
+ no_speech_probs,
|
|
|
+ )
|
|
|
+ if len(set(map(len, fields))) != 1:
|
|
|
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
|
|
|
|
|
- def __init__(self, dims: ModelDimensions):
|
|
|
- super().__init__()
|
|
|
- self.dims = dims
|
|
|
- self.encoder = AudioEncoder(
|
|
|
- self.dims.n_mels,
|
|
|
- self.dims.n_audio_ctx,
|
|
|
- self.dims.n_audio_state,
|
|
|
- self.dims.n_audio_head,
|
|
|
- self.dims.n_audio_layer,
|
|
|
+ return [
|
|
|
+ DecodingResult(
|
|
|
+ audio_features=features,
|
|
|
+ language=language,
|
|
|
+ tokens=tokens,
|
|
|
+ text=text,
|
|
|
+ avg_logprob=avg_logprob,
|
|
|
+ no_speech_prob=no_speech_prob,
|
|
|
+ temperature=self.options.temperature,
|
|
|
+ compression_ratio=compression_ratio(text),
|
|
|
)
|
|
|
- self.decoder = TextDecoder(
|
|
|
- self.dims.n_vocab,
|
|
|
- self.dims.n_text_ctx,
|
|
|
- self.dims.n_text_state,
|
|
|
- self.dims.n_text_head,
|
|
|
- self.dims.n_text_layer,
|
|
|
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
|
|
+ *fields
|
|
|
)
|
|
|
+ ]
|
|
|
|
|
|
- def embed_audio(self, mel: paddle.Tensor):
|
|
|
- return self.encoder.forward(mel)
|
|
|
-
|
|
|
- def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
|
|
|
- return self.decoder.forward(tokens, audio_features)
|
|
|
-
|
|
|
- def forward(
|
|
|
- self, mel: paddle.Tensor, tokens: paddle.Tensor
|
|
|
- ) -> Dict[str, paddle.Tensor]:
|
|
|
- return self.decoder(tokens, self.encoder(mel))
|
|
|
-
|
|
|
- @property
|
|
|
- def device(self):
|
|
|
- return paddle.device.get_device()
|
|
|
-
|
|
|
- @property
|
|
|
- def is_multilingual(self):
|
|
|
- return self.dims.n_vocab == 51865
|
|
|
-
|
|
|
- def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
|
|
- """
|
|
|
- The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
|
|
- tensors calculated for the previous positions. This method returns a dictionary that stores
|
|
|
- all caches, and the necessary hooks for the key and value projection modules that save the
|
|
|
- intermediate tensors to be reused during later calculations.
|
|
|
- Returns
|
|
|
- -------
|
|
|
- cache : Dict[nn.Layer, paddle.Tensor]
|
|
|
- A dictionary object mapping the key/value projection modules to its cache
|
|
|
- hooks : List[RemovableHandle]
|
|
|
- List of PyTorch RemovableHandle objects to stop the hooks to be called
|
|
|
- """
|
|
|
- cache = {**cache} if cache is not None else {}
|
|
|
- hooks = []
|
|
|
-
|
|
|
- def save_to_cache(module, _, output):
|
|
|
- if (
|
|
|
- module not in cache
|
|
|
- or output.shape[1] > self.decoder.positional_embedding.shape[0]
|
|
|
- ):
|
|
|
- cache[module] = (
|
|
|
- output # save as-is, for the first token or cross attention
|
|
|
- )
|
|
|
- else:
|
|
|
- cache[module] = paddle.concat(
|
|
|
- [cache[module], output], axis=1
|
|
|
- ).detach()
|
|
|
- return cache[module]
|
|
|
|
|
|
- def install_hooks(layer: paddle.nn.Layer):
|
|
|
- if isinstance(layer, MultiHeadAttention):
|
|
|
- hooks.append(layer.key.register_forward_post_hook(save_to_cache))
|
|
|
- hooks.append(layer.value.register_forward_post_hook(save_to_cache))
|
|
|
+@paddle.no_grad()
|
|
|
+def decode(
|
|
|
+ model: "Whisper",
|
|
|
+ mel: paddle.Tensor,
|
|
|
+ options: DecodingOptions = DecodingOptions(),
|
|
|
+ resource_path=str,
|
|
|
+) -> Union[DecodingResult, List[DecodingResult]]:
|
|
|
+ """
|
|
|
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ model: Whisper
|
|
|
+ the Whisper model instance
|
|
|
+ mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
|
|
+ A tensor containing the Mel spectrogram(s)
|
|
|
+ options: DecodingOptions
|
|
|
+ A dataclass that contains all necessary options for decoding 30-second segments
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ result: Union[DecodingResult, List[DecodingResult]]
|
|
|
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
|
|
+ """
|
|
|
+ single = mel.ndim == 2
|
|
|
+ if single:
|
|
|
+ mel = mel.unsqueeze(0)
|
|
|
+
|
|
|
+ result = DecodingTask(model, options, resource_path).run(mel)
|
|
|
+
|
|
|
+ if single:
|
|
|
+ result = result[0]
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+class Whisper(paddle.nn.Layer):
|
|
|
+ """
|
|
|
+ The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, dims: ModelDimensions):
|
|
|
+ super().__init__()
|
|
|
+ self.dims = dims
|
|
|
+ self.encoder = AudioEncoder(
|
|
|
+ self.dims.n_mels,
|
|
|
+ self.dims.n_audio_ctx,
|
|
|
+ self.dims.n_audio_state,
|
|
|
+ self.dims.n_audio_head,
|
|
|
+ self.dims.n_audio_layer,
|
|
|
+ )
|
|
|
+ self.decoder = TextDecoder(
|
|
|
+ self.dims.n_vocab,
|
|
|
+ self.dims.n_text_ctx,
|
|
|
+ self.dims.n_text_state,
|
|
|
+ self.dims.n_text_head,
|
|
|
+ self.dims.n_text_layer,
|
|
|
+ )
|
|
|
|
|
|
- self.decoder.apply(install_hooks)
|
|
|
- return cache, hooks
|
|
|
+ def embed_audio(self, mel: paddle.Tensor):
|
|
|
+ return self.encoder.forward(mel)
|
|
|
|
|
|
- detect_language = detect_language
|
|
|
- transcribe = transcribe
|
|
|
- decode = decode
|
|
|
+ def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
|
|
|
+ return self.decoder.forward(tokens, audio_features)
|
|
|
|
|
|
- def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|
|
- """
|
|
|
- Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
|
|
- """
|
|
|
- if paddle.is_tensor(array):
|
|
|
- if array.shape[axis] > length:
|
|
|
- array = array.index_select(axis=axis, index=paddle.arange(length))
|
|
|
-
|
|
|
- if array.shape[axis] < length:
|
|
|
- pad_widths = [(0, 0)] * array.ndim
|
|
|
- pad_widths[axis] = (0, length - array.shape[axis])
|
|
|
- array = paddle.transpose(array, (1, 0))
|
|
|
- array = paddle.nn.functional.pad(
|
|
|
- array,
|
|
|
- [pad for sizes in pad_widths[::-1] for pad in sizes],
|
|
|
- data_format="NLC",
|
|
|
- )
|
|
|
- array = paddle.transpose(array, (1, 0))
|
|
|
- else:
|
|
|
- if array.shape[axis] > length:
|
|
|
- array = array.take(indices=range(length), axis=axis)
|
|
|
-
|
|
|
- if array.shape[axis] < length:
|
|
|
- pad_widths = [(0, 0)] * array.ndim
|
|
|
- pad_widths[axis] = (0, length - array.shape[axis])
|
|
|
- array = paddle.transpose(array, (1, 0))
|
|
|
- array = np.pad(array, pad_widths)
|
|
|
- array = paddle.transpose(array, (1, 0))
|
|
|
+ def forward(
|
|
|
+ self, mel: paddle.Tensor, tokens: paddle.Tensor
|
|
|
+ ) -> Dict[str, paddle.Tensor]:
|
|
|
+ return self.decoder(tokens, self.encoder(mel))
|
|
|
|
|
|
- return array
|
|
|
+ @property
|
|
|
+ def device(self):
|
|
|
+ return paddle.device.get_device()
|
|
|
|
|
|
- def hann_window(n_fft: int = N_FFT):
|
|
|
- """
|
|
|
- hanning window
|
|
|
- n_fft: The number of frequency components of the discrete Fourier transform.
|
|
|
- """
|
|
|
- return paddle.to_tensor(
|
|
|
- [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
|
|
|
- dtype=paddle.float32,
|
|
|
- )
|
|
|
+ @property
|
|
|
+ def is_multilingual(self):
|
|
|
+ return self.dims.n_vocab == 51865
|
|
|
|
|
|
- @lru_cache(maxsize=None)
|
|
|
- def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
|
|
|
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
|
|
"""
|
|
|
- load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
|
|
- Allows decoupling librosa dependency; saved using:
|
|
|
- np.savez_compressed(
|
|
|
- "mel_filters.npz",
|
|
|
- mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
|
|
- )
|
|
|
- """
|
|
|
- assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
|
|
- with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
|
|
|
- return paddle.to_tensor(f[f"mel_{n_mels}"])
|
|
|
-
|
|
|
- @function_requires_deps("soundfile")
|
|
|
- def log_mel_spectrogram(
|
|
|
- audio: Union[str, np.ndarray, paddle.Tensor],
|
|
|
- n_mels: int = N_MELS,
|
|
|
- resource_path: str = None,
|
|
|
- ):
|
|
|
- """
|
|
|
- Compute the log-Mel spectrogram of
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
|
|
|
- The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
|
|
- n_mels: int
|
|
|
- The number of Mel-frequency filters, only 80 is supported
|
|
|
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
|
|
+ tensors calculated for the previous positions. This method returns a dictionary that stores
|
|
|
+ all caches, and the necessary hooks for the key and value projection modules that save the
|
|
|
+ intermediate tensors to be reused during later calculations.
|
|
|
Returns
|
|
|
-------
|
|
|
- paddle.Tensor, shape = (80, n_frames)
|
|
|
- A Tensor that contains the Mel spectrogram
|
|
|
+ cache : Dict[nn.Layer, paddle.Tensor]
|
|
|
+ A dictionary object mapping the key/value projection modules to its cache
|
|
|
+ hooks : List[RemovableHandle]
|
|
|
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
|
|
|
"""
|
|
|
- if not paddle.is_tensor(audio):
|
|
|
- if isinstance(audio, str):
|
|
|
- audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
|
|
|
- audio = audio[:, 0]
|
|
|
- audio = paddle.to_tensor(audio)
|
|
|
-
|
|
|
- window = hann_window(N_FFT)
|
|
|
- stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
|
|
|
-
|
|
|
- magnitudes = stft[:, :-1].abs() ** 2
|
|
|
+ cache = {**cache} if cache is not None else {}
|
|
|
+ hooks = []
|
|
|
|
|
|
- filters = mel_filters(resource_path, n_mels)
|
|
|
- mel_spec = filters @ magnitudes
|
|
|
- mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
|
|
|
-
|
|
|
- log_spec = paddle.clip(mel_spec, min=1e-10).log10()
|
|
|
- log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
|
|
|
- log_spec = (log_spec + 4.0) / 4.0
|
|
|
- return log_spec
|
|
|
+ def save_to_cache(module, _, output):
|
|
|
+ if (
|
|
|
+ module not in cache
|
|
|
+ or output.shape[1] > self.decoder.positional_embedding.shape[0]
|
|
|
+ ):
|
|
|
+ cache[module] = (
|
|
|
+ output # save as-is, for the first token or cross attention
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cache[module] = paddle.concat([cache[module], output], axis=1).detach()
|
|
|
+ return cache[module]
|
|
|
+
|
|
|
+ def install_hooks(layer: paddle.nn.Layer):
|
|
|
+ if isinstance(layer, MultiHeadAttention):
|
|
|
+ hooks.append(layer.key.register_forward_post_hook(save_to_cache))
|
|
|
+ hooks.append(layer.value.register_forward_post_hook(save_to_cache))
|
|
|
+
|
|
|
+ self.decoder.apply(install_hooks)
|
|
|
+ return cache, hooks
|
|
|
+
|
|
|
+ detect_language = detect_language
|
|
|
+ transcribe = transcribe
|
|
|
+ decode = decode
|
|
|
+
|
|
|
+
|
|
|
+def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|
|
+ """
|
|
|
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
|
|
+ """
|
|
|
+ if paddle.is_tensor(array):
|
|
|
+ if array.shape[axis] > length:
|
|
|
+ array = array.index_select(axis=axis, index=paddle.arange(length))
|
|
|
+
|
|
|
+ if array.shape[axis] < length:
|
|
|
+ pad_widths = [(0, 0)] * array.ndim
|
|
|
+ pad_widths[axis] = (0, length - array.shape[axis])
|
|
|
+ array = paddle.transpose(array, (1, 0))
|
|
|
+ array = paddle.nn.functional.pad(
|
|
|
+ array,
|
|
|
+ [pad for sizes in pad_widths[::-1] for pad in sizes],
|
|
|
+ data_format="NLC",
|
|
|
+ )
|
|
|
+ array = paddle.transpose(array, (1, 0))
|
|
|
+ else:
|
|
|
+ if array.shape[axis] > length:
|
|
|
+ array = array.take(indices=range(length), axis=axis)
|
|
|
+
|
|
|
+ if array.shape[axis] < length:
|
|
|
+ pad_widths = [(0, 0)] * array.ndim
|
|
|
+ pad_widths[axis] = (0, length - array.shape[axis])
|
|
|
+ array = paddle.transpose(array, (1, 0))
|
|
|
+ array = np.pad(array, pad_widths)
|
|
|
+ array = paddle.transpose(array, (1, 0))
|
|
|
+
|
|
|
+ return array
|
|
|
+
|
|
|
+
|
|
|
+def hann_window(n_fft: int = N_FFT):
|
|
|
+ """
|
|
|
+ hanning window
|
|
|
+ n_fft: The number of frequency components of the discrete Fourier transform.
|
|
|
+ """
|
|
|
+ return paddle.to_tensor(
|
|
|
+ [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
|
|
|
+ dtype=paddle.float32,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@lru_cache(maxsize=None)
|
|
|
+def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
|
|
|
+ """
|
|
|
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
|
|
+ Allows decoupling librosa dependency; saved using:
|
|
|
+ np.savez_compressed(
|
|
|
+ "mel_filters.npz",
|
|
|
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
|
|
+ )
|
|
|
+ """
|
|
|
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
|
|
+ with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
|
|
|
+ return paddle.to_tensor(f[f"mel_{n_mels}"])
|
|
|
+
|
|
|
+
|
|
|
+@function_requires_deps("soundfile")
|
|
|
+def log_mel_spectrogram(
|
|
|
+ audio: Union[str, np.ndarray, paddle.Tensor],
|
|
|
+ n_mels: int = N_MELS,
|
|
|
+ resource_path: str = None,
|
|
|
+):
|
|
|
+ """
|
|
|
+ Compute the log-Mel spectrogram of
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
|
|
|
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
|
|
+ n_mels: int
|
|
|
+ The number of Mel-frequency filters, only 80 is supported
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ paddle.Tensor, shape = (80, n_frames)
|
|
|
+ A Tensor that contains the Mel spectrogram
|
|
|
+ """
|
|
|
+ if not paddle.is_tensor(audio):
|
|
|
+ if isinstance(audio, str):
|
|
|
+ audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
|
|
|
+ audio = audio[:, 0]
|
|
|
+ audio = paddle.to_tensor(audio)
|
|
|
+
|
|
|
+ window = hann_window(N_FFT)
|
|
|
+ stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
|
|
|
+
|
|
|
+ magnitudes = stft[:, :-1].abs() ** 2
|
|
|
+
|
|
|
+ filters = mel_filters(resource_path, n_mels)
|
|
|
+ mel_spec = filters @ magnitudes
|
|
|
+ mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
|
|
|
+
|
|
|
+ log_spec = paddle.clip(mel_spec, min=1e-10).log10()
|
|
|
+ log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
|
|
|
+ log_spec = (log_spec + 4.0) / 4.0
|
|
|
+ return log_spec
|