| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
- import os
- import zlib
- from dataclasses import dataclass, field
- from functools import lru_cache
- from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
- import numpy as np
- import paddle
- from ....utils.deps import function_requires_deps, is_dep_available
- from ...utils.benchmark import (
- benchmark,
- get_inference_operations,
- set_inference_operations,
- )
- from ..common.tokenizer import GPTTokenizer
- if is_dep_available("soundfile"):
- import soundfile
- if is_dep_available("tqdm"):
- import tqdm
- __all__ = [
- "Whisper",
- "Tokenizer",
- ]
- def exact_div(x, y):
- assert x % y == 0
- return x // y
- _MODELS = ["large"]
- SAMPLE_RATE = 16000
- N_FFT = 400
- N_MELS = 80
- HOP_LENGTH = 160
- CHUNK_LENGTH = 30
- N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
- N_FRAMES = exact_div(
- N_SAMPLES, HOP_LENGTH
- ) # 3000: number of frames in a mel spectrogram input
- @dataclass
- class ModelDimensions:
- n_mels: int
- n_audio_ctx: int
- n_audio_state: int
- n_audio_head: int
- n_audio_layer: int
- n_vocab: int
- n_text_ctx: int
- n_text_state: int
- n_text_head: int
- n_text_layer: int
- LANGUAGES = {
- "en": "english",
- "zh": "chinese",
- "de": "german",
- "es": "spanish",
- "ru": "russian",
- "ko": "korean",
- "fr": "french",
- "ja": "japanese",
- "pt": "portuguese",
- "tr": "turkish",
- "pl": "polish",
- "ca": "catalan",
- "nl": "dutch",
- "ar": "arabic",
- "sv": "swedish",
- "it": "italian",
- "id": "indonesian",
- "hi": "hindi",
- "fi": "finnish",
- "vi": "vietnamese",
- "iw": "hebrew",
- "uk": "ukrainian",
- "el": "greek",
- "ms": "malay",
- "cs": "czech",
- "ro": "romanian",
- "da": "danish",
- "hu": "hungarian",
- "ta": "tamil",
- "no": "norwegian",
- "th": "thai",
- "ur": "urdu",
- "hr": "croatian",
- "bg": "bulgarian",
- "lt": "lithuanian",
- "la": "latin",
- "mi": "maori",
- "ml": "malayalam",
- "cy": "welsh",
- "sk": "slovak",
- "te": "telugu",
- "fa": "persian",
- "lv": "latvian",
- "bn": "bengali",
- "sr": "serbian",
- "az": "azerbaijani",
- "sl": "slovenian",
- "kn": "kannada",
- "et": "estonian",
- "mk": "macedonian",
- "br": "breton",
- "eu": "basque",
- "is": "icelandic",
- "hy": "armenian",
- "ne": "nepali",
- "mn": "mongolian",
- "bs": "bosnian",
- "kk": "kazakh",
- "sq": "albanian",
- "sw": "swahili",
- "gl": "galician",
- "mr": "marathi",
- "pa": "punjabi",
- "si": "sinhala",
- "km": "khmer",
- "sn": "shona",
- "yo": "yoruba",
- "so": "somali",
- "af": "afrikaans",
- "oc": "occitan",
- "ka": "georgian",
- "be": "belarusian",
- "tg": "tajik",
- "sd": "sindhi",
- "gu": "gujarati",
- "am": "amharic",
- "yi": "yiddish",
- "lo": "lao",
- "uz": "uzbek",
- "fo": "faroese",
- "ht": "haitian creole",
- "ps": "pashto",
- "tk": "turkmen",
- "nn": "nynorsk",
- "mt": "maltese",
- "sa": "sanskrit",
- "lb": "luxembourgish",
- "my": "myanmar",
- "bo": "tibetan",
- "tl": "tagalog",
- "mg": "malagasy",
- "as": "assamese",
- "tt": "tatar",
- "haw": "hawaiian",
- "ln": "lingala",
- "ha": "hausa",
- "ba": "bashkir",
- "jw": "javanese",
- "su": "sundanese",
- }
- # language code lookup by name, with a few language aliases
- TO_LANGUAGE_CODE = {
- **{language: code for code, language in LANGUAGES.items()},
- "burmese": "my",
- "valencian": "ca",
- "flemish": "nl",
- "haitian": "ht",
- "letzeburgesch": "lb",
- "pushto": "ps",
- "panjabi": "pa",
- "moldavian": "ro",
- "moldovan": "ro",
- "sinhalese": "si",
- "castilian": "es",
- }
- def compression_ratio(text) -> float:
- return len(text) / len(zlib.compress(text.encode("utf-8")))
- def format_timestamp(
- seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
- ):
- assert seconds >= 0, "non-negative timestamp expected"
- milliseconds = round(seconds * 1000.0)
- hours = milliseconds // 3_600_000
- milliseconds -= hours * 3_600_000
- minutes = milliseconds // 60_000
- milliseconds -= minutes * 60_000
- seconds = milliseconds // 1_000
- milliseconds -= seconds * 1_000
- hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
- return (
- f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
- )
- @dataclass(frozen=True)
- class Tokenizer:
- """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
- tokenizer: "GPTTokenizer"
- language: Optional[str]
- sot_sequence: Tuple[int]
- def encode(self, text, **kwargs):
- return self.tokenizer.encode(text, **kwargs)
- def decode(
- self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
- ):
- if len(token_ids) > 1:
- ids_list = []
- for ids in token_ids:
- if paddle.is_tensor(ids):
- ids = ids.item()
- if ids < len(self.tokenizer):
- ids_list.append(ids)
- token_ids = ids_list
- elif len(token_ids) == 1:
- token_ids = token_ids[0]
- else:
- raise ValueError(f"token_ids {token_ids} load error.")
- return self.tokenizer.decode(token_ids, **kwargs)
- def decode_with_timestamps(self, tokens) -> str:
- """
- Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
- This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
- """
- outputs = [[]]
- for token in tokens:
- if token >= self.timestamp_begin:
- timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
- outputs.append(timestamp)
- outputs.append([])
- else:
- outputs[-1].append(token)
- outputs = [
- s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
- ]
- return "".join(outputs)
- @property
- @lru_cache()
- def eot(self) -> int:
- return self.tokenizer.eos_token_id
- @property
- @lru_cache()
- def sot(self) -> int:
- return self._get_single_token_id("<|startoftranscript|>")
- @property
- @lru_cache()
- def sot_lm(self) -> int:
- return self._get_single_token_id("<|startoflm|>")
- @property
- @lru_cache()
- def sot_prev(self) -> int:
- return self._get_single_token_id("<|startofprev|>")
- @property
- @lru_cache()
- def no_speech(self) -> int:
- return self._get_single_token_id("<|nospeech|>")
- @property
- @lru_cache()
- def no_timestamps(self) -> int:
- return self._get_single_token_id("<|notimestamps|>")
- @property
- @lru_cache()
- def timestamp_begin(self) -> int:
- return self.tokenizer.all_special_ids[-1] + 1
- @property
- @lru_cache()
- def language_token(self) -> int:
- """Returns the token id corresponding to the value of the `language` field"""
- if self.language is None:
- raise ValueError("This tokenizer does not have language token configured")
- additional_tokens = dict(
- zip(
- self.tokenizer.additional_special_tokens,
- self.tokenizer.additional_special_tokens_ids,
- )
- )
- candidate = f"<|{self.language}|>"
- if candidate in additional_tokens:
- return additional_tokens[candidate]
- raise KeyError(f"Language {self.language} not found in tokenizer.")
- @property
- @lru_cache()
- def all_language_tokens(self) -> Tuple[int]:
- result = []
- for token, token_id in zip(
- self.tokenizer.additional_special_tokens,
- self.tokenizer.additional_special_tokens_ids,
- ):
- if token.strip("<|>") in LANGUAGES:
- result.append(token_id)
- return tuple(result)
- @property
- @lru_cache()
- def all_language_codes(self) -> Tuple[str]:
- return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
- @property
- @lru_cache()
- def sot_sequence_including_notimestamps(self) -> Tuple[int]:
- return tuple(list(self.sot_sequence) + [self.no_timestamps])
- @property
- @lru_cache()
- def non_speech_tokens(self) -> Tuple[int]:
- """
- Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
- annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
- - ♪♪♪
- - ( SPEAKING FOREIGN LANGUAGE )
- - [DAVID] Hey there,
- keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
- """
- symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
- symbols += (
- "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
- )
- # symbols that may be a single token or multiple tokens depending on the tokenizer.
- # In case they're multiple tokens, suppress the first token, which is safe because:
- # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
- # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
- miscellaneous = set("♩♪♫♬♭♮♯")
- assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
- # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
- result = {
- self.tokenizer.encode(" -").input_ids[0],
- self.tokenizer.encode(" '").input_ids[0],
- }
- for symbol in symbols + list(miscellaneous):
- for tokens in [
- self.tokenizer.encode(symbol).input_ids,
- self.tokenizer.encode(" " + symbol).input_ids,
- ]:
- if len(tokens) == 1 or symbol in miscellaneous:
- result.add(tokens[0])
- return tuple(sorted(result))
- def _get_single_token_id(self, text) -> int:
- tokens = self.tokenizer.encode(text).input_ids
- assert len(tokens) == 1, f"{text} is not encoded as a single token"
- return tokens[0]
- @lru_cache(maxsize=None)
- def build_tokenizer(resource_path: str, name: str = "gpt2"):
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- path = os.path.join(resource_path, "assets", name)
- tokenizer = GPTTokenizer.from_pretrained(path)
- specials = [
- "<|startoftranscript|>",
- *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
- "<|translate|>",
- "<|transcribe|>",
- "<|startoflm|>",
- "<|startofprev|>",
- "<|nospeech|>",
- "<|notimestamps|>",
- ]
- tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
- return tokenizer
- @lru_cache(maxsize=None)
- def get_tokenizer(
- multilingual: bool,
- resource_path: str,
- *,
- task: Optional[str] = None, # Literal["transcribe", "translate", None]
- language: Optional[str] = None,
- ) -> Tokenizer:
- if language is not None:
- language = language.lower()
- if language not in LANGUAGES:
- if language in TO_LANGUAGE_CODE:
- language = TO_LANGUAGE_CODE[language]
- else:
- raise ValueError(f"Unsupported language: {language}")
- if multilingual:
- tokenizer_name = "multilingual"
- task = task or "transcribe"
- language = language or "en"
- else:
- tokenizer_name = "gpt2"
- task = None
- language = None
- tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
- all_special_ids: List[int] = tokenizer.all_special_ids
- sot: int = all_special_ids[1]
- translate: int = all_special_ids[-6]
- transcribe: int = all_special_ids[-5]
- langs = tuple(LANGUAGES.keys())
- sot_sequence = [sot]
- if language is not None:
- sot_sequence.append(sot + 1 + langs.index(language))
- if task is not None:
- sot_sequence.append(transcribe if task == "transcribe" else translate)
- return Tokenizer(
- tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
- )
- class MultiHeadAttention(paddle.nn.Layer):
- def __init__(self, n_state: int, n_head: int):
- super().__init__()
- self.n_head = n_head
- self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
- self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
- self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
- self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
- def forward(
- self,
- x: paddle.Tensor,
- xa: Optional[paddle.Tensor] = None,
- mask: Optional[paddle.Tensor] = None,
- kv_cache: Optional[dict] = None,
- ):
- q = self.query(x)
- if kv_cache is None or xa is None or self.key not in kv_cache:
- # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
- # otherwise, perform key/value projections for self- or cross-attention as usual.
- k = self.key(x if xa is None else xa)
- v = self.value(x if xa is None else xa)
- else:
- # for cross-attention, calculate keys and values once and reuse in subsequent calls.
- k = kv_cache[self.key]
- v = kv_cache[self.value]
- wv = self.qkv_attention(q, k, v, mask)
- return self.out(wv)
- def qkv_attention(
- self,
- q: paddle.Tensor,
- k: paddle.Tensor,
- v: paddle.Tensor,
- mask: Optional[paddle.Tensor] = None,
- ):
- n_batch, n_ctx, n_state = q.shape
- scale = (n_state // self.n_head) ** -0.25
- q = (
- paddle.transpose(q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
- * scale
- )
- k = (
- paddle.transpose(k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1))
- * scale
- )
- v = paddle.transpose(v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
- qk = q @ k
- if mask is not None:
- qk = qk + mask[:n_ctx, :n_ctx]
- w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
- return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
- class ResidualAttentionBlock(paddle.nn.Layer):
- def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
- super().__init__()
- self.attn = MultiHeadAttention(n_state, n_head)
- self.attn_ln = paddle.nn.LayerNorm(n_state)
- self.cross_attn = (
- MultiHeadAttention(n_state, n_head) if cross_attention else None
- )
- self.cross_attn_ln = paddle.nn.LayerNorm(n_state) if cross_attention else None
- n_mlp = n_state * 4
- self.mlp = paddle.nn.Sequential(
- paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
- paddle.nn.GELU(),
- paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
- )
- self.mlp_ln = paddle.nn.LayerNorm(n_state)
- def forward(
- self,
- x: paddle.Tensor,
- xa: Optional[paddle.Tensor] = None,
- mask: Optional[paddle.Tensor] = None,
- kv_cache: Optional[dict] = None,
- ):
- x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
- if self.cross_attn:
- x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
- x = x + self.mlp(self.mlp_ln(x))
- return x
- def sinusoids(length, channels, max_timescale=10000):
- """Returns sinusoids for positional embedding"""
- assert channels % 2 == 0
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
- inv_timescales = paddle.exp(
- -log_timescale_increment * paddle.arange(channels // 2, dtype=paddle.float32)
- )
- scaled_time = (
- paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
- * inv_timescales[np.newaxis, :]
- )
- return paddle.to_tensor(
- paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
- )
- class AudioEncoder(paddle.nn.Layer):
- def __init__(
- self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
- ):
- super().__init__()
- self.conv1 = paddle.nn.Conv1D(
- n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
- )
- self.conv2 = paddle.nn.Conv1D(
- n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
- )
- self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
- self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
- [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
- )
- self.ln_post = paddle.nn.LayerNorm(n_state)
- def forward(self, x: paddle.Tensor):
- """
- x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
- the mel spectrogram of the audio
- """
- x = paddle.nn.functional.gelu(self.conv1(x))
- x = paddle.nn.functional.gelu(self.conv2(x))
- x = paddle.transpose(x, (0, 2, 1))
- assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
- x = x + self.positional_embedding
- for block in self.blocks:
- x = block(x)
- x = self.ln_post(x)
- return x
- class TextDecoder(paddle.nn.Layer):
- def __init__(
- self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
- ):
- super().__init__()
- self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
- self.positional_embedding = paddle.create_parameter(
- shape=[n_ctx, n_state], dtype="float32"
- )
- self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
- [
- ResidualAttentionBlock(n_state, n_head, cross_attention=True)
- for _ in range(n_layer)
- ]
- )
- self.ln = paddle.nn.LayerNorm(n_state)
- mask = paddle.full(shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32")
- mask = paddle.triu(mask, diagonal=1)
- self.register_buffer("mask", mask, persistable=False)
- def forward(
- self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
- ):
- """
- x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
- the text tokens
- xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
- the encoded audio features to be attended on
- """
- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
- x = (
- self.token_embedding(x)
- + self.positional_embedding[offset : offset + x.shape[-1]]
- )
- x = x.to(xa.dtype)
- for block in self.blocks:
- x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
- x = self.ln(x)
- logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
- return logits
- @dataclass(frozen=True)
- class DecodingOptions:
- task: str = (
- "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
- )
- language: Optional[str] = (
- None # language that the audio is in; uses detected language if None
- )
- # sampling-related options
- temperature: float = 0.0
- sample_len: Optional[int] = None # maximum number of tokens to sample
- best_of: Optional[int] = (
- None # number of independent samples to collect, when t > 0
- )
- beam_size: Optional[int] = None # number of beams in beam search, when t == 0
- patience: Optional[float] = (
- None # patience in beam search (https://arxiv.org/abs/2204.05424)
- )
- # options for ranking generations (either beams or best-of-N samples)
- length_penalty: Optional[float] = (
- None # "alpha" in Google NMT, None defaults to length norm
- )
- # prompt, prefix, and token suppression
- prompt: Optional[Union[str, List[int]]] = (
- None # text or tokens for the previous context
- )
- prefix: Optional[Union[str, List[int]]] = (
- None # text or tokens to prefix the current context
- )
- suppress_blank: bool = True # this will suppress blank outputs
- # list of tokens ids (or comma-separated token ids) to suppress
- # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
- suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
- # timestamp sampling options
- without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
- max_initial_timestamp: Optional[float] = (
- 1.0 # the initial timestamp cannot be later than this
- )
- # implementation details
- fp16: bool = False # use fp16 for most of the calculation
- @dataclass(frozen=True)
- class DecodingResult:
- audio_features: paddle.Tensor
- language: str
- language_probs: Optional[Dict[str, float]] = None
- tokens: List[int] = field(default_factory=list)
- text: str = ""
- avg_logprob: float = np.nan
- no_speech_prob: float = np.nan
- temperature: float = np.nan
- compression_ratio: float = np.nan
- class Inference:
- def logits(
- self, tokens: paddle.Tensor, audio_features: paddle.Tensor
- ) -> paddle.Tensor:
- """Perform a forward pass on the decoder and return per-token logits"""
- raise NotImplementedError
- def rearrange_kv_cache(self, source_indices) -> None:
- """Update the key-value cache according to the updated beams"""
- raise NotImplementedError
- def cleanup_caching(self) -> None:
- """Clean up any resources or hooks after decoding is finished"""
- class WhisperInference(Inference):
- def __init__(self, model: "Whisper", initial_token_length: int):
- self.model: "Whisper" = model
- self.initial_token_length = initial_token_length
- self.kv_cache = {}
- self.hooks = []
- def logits(
- self, tokens: paddle.Tensor, audio_features: paddle.Tensor
- ) -> paddle.Tensor:
- if not self.kv_cache:
- self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
- if tokens.shape[-1] > self.initial_token_length:
- # only need to use the last token except in the first forward pass
- tokens = tokens[:, -1:]
- return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
- def cleanup_caching(self):
- for hook in self.hooks:
- hook.remove()
- self.kv_cache = {}
- self.hooks = []
- def rearrange_kv_cache(self, source_indices):
- for module, tensor in self.kv_cache.items():
- # update the key/value cache to contain the selected sequences
- self.kv_cache[module] = tensor[source_indices].detach()
- @paddle.no_grad()
- def detect_language(
- model: "Whisper",
- mel: paddle.Tensor,
- resource_path: str,
- tokenizer: Tokenizer = None,
- ) -> Tuple[paddle.Tensor, List[dict]]:
- """
- Detect the spoken language in the audio, and return them as list of strings, along with the ids
- of the most probable language tokens and the probability distribution over all language tokens.
- This is performed outside the main decode loop in order to not interfere with kv-caching.
- Returns
- -------
- language_tokens : Tensor, shape = (batch_size,)
- ids of the most probable language tokens, which appears after the startoftranscript token.
- language_probs : List[Dict[str, float]], length = batch_size
- list of dictionaries containing the probability distribution over all languages.
- """
- if tokenizer is None:
- tokenizer = get_tokenizer(model.is_multilingual, resource_path=resource_path)
- if (
- tokenizer.language is None
- or tokenizer.language_token not in tokenizer.sot_sequence
- ):
- raise ValueError(
- "This model doesn't have language tokens so it can't perform lang id"
- )
- single = mel.ndim == 2
- if single:
- mel = mel.unsqueeze(0)
- # skip encoder forward pass if already-encoded audio features were given
- if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
- mel = model.encoder(mel)
- # forward pass using a single token, startoftranscript
- batch_size = mel.shape[0]
- x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
- logits = model.logits(x, mel)[:, 0]
- # collect detected languages; suppress all non-language tokens
- mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
- mask[list(tokenizer.all_language_tokens)] = False
- logits[:, mask] = -np.inf
- language_tokens = paddle.argmax(logits, axis=-1)
- language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
- language_probs = [
- {
- c: language_token_probs[i, j].tolist()
- for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
- }
- for i in range(batch_size)
- ]
- if single:
- language_tokens = language_tokens[0]
- language_probs = language_probs[0]
- return language_tokens, language_probs
- @function_requires_deps("tqdm")
- def transcribe(
- model: "Whisper",
- mel: paddle.Tensor,
- resource_path: str,
- *,
- verbose: Optional[bool] = None,
- temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
- compression_ratio_threshold: Optional[float] = 2.4,
- logprob_threshold: Optional[float] = -1.0,
- no_speech_threshold: Optional[float] = 0.6,
- condition_on_previous_text: bool = True,
- **decode_options,
- ):
- """
- Transcribe an audio file using Whisper
- Parameters
- ----------
- model: Whisper
- The Whisper model instance
- mel: paddle.Tensor
- The audio feature
- verbose: bool
- Whether to display the text being decoded to the console. If True, displays all the details,
- If False, displays minimal details. If None, does not display anything
- temperature: Union[float, Tuple[float, ...]]
- Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
- upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
- compression_ratio_threshold: float
- If the gzip compression ratio is above this value, treat as failed
- logprob_threshold: float
- If the average log probability over sampled tokens is below this value, treat as failed
- no_speech_threshold: float
- If the no_speech probability is higher than this value AND the average log probability
- over sampled tokens is below `logprob_threshold`, consider the segment as silent
- condition_on_previous_text: bool
- if True, the previous output of the model is provided as a prompt for the next window;
- disabling may make the text inconsistent across windows, but the model becomes less prone to
- getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
- decode_options: dict
- Keyword arguments to construct `DecodingOptions` instances
- Returns
- -------
- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
- the spoken language ("language"), which is detected when `decode_options["language"]` is None.
- """
- dtype = np.float32 # paddle only support float32
- if dtype == np.float32:
- decode_options["fp16"] = False
- if (
- decode_options.get("language") == "None"
- or decode_options.get("language", None) is None
- ):
- if not model.is_multilingual:
- decode_options["language"] = "en"
- else:
- if verbose:
- print(
- "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
- )
- segment = pad_or_trim(mel, N_FRAMES)
- _, probs = model.detect_language(segment, resource_path)
- decode_options["language"] = max(probs, key=probs.get)
- if verbose is not None:
- print(
- f"Detected language: {LANGUAGES[decode_options['language']].title()}"
- )
- language = decode_options["language"]
- task = decode_options.get("task", "transcribe")
- tokenizer = get_tokenizer(
- model.is_multilingual,
- resource_path=resource_path,
- language=language,
- task=task,
- )
- def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
- temperatures = (
- [temperature] if isinstance(temperature, (int, float)) else temperature
- )
- decode_result = None
- for t in temperatures:
- kwargs = {**decode_options}
- if t > 0:
- # disable beam_size and patience when t > 0
- kwargs.pop("beam_size", None)
- kwargs.pop("patience", None)
- else:
- # disable best_of when t == 0
- kwargs.pop("best_of", None)
- options = DecodingOptions(**kwargs, temperature=t)
- decode_result = model.decode(segment, options, resource_path)
- needs_fallback = False
- if (
- compression_ratio_threshold is not None
- and decode_result.compression_ratio > compression_ratio_threshold
- ):
- needs_fallback = True # too repetitive
- if (
- logprob_threshold is not None
- and decode_result.avg_logprob < logprob_threshold
- ):
- needs_fallback = True # average log probability is too low
- if not needs_fallback:
- break
- return decode_result
- seek = 0
- input_stride = exact_div(
- N_FRAMES, model.dims.n_audio_ctx
- ) # mel frames per output token: 2
- time_precision = (
- input_stride * HOP_LENGTH / SAMPLE_RATE
- ) # time per output token: 0.02 (seconds)
- all_tokens = []
- all_segments = []
- prompt_reset_since = 0
- initial_prompt = decode_options.pop("initial_prompt", None)
- if initial_prompt and initial_prompt != "None":
- initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
- all_tokens.extend(initial_prompt)
- else:
- initial_prompt = []
- def add_segment(
- *,
- start: float,
- end: float,
- text_tokens: paddle.Tensor,
- result: DecodingResult,
- ):
- text = tokenizer.decode(
- [token for token in text_tokens if token < tokenizer.eot]
- )
- if len(text.strip()) == 0: # skip empty text output
- return
- all_segments.append(
- {
- "id": len(all_segments),
- "seek": seek,
- "start": start,
- "end": end,
- "text": text,
- "tokens": result.tokens,
- "temperature": result.temperature,
- "avg_logprob": result.avg_logprob,
- "compression_ratio": result.compression_ratio,
- "no_speech_prob": result.no_speech_prob,
- }
- )
- if verbose:
- print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
- # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
- num_frames = mel.shape[-1]
- previous_seek_value = seek
- with tqdm.tqdm(
- total=num_frames, unit="frames", disable=verbose is not False
- ) as pbar:
- while seek < num_frames:
- timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
- segment = pad_or_trim(mel[:, seek:], N_FRAMES)
- segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
- decode_options["prompt"] = all_tokens[prompt_reset_since:]
- result: DecodingResult = decode_with_fallback(segment)
- tokens = paddle.to_tensor(result.tokens)
- if no_speech_threshold is not None:
- # no voice activity check
- should_skip = result.no_speech_prob > no_speech_threshold
- if (
- logprob_threshold is not None
- and result.avg_logprob > logprob_threshold
- ):
- # don't skip if the logprob is high enough, despite the no_speech_prob
- should_skip = False
- if should_skip:
- seek += segment.shape[
- -1
- ] # fast-forward to the next segment boundary
- continue
- timestamp_tokens: paddle.Tensor = tokens.greater_equal(
- paddle.to_tensor(tokenizer.timestamp_begin)
- )
- consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
- if (
- len(consecutive) > 0
- ): # if the output contains two consecutive timestamp tokens
- consecutive = paddle.add(consecutive, paddle.to_tensor(1))
- last_slice = 0
- for current_slice in consecutive:
- sliced_tokens = tokens[last_slice:current_slice]
- start_timestamp_position = (
- sliced_tokens[0].item() - tokenizer.timestamp_begin
- )
- end_timestamp_position = (
- sliced_tokens[-1].item() - tokenizer.timestamp_begin
- )
- add_segment(
- start=timestamp_offset
- + start_timestamp_position * time_precision,
- end=timestamp_offset + end_timestamp_position * time_precision,
- text_tokens=sliced_tokens[1:-1],
- result=result,
- )
- last_slice = current_slice
- last_timestamp_position = (
- tokens[last_slice - 1].item() - tokenizer.timestamp_begin
- )
- seek += last_timestamp_position * input_stride
- all_tokens.extend(tokens[: last_slice + 1].tolist())
- else:
- duration = segment_duration
- timestamps = tokens[timestamp_tokens.nonzero().flatten()]
- if (
- len(timestamps) > 0
- and timestamps[-1].item() != tokenizer.timestamp_begin
- ):
- # no consecutive timestamps but it has a timestamp; use the last one.
- # single timestamp at the end means no speech after the last timestamp.
- last_timestamp_position = (
- timestamps[-1].item() - tokenizer.timestamp_begin
- )
- duration = last_timestamp_position * time_precision
- add_segment(
- start=timestamp_offset,
- end=timestamp_offset + duration,
- text_tokens=tokens,
- result=result,
- )
- seek += segment.shape[-1]
- all_tokens.extend(tokens.tolist())
- if not condition_on_previous_text or result.temperature > 0.5:
- # do not feed the prompt tokens if a high temperature was used
- prompt_reset_since = len(all_tokens)
- # update progress bar
- pbar.update(min(num_frames, seek) - previous_seek_value)
- previous_seek_value = seek
- return dict(
- text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
- segments=all_segments,
- language=language,
- )
- class SequenceRanker:
- def rank(
- self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
- ) -> List[int]:
- """
- Given a list of groups of samples and their cumulative log probabilities,
- return the indices of the samples in each group to select as the final result
- """
- raise NotImplementedError
- class MaximumLikelihoodRanker(SequenceRanker):
- """
- Select the sample with the highest log probabilities, penalized using either
- a simple length normalization or Google NMT paper's length penalty
- """
- def __init__(self, length_penalty: Optional[float]):
- self.length_penalty = length_penalty
- def rank(self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]):
- def scores(logprobs, lengths):
- result = []
- for logprob, length in zip(logprobs, lengths):
- if self.length_penalty is None or self.length_penalty == "None":
- penalty = length
- else:
- # from the Google NMT paper
- penalty = ((5 + length) / 6) ** self.length_penalty
- result.append(logprob / penalty)
- return result
- # get the sequence with the highest score
- lengths = [[len(t) for t in s] for s in tokens]
- return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
- class TokenDecoder:
- def reset(self):
- """Initialize any stateful variables for decoding a new sequence"""
- def update(
- self,
- tokens: paddle.Tensor,
- logits: paddle.Tensor,
- sum_logprobs: paddle.Tensor,
- ) -> Tuple[paddle.Tensor, bool]:
- """Specify how to select the next token, based on the current trace and logits
- Parameters
- ----------
- tokens : Tensor, shape = (n_batch, current_sequence_length)
- all tokens in the context so far, including the prefix and sot_sequence tokens
- logits : Tensor, shape = (n_batch, vocab_size)
- per-token logits of the probability distribution at the current step
- sum_logprobs : Tensor, shape = (n_batch)
- cumulative log probabilities for each sequence
- Returns
- -------
- tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
- the tokens, appended with the selected next token
- completed : bool
- True if all sequences has reached the end of text
- """
- raise NotImplementedError
- def finalize(
- self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
- ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
- """Finalize search and return the final candidate sequences
- Parameters
- ----------
- tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
- all tokens in the context so far, including the prefix and sot_sequence
- sum_logprobs : Tensor, shape = (batch_size, beam_size)
- cumulative log probabilities for each sequence
- Returns
- -------
- tokens : Sequence[Sequence[Tensor]], length = batch_size
- sequence of Tensors containing candidate token sequences, for each audio input
- sum_logprobs : List[List[float]], length = batch_size
- sequence of cumulative log probabilities corresponding to the above
- """
- raise NotImplementedError
- class GreedyDecoder(TokenDecoder):
- def __init__(self, temperature: float, eot: int):
- self.temperature = temperature
- self.eot = eot
- def update(
- self,
- tokens: paddle.Tensor,
- logits: paddle.Tensor,
- sum_logprobs: paddle.Tensor,
- ) -> Tuple[paddle.Tensor, bool]:
- temperature = self.temperature
- if temperature == 0:
- next_tokens = paddle.argmax(logits, axis=-1)
- else:
- next_tokens = paddle.distribution.Categorical(
- logits=logits / temperature
- ).sample([1])
- next_tokens = paddle.reshape(
- next_tokens,
- [
- next_tokens.shape[0] * next_tokens.shape[1],
- ],
- )
- logprobs = paddle.nn.functional.log_softmax(
- logits, axis=-1, dtype=paddle.float32
- )
- current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
- sum_logprobs += current_logprobs * paddle.to_tensor(
- (tokens[:, -1] != self.eot), dtype=paddle.float32
- )
- next_tokens[tokens[:, -1] == self.eot] = self.eot
- tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
- completed = paddle.all((tokens[:, -1] == self.eot))
- return tokens, completed
- def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
- # make sure each sequence has at least one EOT token at the end
- tokens = paddle.nn.functional.pad(
- tokens, (0, 1), value=self.eot, data_format="NCL"
- )
- return tokens, sum_logprobs.tolist()
- class BeamSearchDecoder(TokenDecoder):
- def __init__(
- self,
- beam_size: int,
- eot: int,
- inference: Inference,
- patience: Optional[float] = None,
- ):
- self.beam_size = beam_size
- self.eot = eot
- self.inference = inference
- self.patience = patience or 1.0
- if patience is None or patience == "None":
- self.patience = 1.0
- else:
- self.patience = patience
- self.max_candidates: int = round(beam_size * self.patience)
- self.finished_sequences = None
- assert (
- self.max_candidates > 0
- ), f"Invalid beam size ({beam_size}) or patience ({patience})"
- def reset(self):
- self.finished_sequences = None
- def update(
- self,
- tokens: paddle.Tensor,
- logits: paddle.Tensor,
- sum_logprobs: paddle.Tensor,
- ) -> Tuple[paddle.Tensor, bool]:
- if tokens.shape[0] % self.beam_size != 0:
- raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
- batch_size = tokens.shape[0] // self.beam_size
- if self.finished_sequences is None: # for the first update
- self.finished_sequences = [{} for _ in range(batch_size)]
- logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
- next_tokens, source_indices, finished_sequences = [], [], []
- for i in range(batch_size):
- scores, sources, finished = {}, {}, {}
- # STEP 1: calculate the cumulative log probabilities for possible candidates
- for j in range(self.beam_size):
- idx = i * self.beam_size + j
- prefix = tokens[idx].tolist()
- logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
- for logprob, token in zip(logprob, token):
- new_logprob = (sum_logprobs[idx] + logprob).item()
- sequence = tuple(prefix + [token.item()])
- scores[sequence] = new_logprob
- sources[sequence] = idx
- # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
- saved = 0
- for sequence in sorted(scores, key=scores.get, reverse=True):
- if sequence[-1] == self.eot:
- finished[sequence] = scores[sequence]
- else:
- sum_logprobs[len(next_tokens)] = scores[sequence]
- next_tokens.append(sequence)
- source_indices.append(sources[sequence])
- saved += 1
- if saved == self.beam_size:
- break
- finished_sequences.append(finished)
- tokens = paddle.to_tensor(next_tokens)
- self.inference.rearrange_kv_cache(source_indices)
- # add newly finished sequences to self.finished_sequences
- assert len(self.finished_sequences) == len(finished_sequences)
- for previously_finished, newly_finished in zip(
- self.finished_sequences, finished_sequences
- ):
- for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
- if len(previously_finished) >= self.max_candidates:
- break # the candidate list is full
- previously_finished[seq] = newly_finished[seq]
- # mark as completed if all audio has enough number of samples
- completed = all(
- len(sequences) >= self.max_candidates
- for sequences in self.finished_sequences
- )
- return tokens, completed
- def finalize(self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
- # collect all finished sequences, including patience, and add unfinished ones if not enough
- sum_logprobs = sum_logprobs.cpu()
- for i, sequences in enumerate(self.finished_sequences):
- if (
- len(sequences) < self.beam_size
- ): # when not enough sequences are finished
- for j in list(np.argsort(sum_logprobs[i]))[::-1]:
- sequence = preceding_tokens[i, j].tolist() + [self.eot]
- sequences[tuple(sequence)] = sum_logprobs[i][j].item()
- if len(sequences) >= self.beam_size:
- break
- tokens: List[List[paddle.Tensor]] = [
- [paddle.to_tensor(seq) for seq in sequences.keys()]
- for sequences in self.finished_sequences
- ]
- sum_logprobs: List[List[float]] = [
- list(sequences.values()) for sequences in self.finished_sequences
- ]
- return tokens, sum_logprobs
- class LogitFilter:
- def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
- """Apply any filtering or masking to logits in-place
- Parameters
- ----------
- logits : Tensor, shape = (n_batch, vocab_size)
- per-token logits of the probability distribution at the current step
- tokens : Tensor, shape = (n_batch, current_sequence_length)
- all tokens in the context so far, including the prefix and sot_sequence tokens
- """
- raise NotImplementedError
- class SuppressBlank(LogitFilter):
- def __init__(self, tokenizer: Tokenizer, sample_begin: int):
- self.tokenizer = tokenizer
- self.sample_begin = sample_begin
- def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
- if tokens.shape[1] == self.sample_begin:
- logits.contiguous()
- logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
- -np.inf
- )
- class SuppressTokens(LogitFilter):
- def __init__(self, suppress_tokens: Sequence[int]):
- self.suppress_tokens = list(suppress_tokens)
- def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
- logits.contiguous()
- logits[:, self.suppress_tokens] = -np.inf
- class ApplyTimestampRules(LogitFilter):
- def __init__(
- self,
- tokenizer: Tokenizer,
- sample_begin: int,
- max_initial_timestamp_index: Optional[int],
- ):
- self.tokenizer = tokenizer
- self.sample_begin = sample_begin
- self.max_initial_timestamp_index = max_initial_timestamp_index
- def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
- # suppress <|notimestamps|> which is handled by without_timestamps
- if self.tokenizer.no_timestamps is not None:
- logits.contiguous()
- logits[:, self.tokenizer.no_timestamps] = -np.inf
- # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
- for k in range(tokens.shape[0]):
- seq = [t for t in tokens[k, self.sample_begin :].tolist()]
- last_was_timestamp = (
- len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
- )
- penultimate_was_timestamp = (
- len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
- )
- if last_was_timestamp:
- logits.contiguous()
- if penultimate_was_timestamp: # has to be non-timestamp
- logits[k, self.tokenizer.timestamp_begin :] = -np.inf
- else: # cannot be normal text tokens
- logits[k, : self.tokenizer.eot] = -np.inf
- # apply the `max_initial_timestamp` option
- if (
- tokens.shape[1] == self.sample_begin
- and self.max_initial_timestamp_index is not None
- ):
- last_allowed = (
- self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
- )
- logits.contiguous()
- logits[:, last_allowed + 1 :] = -np.inf
- # if sum of probability over timestamps is above any other token, sample timestamp
- logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
- for k in range(tokens.shape[0]):
- # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
- # To bypass this issue in CI, we have decomposed the operation into separate steps.
- # It will raise 2e-6 difference in precision.
- # TODO: revert this after logsumexp been fixed.
- timestamp_logprob = paddle.exp(
- logprobs[k, self.tokenizer.timestamp_begin :]
- )
- timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
- timestamp_logprob = paddle.log(timestamp_logprob)
- max_text_token_logprob = paddle.max(
- logprobs[k, : self.tokenizer.timestamp_begin]
- )
- if timestamp_logprob > max_text_token_logprob:
- logits.contiguous()
- logits[k, : self.tokenizer.timestamp_begin] = -np.inf
- class DecodingTask:
- inference: Inference
- sequence_ranker: SequenceRanker
- decoder: TokenDecoder
- logit_filters: List[LogitFilter]
- def __init__(self, model: "Whisper", options: DecodingOptions, resource_path: str):
- self.model = model
- language = options.language or "en"
- tokenizer = get_tokenizer(
- model.is_multilingual,
- resource_path=resource_path,
- language=language,
- task=options.task,
- )
- self.tokenizer: Tokenizer = tokenizer
- self.options: DecodingOptions = self._verify_options(options)
- self.resource_path: str = resource_path
- self.beam_size: int = options.beam_size or options.best_of or 1
- self.n_ctx: int = model.dims.n_text_ctx
- self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
- self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
- if self.options.without_timestamps:
- self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
- self.initial_tokens: Tuple[int] = self._get_initial_tokens()
- self.sample_begin: int = len(self.initial_tokens)
- self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
- # inference: implements the forward pass through the decoder, including kv caching
- self.inference = WhisperInference(model, len(self.initial_tokens))
- # sequence ranker: implements how to rank a group of sampled sequences
- self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
- # decoder: implements how to select the next tokens, given the autoregressive distribution
- if options.beam_size is not None:
- self.decoder = BeamSearchDecoder(
- options.beam_size, tokenizer.eot, self.inference, options.patience
- )
- else:
- self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
- # logit filters: applies various rules to suppress or penalize certain tokens
- self.logit_filters = []
- if self.options.suppress_blank:
- self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
- if self.options.suppress_tokens:
- self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
- if not options.without_timestamps:
- precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
- max_initial_timestamp_index = None
- if options.max_initial_timestamp:
- max_initial_timestamp_index = round(
- self.options.max_initial_timestamp / precision
- )
- self.logit_filters.append(
- ApplyTimestampRules(
- tokenizer, self.sample_begin, max_initial_timestamp_index
- )
- )
- def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
- if options.beam_size is not None and options.best_of is not None:
- raise ValueError("beam_size and best_of can't be given together")
- if options.temperature == 0:
- if options.best_of is not None:
- raise ValueError("best_of with greedy sampling (T=0) is not compatible")
- if options.patience is not None and options.beam_size is None:
- raise ValueError("patience requires beam_size to be given")
- if options.length_penalty is not None and options.length_penalty != "None":
- if not (0 <= options.length_penalty <= 1):
- raise ValueError(
- "length_penalty (alpha) should be a value between 0 and 1"
- )
- return options
- def _get_initial_tokens(self) -> Tuple[int]:
- tokens = list(self.sot_sequence)
- prefix = self.options.prefix
- prompt = self.options.prompt
- if prefix:
- prefix_tokens = (
- self.tokenizer.encode(" " + prefix.strip().input_ids)
- if isinstance(prefix, str)
- else prefix
- )
- if self.sample_len is not None:
- max_prefix_len = self.n_ctx // 2 - self.sample_len
- prefix_tokens = prefix_tokens[-max_prefix_len:]
- tokens = tokens + prefix_tokens
- if prompt:
- prompt_tokens = (
- self.tokenizer.encode(" " + prompt.strip().input_ids)
- if isinstance(prompt, str)
- else prompt
- )
- tokens = (
- [self.tokenizer.sot_prev]
- + prompt_tokens[-(self.n_ctx // 2 - 1) :]
- + tokens
- )
- return tuple(tokens)
- def _get_suppress_tokens(self) -> Tuple[int]:
- suppress_tokens = self.options.suppress_tokens
- if isinstance(suppress_tokens, str):
- suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
- if -1 in suppress_tokens:
- suppress_tokens = [t for t in suppress_tokens if t >= 0]
- suppress_tokens.extend(self.tokenizer.non_speech_tokens)
- elif suppress_tokens is None or len(suppress_tokens) == 0:
- suppress_tokens = [] # interpret empty string as an empty list
- else:
- assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
- suppress_tokens.extend(
- [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
- )
- if self.tokenizer.no_speech is not None:
- # no-speech probability is collected separately
- suppress_tokens.append(self.tokenizer.no_speech)
- return tuple(sorted(set(suppress_tokens)))
- def _get_audio_features(self, mel: paddle.Tensor):
- if mel.shape[-2:] == (
- self.model.dims.n_audio_ctx,
- self.model.dims.n_audio_state,
- ):
- # encoded audio features are given; skip audio encoding
- audio_features = mel
- else:
- audio_features = self.model.encoder(mel)
- return audio_features
- def _detect_language(
- self,
- audio_features: paddle.Tensor,
- tokens: paddle.Tensor,
- resource_path: str,
- ):
- languages = [self.options.language] * audio_features.shape[0]
- lang_probs = None
- if self.options.language is None or self.options.task == "lang_id":
- lang_tokens, lang_probs = self.model.detect_language(
- audio_features, self.tokenizer, self.resource_path
- )
- languages = [max(probs, key=probs.get) for probs in lang_probs]
- if self.options.language is None:
- tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
- return languages, lang_probs
- def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
- assert audio_features.shape[0] == tokens.shape[0]
- n_batch = tokens.shape[0]
- sum_logprobs: paddle.Tensor = paddle.zeros(
- paddle.to_tensor(n_batch), dtype=paddle.float32
- )
- no_speech_probs = [np.nan] * n_batch
- try:
- for i in range(self.sample_len):
- logits = self.inference.logits(tokens, audio_features)
- if (
- i == 0 and self.tokenizer.no_speech is not None
- ): # save no_speech_probs
- probs_at_sot = paddle.nn.functional.softmax(
- logits[:, self.sot_index], axis=-1, dtype=paddle.float32
- )
- no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
- # now we need to consider the logits at the last token only
- logits = logits[:, -1]
- # apply the logit filters, e.g. for suppressing or applying penalty to
- for logit_filter in self.logit_filters:
- logit_filter.apply(logits, tokens)
- # expand the tokens tensor with the selected next tokens
- tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
- if completed or tokens.shape[-1] > self.n_ctx:
- break
- finally:
- self.inference.cleanup_caching()
- return tokens, sum_logprobs, no_speech_probs
- @paddle.no_grad()
- def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
- self.decoder.reset()
- tokenizer: Tokenizer = self.tokenizer
- batch_size: int = mel.shape[0]
- audio_features: paddle.Tensor = self._get_audio_features(
- mel
- ) # encoder forward pass
- tokens: paddle.Tensor
- if batch_size > 1:
- for i in range(batch_size):
- tokens = paddle.concat(
- x=[
- paddle.to_tensor([self.initial_tokens]),
- paddle.to_tensor([self.initial_tokens]),
- ],
- axis=0,
- )
- elif batch_size == 1:
- tokens = paddle.to_tensor([self.initial_tokens])
- # detect language if requested, overwriting the language token
- languages, language_probs = self._detect_language(
- paddle.to_tensor(audio_features),
- paddle.to_tensor(tokens),
- self.resource_path,
- )
- if self.options.task == "lang_id":
- return [
- DecodingResult(
- audio_features=features, language=language, language_probs=probs
- )
- for features, language, probs in zip(
- audio_features, languages, language_probs
- )
- ]
- # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
- audio_features = paddle.repeat_interleave(
- audio_features, self.beam_size, axis=0
- )
- tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
- # call the main sampling loop
- tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
- # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
- audio_features = audio_features[:: self.beam_size]
- no_speech_probs = no_speech_probs[:: self.beam_size]
- assert audio_features.shape[0] == len(no_speech_probs) == batch_size
- tokens = tokens.reshape([batch_size, self.beam_size, -1])
- sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
- # get the final candidates for each group, and slice between the first sampled token and EOT
- tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
- tokens: List[List[paddle.Tensor]] = [
- [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
- for s in tokens
- ]
- # select the top-ranked sample in each group
- selected = self.sequence_ranker.rank(tokens, sum_logprobs)
- tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
- texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
- sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
- avg_logprobs: List[float] = [
- lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
- ]
- fields = (
- texts,
- languages,
- tokens,
- audio_features,
- avg_logprobs,
- no_speech_probs,
- )
- if len(set(map(len, fields))) != 1:
- raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
- return [
- DecodingResult(
- audio_features=features,
- language=language,
- tokens=tokens,
- text=text,
- avg_logprob=avg_logprob,
- no_speech_prob=no_speech_prob,
- temperature=self.options.temperature,
- compression_ratio=compression_ratio(text),
- )
- for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
- *fields
- )
- ]
- @paddle.no_grad()
- def decode(
- model: "Whisper",
- mel: paddle.Tensor,
- options: DecodingOptions = DecodingOptions(),
- resource_path=str,
- ) -> Union[DecodingResult, List[DecodingResult]]:
- """
- Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
- Parameters
- ----------
- model: Whisper
- the Whisper model instance
- mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
- A tensor containing the Mel spectrogram(s)
- options: DecodingOptions
- A dataclass that contains all necessary options for decoding 30-second segments
- Returns
- -------
- result: Union[DecodingResult, List[DecodingResult]]
- The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
- """
- single = mel.ndim == 2
- if single:
- mel = mel.unsqueeze(0)
- result = DecodingTask(model, options, resource_path).run(mel)
- if single:
- result = result[0]
- return result
- class Whisper(paddle.nn.Layer):
- """
- The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
- """
- def __init__(self, dims: ModelDimensions):
- super().__init__()
- self.dims = dims
- self.encoder = AudioEncoder(
- self.dims.n_mels,
- self.dims.n_audio_ctx,
- self.dims.n_audio_state,
- self.dims.n_audio_head,
- self.dims.n_audio_layer,
- )
- self.decoder = TextDecoder(
- self.dims.n_vocab,
- self.dims.n_text_ctx,
- self.dims.n_text_state,
- self.dims.n_text_head,
- self.dims.n_text_layer,
- )
- def embed_audio(self, mel: paddle.Tensor):
- return self.encoder.forward(mel)
- def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
- return self.decoder.forward(tokens, audio_features)
- def forward(
- self, mel: paddle.Tensor, tokens: paddle.Tensor
- ) -> Dict[str, paddle.Tensor]:
- return self.decoder(tokens, self.encoder(mel))
- @property
- def device(self):
- return paddle.device.get_device()
- @property
- def is_multilingual(self):
- return self.dims.n_vocab == 51865
- def install_kv_cache_hooks(self, cache: Optional[dict] = None):
- """
- The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
- tensors calculated for the previous positions. This method returns a dictionary that stores
- all caches, and the necessary hooks for the key and value projection modules that save the
- intermediate tensors to be reused during later calculations.
- Returns
- -------
- cache : Dict[nn.Layer, paddle.Tensor]
- A dictionary object mapping the key/value projection modules to its cache
- hooks : List[RemovableHandle]
- List of PyTorch RemovableHandle objects to stop the hooks to be called
- """
- cache = {**cache} if cache is not None else {}
- hooks = []
- def save_to_cache(module, _, output):
- if (
- module not in cache
- or output.shape[1] > self.decoder.positional_embedding.shape[0]
- ):
- cache[module] = (
- output # save as-is, for the first token or cross attention
- )
- else:
- cache[module] = paddle.concat([cache[module], output], axis=1).detach()
- return cache[module]
- def install_hooks(layer: paddle.nn.Layer):
- if isinstance(layer, MultiHeadAttention):
- hooks.append(layer.key.register_forward_post_hook(save_to_cache))
- hooks.append(layer.value.register_forward_post_hook(save_to_cache))
- self.decoder.apply(install_hooks)
- return cache, hooks
- detect_language = detect_language
- set_inference_operations(get_inference_operations() + ["speech_transcribe"])
- transcribe = benchmark.timeit_with_options(name="speech_transcribe")(transcribe)
- decode = decode
- def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
- """
- Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
- """
- if paddle.is_tensor(array):
- if array.shape[axis] > length:
- array = array.index_select(axis=axis, index=paddle.arange(length))
- if array.shape[axis] < length:
- pad_widths = [(0, 0)] * array.ndim
- pad_widths[axis] = (0, length - array.shape[axis])
- array = paddle.transpose(array, (1, 0))
- array = paddle.nn.functional.pad(
- array,
- [pad for sizes in pad_widths[::-1] for pad in sizes],
- data_format="NLC",
- )
- array = paddle.transpose(array, (1, 0))
- else:
- if array.shape[axis] > length:
- array = array.take(indices=range(length), axis=axis)
- if array.shape[axis] < length:
- pad_widths = [(0, 0)] * array.ndim
- pad_widths[axis] = (0, length - array.shape[axis])
- array = paddle.transpose(array, (1, 0))
- array = np.pad(array, pad_widths)
- array = paddle.transpose(array, (1, 0))
- return array
- def hann_window(n_fft: int = N_FFT):
- """
- hanning window
- n_fft: The number of frequency components of the discrete Fourier transform.
- """
- return paddle.to_tensor(
- [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
- dtype=paddle.float32,
- )
- @lru_cache(maxsize=None)
- def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
- """
- load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
- Allows decoupling librosa dependency; saved using:
- np.savez_compressed(
- "mel_filters.npz",
- mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
- )
- """
- assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
- with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
- return paddle.to_tensor(f[f"mel_{n_mels}"])
- @function_requires_deps("soundfile")
- def log_mel_spectrogram(
- audio: Union[str, np.ndarray, paddle.Tensor],
- n_mels: int = N_MELS,
- resource_path: str = None,
- ):
- """
- Compute the log-Mel spectrogram of
- Parameters
- ----------
- audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
- The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
- n_mels: int
- The number of Mel-frequency filters, only 80 is supported
- Returns
- -------
- paddle.Tensor, shape = (80, n_frames)
- A Tensor that contains the Mel spectrogram
- """
- if not paddle.is_tensor(audio):
- if isinstance(audio, str):
- audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
- audio = audio[:, 0]
- audio = paddle.to_tensor(audio)
- window = hann_window(N_FFT)
- stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
- magnitudes = stft[:, :-1].abs() ** 2
- filters = mel_filters(resource_path, n_mels)
- mel_spec = filters @ magnitudes
- mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
- log_spec = paddle.clip(mel_spec, min=1e-10).log10()
- log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
- log_spec = (log_spec + 4.0) / 4.0
- return log_spec
|