processors.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
  15. import os
  16. import zlib
  17. from dataclasses import dataclass, field
  18. from functools import lru_cache
  19. from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
  20. import numpy as np
  21. import paddle
  22. from ....utils.deps import function_requires_deps, is_dep_available
  23. from ..common.tokenizer import GPTTokenizer
  24. if is_dep_available("soundfile"):
  25. import soundfile
  26. if is_dep_available("tqdm"):
  27. import tqdm
  28. __all__ = [
  29. "Whisper",
  30. "Tokenizer",
  31. ]
  32. def exact_div(x, y):
  33. assert x % y == 0
  34. return x // y
  35. _MODELS = ["large"]
  36. SAMPLE_RATE = 16000
  37. N_FFT = 400
  38. N_MELS = 80
  39. HOP_LENGTH = 160
  40. CHUNK_LENGTH = 30
  41. N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
  42. N_FRAMES = exact_div(
  43. N_SAMPLES, HOP_LENGTH
  44. ) # 3000: number of frames in a mel spectrogram input
  45. @dataclass
  46. class ModelDimensions:
  47. n_mels: int
  48. n_audio_ctx: int
  49. n_audio_state: int
  50. n_audio_head: int
  51. n_audio_layer: int
  52. n_vocab: int
  53. n_text_ctx: int
  54. n_text_state: int
  55. n_text_head: int
  56. n_text_layer: int
  57. LANGUAGES = {
  58. "en": "english",
  59. "zh": "chinese",
  60. "de": "german",
  61. "es": "spanish",
  62. "ru": "russian",
  63. "ko": "korean",
  64. "fr": "french",
  65. "ja": "japanese",
  66. "pt": "portuguese",
  67. "tr": "turkish",
  68. "pl": "polish",
  69. "ca": "catalan",
  70. "nl": "dutch",
  71. "ar": "arabic",
  72. "sv": "swedish",
  73. "it": "italian",
  74. "id": "indonesian",
  75. "hi": "hindi",
  76. "fi": "finnish",
  77. "vi": "vietnamese",
  78. "iw": "hebrew",
  79. "uk": "ukrainian",
  80. "el": "greek",
  81. "ms": "malay",
  82. "cs": "czech",
  83. "ro": "romanian",
  84. "da": "danish",
  85. "hu": "hungarian",
  86. "ta": "tamil",
  87. "no": "norwegian",
  88. "th": "thai",
  89. "ur": "urdu",
  90. "hr": "croatian",
  91. "bg": "bulgarian",
  92. "lt": "lithuanian",
  93. "la": "latin",
  94. "mi": "maori",
  95. "ml": "malayalam",
  96. "cy": "welsh",
  97. "sk": "slovak",
  98. "te": "telugu",
  99. "fa": "persian",
  100. "lv": "latvian",
  101. "bn": "bengali",
  102. "sr": "serbian",
  103. "az": "azerbaijani",
  104. "sl": "slovenian",
  105. "kn": "kannada",
  106. "et": "estonian",
  107. "mk": "macedonian",
  108. "br": "breton",
  109. "eu": "basque",
  110. "is": "icelandic",
  111. "hy": "armenian",
  112. "ne": "nepali",
  113. "mn": "mongolian",
  114. "bs": "bosnian",
  115. "kk": "kazakh",
  116. "sq": "albanian",
  117. "sw": "swahili",
  118. "gl": "galician",
  119. "mr": "marathi",
  120. "pa": "punjabi",
  121. "si": "sinhala",
  122. "km": "khmer",
  123. "sn": "shona",
  124. "yo": "yoruba",
  125. "so": "somali",
  126. "af": "afrikaans",
  127. "oc": "occitan",
  128. "ka": "georgian",
  129. "be": "belarusian",
  130. "tg": "tajik",
  131. "sd": "sindhi",
  132. "gu": "gujarati",
  133. "am": "amharic",
  134. "yi": "yiddish",
  135. "lo": "lao",
  136. "uz": "uzbek",
  137. "fo": "faroese",
  138. "ht": "haitian creole",
  139. "ps": "pashto",
  140. "tk": "turkmen",
  141. "nn": "nynorsk",
  142. "mt": "maltese",
  143. "sa": "sanskrit",
  144. "lb": "luxembourgish",
  145. "my": "myanmar",
  146. "bo": "tibetan",
  147. "tl": "tagalog",
  148. "mg": "malagasy",
  149. "as": "assamese",
  150. "tt": "tatar",
  151. "haw": "hawaiian",
  152. "ln": "lingala",
  153. "ha": "hausa",
  154. "ba": "bashkir",
  155. "jw": "javanese",
  156. "su": "sundanese",
  157. }
  158. # language code lookup by name, with a few language aliases
  159. TO_LANGUAGE_CODE = {
  160. **{language: code for code, language in LANGUAGES.items()},
  161. "burmese": "my",
  162. "valencian": "ca",
  163. "flemish": "nl",
  164. "haitian": "ht",
  165. "letzeburgesch": "lb",
  166. "pushto": "ps",
  167. "panjabi": "pa",
  168. "moldavian": "ro",
  169. "moldovan": "ro",
  170. "sinhalese": "si",
  171. "castilian": "es",
  172. }
  173. def compression_ratio(text) -> float:
  174. return len(text) / len(zlib.compress(text.encode("utf-8")))
  175. def format_timestamp(
  176. seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
  177. ):
  178. assert seconds >= 0, "non-negative timestamp expected"
  179. milliseconds = round(seconds * 1000.0)
  180. hours = milliseconds // 3_600_000
  181. milliseconds -= hours * 3_600_000
  182. minutes = milliseconds // 60_000
  183. milliseconds -= minutes * 60_000
  184. seconds = milliseconds // 1_000
  185. milliseconds -= seconds * 1_000
  186. hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
  187. return (
  188. f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
  189. )
  190. @dataclass(frozen=True)
  191. class Tokenizer:
  192. """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
  193. tokenizer: "GPTTokenizer"
  194. language: Optional[str]
  195. sot_sequence: Tuple[int]
  196. def encode(self, text, **kwargs):
  197. return self.tokenizer.encode(text, **kwargs)
  198. def decode(
  199. self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
  200. ):
  201. if len(token_ids) > 1:
  202. ids_list = []
  203. for ids in token_ids:
  204. if paddle.is_tensor(ids):
  205. ids = ids.item()
  206. if ids < len(self.tokenizer):
  207. ids_list.append(ids)
  208. token_ids = ids_list
  209. elif len(token_ids) == 1:
  210. token_ids = token_ids[0]
  211. else:
  212. raise ValueError(f"token_ids {token_ids} load error.")
  213. return self.tokenizer.decode(token_ids, **kwargs)
  214. def decode_with_timestamps(self, tokens) -> str:
  215. """
  216. Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
  217. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
  218. """
  219. outputs = [[]]
  220. for token in tokens:
  221. if token >= self.timestamp_begin:
  222. timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
  223. outputs.append(timestamp)
  224. outputs.append([])
  225. else:
  226. outputs[-1].append(token)
  227. outputs = [
  228. s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
  229. ]
  230. return "".join(outputs)
  231. @property
  232. @lru_cache()
  233. def eot(self) -> int:
  234. return self.tokenizer.eos_token_id
  235. @property
  236. @lru_cache()
  237. def sot(self) -> int:
  238. return self._get_single_token_id("<|startoftranscript|>")
  239. @property
  240. @lru_cache()
  241. def sot_lm(self) -> int:
  242. return self._get_single_token_id("<|startoflm|>")
  243. @property
  244. @lru_cache()
  245. def sot_prev(self) -> int:
  246. return self._get_single_token_id("<|startofprev|>")
  247. @property
  248. @lru_cache()
  249. def no_speech(self) -> int:
  250. return self._get_single_token_id("<|nospeech|>")
  251. @property
  252. @lru_cache()
  253. def no_timestamps(self) -> int:
  254. return self._get_single_token_id("<|notimestamps|>")
  255. @property
  256. @lru_cache()
  257. def timestamp_begin(self) -> int:
  258. return self.tokenizer.all_special_ids[-1] + 1
  259. @property
  260. @lru_cache()
  261. def language_token(self) -> int:
  262. """Returns the token id corresponding to the value of the `language` field"""
  263. if self.language is None:
  264. raise ValueError("This tokenizer does not have language token configured")
  265. additional_tokens = dict(
  266. zip(
  267. self.tokenizer.additional_special_tokens,
  268. self.tokenizer.additional_special_tokens_ids,
  269. )
  270. )
  271. candidate = f"<|{self.language}|>"
  272. if candidate in additional_tokens:
  273. return additional_tokens[candidate]
  274. raise KeyError(f"Language {self.language} not found in tokenizer.")
  275. @property
  276. @lru_cache()
  277. def all_language_tokens(self) -> Tuple[int]:
  278. result = []
  279. for token, token_id in zip(
  280. self.tokenizer.additional_special_tokens,
  281. self.tokenizer.additional_special_tokens_ids,
  282. ):
  283. if token.strip("<|>") in LANGUAGES:
  284. result.append(token_id)
  285. return tuple(result)
  286. @property
  287. @lru_cache()
  288. def all_language_codes(self) -> Tuple[str]:
  289. return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
  290. @property
  291. @lru_cache()
  292. def sot_sequence_including_notimestamps(self) -> Tuple[int]:
  293. return tuple(list(self.sot_sequence) + [self.no_timestamps])
  294. @property
  295. @lru_cache()
  296. def non_speech_tokens(self) -> Tuple[int]:
  297. """
  298. Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
  299. annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
  300. - ♪♪♪
  301. - ( SPEAKING FOREIGN LANGUAGE )
  302. - [DAVID] Hey there,
  303. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
  304. """
  305. symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
  306. symbols += (
  307. "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
  308. )
  309. # symbols that may be a single token or multiple tokens depending on the tokenizer.
  310. # In case they're multiple tokens, suppress the first token, which is safe because:
  311. # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
  312. # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
  313. miscellaneous = set("♩♪♫♬♭♮♯")
  314. assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
  315. # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
  316. result = {
  317. self.tokenizer.encode(" -").input_ids[0],
  318. self.tokenizer.encode(" '").input_ids[0],
  319. }
  320. for symbol in symbols + list(miscellaneous):
  321. for tokens in [
  322. self.tokenizer.encode(symbol).input_ids,
  323. self.tokenizer.encode(" " + symbol).input_ids,
  324. ]:
  325. if len(tokens) == 1 or symbol in miscellaneous:
  326. result.add(tokens[0])
  327. return tuple(sorted(result))
  328. def _get_single_token_id(self, text) -> int:
  329. tokens = self.tokenizer.encode(text).input_ids
  330. assert len(tokens) == 1, f"{text} is not encoded as a single token"
  331. return tokens[0]
  332. @lru_cache(maxsize=None)
  333. def build_tokenizer(resource_path: str, name: str = "gpt2"):
  334. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  335. path = os.path.join(resource_path, "assets", name)
  336. tokenizer = GPTTokenizer.from_pretrained(path)
  337. specials = [
  338. "<|startoftranscript|>",
  339. *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
  340. "<|translate|>",
  341. "<|transcribe|>",
  342. "<|startoflm|>",
  343. "<|startofprev|>",
  344. "<|nospeech|>",
  345. "<|notimestamps|>",
  346. ]
  347. tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
  348. return tokenizer
  349. @lru_cache(maxsize=None)
  350. def get_tokenizer(
  351. multilingual: bool,
  352. resource_path: str,
  353. *,
  354. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  355. language: Optional[str] = None,
  356. ) -> Tokenizer:
  357. if language is not None:
  358. language = language.lower()
  359. if language not in LANGUAGES:
  360. if language in TO_LANGUAGE_CODE:
  361. language = TO_LANGUAGE_CODE[language]
  362. else:
  363. raise ValueError(f"Unsupported language: {language}")
  364. if multilingual:
  365. tokenizer_name = "multilingual"
  366. task = task or "transcribe"
  367. language = language or "en"
  368. else:
  369. tokenizer_name = "gpt2"
  370. task = None
  371. language = None
  372. tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
  373. all_special_ids: List[int] = tokenizer.all_special_ids
  374. sot: int = all_special_ids[1]
  375. translate: int = all_special_ids[-6]
  376. transcribe: int = all_special_ids[-5]
  377. langs = tuple(LANGUAGES.keys())
  378. sot_sequence = [sot]
  379. if language is not None:
  380. sot_sequence.append(sot + 1 + langs.index(language))
  381. if task is not None:
  382. sot_sequence.append(transcribe if task == "transcribe" else translate)
  383. return Tokenizer(
  384. tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
  385. )
  386. class MultiHeadAttention(paddle.nn.Layer):
  387. def __init__(self, n_state: int, n_head: int):
  388. super().__init__()
  389. self.n_head = n_head
  390. self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  391. self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
  392. self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  393. self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  394. def forward(
  395. self,
  396. x: paddle.Tensor,
  397. xa: Optional[paddle.Tensor] = None,
  398. mask: Optional[paddle.Tensor] = None,
  399. kv_cache: Optional[dict] = None,
  400. ):
  401. q = self.query(x)
  402. if kv_cache is None or xa is None or self.key not in kv_cache:
  403. # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
  404. # otherwise, perform key/value projections for self- or cross-attention as usual.
  405. k = self.key(x if xa is None else xa)
  406. v = self.value(x if xa is None else xa)
  407. else:
  408. # for cross-attention, calculate keys and values once and reuse in subsequent calls.
  409. k = kv_cache[self.key]
  410. v = kv_cache[self.value]
  411. wv = self.qkv_attention(q, k, v, mask)
  412. return self.out(wv)
  413. def qkv_attention(
  414. self,
  415. q: paddle.Tensor,
  416. k: paddle.Tensor,
  417. v: paddle.Tensor,
  418. mask: Optional[paddle.Tensor] = None,
  419. ):
  420. n_batch, n_ctx, n_state = q.shape
  421. scale = (n_state // self.n_head) ** -0.25
  422. q = (
  423. paddle.transpose(q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
  424. * scale
  425. )
  426. k = (
  427. paddle.transpose(k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1))
  428. * scale
  429. )
  430. v = paddle.transpose(v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
  431. qk = q @ k
  432. if mask is not None:
  433. qk = qk + mask[:n_ctx, :n_ctx]
  434. w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
  435. return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
  436. class ResidualAttentionBlock(paddle.nn.Layer):
  437. def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
  438. super().__init__()
  439. self.attn = MultiHeadAttention(n_state, n_head)
  440. self.attn_ln = paddle.nn.LayerNorm(n_state)
  441. self.cross_attn = (
  442. MultiHeadAttention(n_state, n_head) if cross_attention else None
  443. )
  444. self.cross_attn_ln = paddle.nn.LayerNorm(n_state) if cross_attention else None
  445. n_mlp = n_state * 4
  446. self.mlp = paddle.nn.Sequential(
  447. paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
  448. paddle.nn.GELU(),
  449. paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
  450. )
  451. self.mlp_ln = paddle.nn.LayerNorm(n_state)
  452. def forward(
  453. self,
  454. x: paddle.Tensor,
  455. xa: Optional[paddle.Tensor] = None,
  456. mask: Optional[paddle.Tensor] = None,
  457. kv_cache: Optional[dict] = None,
  458. ):
  459. x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
  460. if self.cross_attn:
  461. x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
  462. x = x + self.mlp(self.mlp_ln(x))
  463. return x
  464. def sinusoids(length, channels, max_timescale=10000):
  465. """Returns sinusoids for positional embedding"""
  466. assert channels % 2 == 0
  467. log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
  468. inv_timescales = paddle.exp(
  469. -log_timescale_increment * paddle.arange(channels // 2, dtype=paddle.float32)
  470. )
  471. scaled_time = (
  472. paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
  473. * inv_timescales[np.newaxis, :]
  474. )
  475. return paddle.to_tensor(
  476. paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
  477. )
  478. class AudioEncoder(paddle.nn.Layer):
  479. def __init__(
  480. self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  481. ):
  482. super().__init__()
  483. self.conv1 = paddle.nn.Conv1D(
  484. n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
  485. )
  486. self.conv2 = paddle.nn.Conv1D(
  487. n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
  488. )
  489. self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
  490. self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
  491. [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
  492. )
  493. self.ln_post = paddle.nn.LayerNorm(n_state)
  494. def forward(self, x: paddle.Tensor):
  495. """
  496. x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
  497. the mel spectrogram of the audio
  498. """
  499. x = paddle.nn.functional.gelu(self.conv1(x))
  500. x = paddle.nn.functional.gelu(self.conv2(x))
  501. x = paddle.transpose(x, (0, 2, 1))
  502. assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
  503. x = x + self.positional_embedding
  504. for block in self.blocks:
  505. x = block(x)
  506. x = self.ln_post(x)
  507. return x
  508. class TextDecoder(paddle.nn.Layer):
  509. def __init__(
  510. self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  511. ):
  512. super().__init__()
  513. self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
  514. self.positional_embedding = paddle.create_parameter(
  515. shape=[n_ctx, n_state], dtype="float32"
  516. )
  517. self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
  518. [
  519. ResidualAttentionBlock(n_state, n_head, cross_attention=True)
  520. for _ in range(n_layer)
  521. ]
  522. )
  523. self.ln = paddle.nn.LayerNorm(n_state)
  524. mask = paddle.full(shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32")
  525. mask = paddle.triu(mask, diagonal=1)
  526. self.register_buffer("mask", mask, persistable=False)
  527. def forward(
  528. self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
  529. ):
  530. """
  531. x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
  532. the text tokens
  533. xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
  534. the encoded audio features to be attended on
  535. """
  536. offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
  537. x = (
  538. self.token_embedding(x)
  539. + self.positional_embedding[offset : offset + x.shape[-1]]
  540. )
  541. x = x.to(xa.dtype)
  542. for block in self.blocks:
  543. x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
  544. x = self.ln(x)
  545. logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
  546. return logits
  547. @dataclass(frozen=True)
  548. class DecodingOptions:
  549. task: str = (
  550. "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
  551. )
  552. language: Optional[str] = (
  553. None # language that the audio is in; uses detected language if None
  554. )
  555. # sampling-related options
  556. temperature: float = 0.0
  557. sample_len: Optional[int] = None # maximum number of tokens to sample
  558. best_of: Optional[int] = (
  559. None # number of independent samples to collect, when t > 0
  560. )
  561. beam_size: Optional[int] = None # number of beams in beam search, when t == 0
  562. patience: Optional[float] = (
  563. None # patience in beam search (https://arxiv.org/abs/2204.05424)
  564. )
  565. # options for ranking generations (either beams or best-of-N samples)
  566. length_penalty: Optional[float] = (
  567. None # "alpha" in Google NMT, None defaults to length norm
  568. )
  569. # prompt, prefix, and token suppression
  570. prompt: Optional[Union[str, List[int]]] = (
  571. None # text or tokens for the previous context
  572. )
  573. prefix: Optional[Union[str, List[int]]] = (
  574. None # text or tokens to prefix the current context
  575. )
  576. suppress_blank: bool = True # this will suppress blank outputs
  577. # list of tokens ids (or comma-separated token ids) to suppress
  578. # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
  579. suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
  580. # timestamp sampling options
  581. without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
  582. max_initial_timestamp: Optional[float] = (
  583. 1.0 # the initial timestamp cannot be later than this
  584. )
  585. # implementation details
  586. fp16: bool = False # use fp16 for most of the calculation
  587. @dataclass(frozen=True)
  588. class DecodingResult:
  589. audio_features: paddle.Tensor
  590. language: str
  591. language_probs: Optional[Dict[str, float]] = None
  592. tokens: List[int] = field(default_factory=list)
  593. text: str = ""
  594. avg_logprob: float = np.nan
  595. no_speech_prob: float = np.nan
  596. temperature: float = np.nan
  597. compression_ratio: float = np.nan
  598. class Inference:
  599. def logits(
  600. self, tokens: paddle.Tensor, audio_features: paddle.Tensor
  601. ) -> paddle.Tensor:
  602. """Perform a forward pass on the decoder and return per-token logits"""
  603. raise NotImplementedError
  604. def rearrange_kv_cache(self, source_indices) -> None:
  605. """Update the key-value cache according to the updated beams"""
  606. raise NotImplementedError
  607. def cleanup_caching(self) -> None:
  608. """Clean up any resources or hooks after decoding is finished"""
  609. class WhisperInference(Inference):
  610. def __init__(self, model: "Whisper", initial_token_length: int):
  611. self.model: "Whisper" = model
  612. self.initial_token_length = initial_token_length
  613. self.kv_cache = {}
  614. self.hooks = []
  615. def logits(
  616. self, tokens: paddle.Tensor, audio_features: paddle.Tensor
  617. ) -> paddle.Tensor:
  618. if not self.kv_cache:
  619. self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
  620. if tokens.shape[-1] > self.initial_token_length:
  621. # only need to use the last token except in the first forward pass
  622. tokens = tokens[:, -1:]
  623. return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
  624. def cleanup_caching(self):
  625. for hook in self.hooks:
  626. hook.remove()
  627. self.kv_cache = {}
  628. self.hooks = []
  629. def rearrange_kv_cache(self, source_indices):
  630. for module, tensor in self.kv_cache.items():
  631. # update the key/value cache to contain the selected sequences
  632. self.kv_cache[module] = tensor[source_indices].detach()
  633. @paddle.no_grad()
  634. def detect_language(
  635. model: "Whisper",
  636. mel: paddle.Tensor,
  637. resource_path: str,
  638. tokenizer: Tokenizer = None,
  639. ) -> Tuple[paddle.Tensor, List[dict]]:
  640. """
  641. Detect the spoken language in the audio, and return them as list of strings, along with the ids
  642. of the most probable language tokens and the probability distribution over all language tokens.
  643. This is performed outside the main decode loop in order to not interfere with kv-caching.
  644. Returns
  645. -------
  646. language_tokens : Tensor, shape = (batch_size,)
  647. ids of the most probable language tokens, which appears after the startoftranscript token.
  648. language_probs : List[Dict[str, float]], length = batch_size
  649. list of dictionaries containing the probability distribution over all languages.
  650. """
  651. if tokenizer is None:
  652. tokenizer = get_tokenizer(model.is_multilingual, resource_path=resource_path)
  653. if (
  654. tokenizer.language is None
  655. or tokenizer.language_token not in tokenizer.sot_sequence
  656. ):
  657. raise ValueError(
  658. "This model doesn't have language tokens so it can't perform lang id"
  659. )
  660. single = mel.ndim == 2
  661. if single:
  662. mel = mel.unsqueeze(0)
  663. # skip encoder forward pass if already-encoded audio features were given
  664. if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
  665. mel = model.encoder(mel)
  666. # forward pass using a single token, startoftranscript
  667. batch_size = mel.shape[0]
  668. x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
  669. logits = model.logits(x, mel)[:, 0]
  670. # collect detected languages; suppress all non-language tokens
  671. mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
  672. mask[list(tokenizer.all_language_tokens)] = False
  673. logits[:, mask] = -np.inf
  674. language_tokens = paddle.argmax(logits, axis=-1)
  675. language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
  676. language_probs = [
  677. {
  678. c: language_token_probs[i, j].tolist()
  679. for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
  680. }
  681. for i in range(batch_size)
  682. ]
  683. if single:
  684. language_tokens = language_tokens[0]
  685. language_probs = language_probs[0]
  686. return language_tokens, language_probs
  687. @function_requires_deps("tqdm")
  688. def transcribe(
  689. model: "Whisper",
  690. mel: paddle.Tensor,
  691. resource_path: str,
  692. *,
  693. verbose: Optional[bool] = None,
  694. temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
  695. compression_ratio_threshold: Optional[float] = 2.4,
  696. logprob_threshold: Optional[float] = -1.0,
  697. no_speech_threshold: Optional[float] = 0.6,
  698. condition_on_previous_text: bool = True,
  699. **decode_options,
  700. ):
  701. """
  702. Transcribe an audio file using Whisper
  703. Parameters
  704. ----------
  705. model: Whisper
  706. The Whisper model instance
  707. mel: paddle.Tensor
  708. The audio feature
  709. verbose: bool
  710. Whether to display the text being decoded to the console. If True, displays all the details,
  711. If False, displays minimal details. If None, does not display anything
  712. temperature: Union[float, Tuple[float, ...]]
  713. Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
  714. upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
  715. compression_ratio_threshold: float
  716. If the gzip compression ratio is above this value, treat as failed
  717. logprob_threshold: float
  718. If the average log probability over sampled tokens is below this value, treat as failed
  719. no_speech_threshold: float
  720. If the no_speech probability is higher than this value AND the average log probability
  721. over sampled tokens is below `logprob_threshold`, consider the segment as silent
  722. condition_on_previous_text: bool
  723. if True, the previous output of the model is provided as a prompt for the next window;
  724. disabling may make the text inconsistent across windows, but the model becomes less prone to
  725. getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
  726. decode_options: dict
  727. Keyword arguments to construct `DecodingOptions` instances
  728. Returns
  729. -------
  730. A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
  731. the spoken language ("language"), which is detected when `decode_options["language"]` is None.
  732. """
  733. dtype = np.float32 # paddle only support float32
  734. if dtype == np.float32:
  735. decode_options["fp16"] = False
  736. if (
  737. decode_options.get("language") == "None"
  738. or decode_options.get("language", None) is None
  739. ):
  740. if not model.is_multilingual:
  741. decode_options["language"] = "en"
  742. else:
  743. if verbose:
  744. print(
  745. "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
  746. )
  747. segment = pad_or_trim(mel, N_FRAMES)
  748. _, probs = model.detect_language(segment, resource_path)
  749. decode_options["language"] = max(probs, key=probs.get)
  750. if verbose is not None:
  751. print(
  752. f"Detected language: {LANGUAGES[decode_options['language']].title()}"
  753. )
  754. language = decode_options["language"]
  755. task = decode_options.get("task", "transcribe")
  756. tokenizer = get_tokenizer(
  757. model.is_multilingual,
  758. resource_path=resource_path,
  759. language=language,
  760. task=task,
  761. )
  762. def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
  763. temperatures = (
  764. [temperature] if isinstance(temperature, (int, float)) else temperature
  765. )
  766. decode_result = None
  767. for t in temperatures:
  768. kwargs = {**decode_options}
  769. if t > 0:
  770. # disable beam_size and patience when t > 0
  771. kwargs.pop("beam_size", None)
  772. kwargs.pop("patience", None)
  773. else:
  774. # disable best_of when t == 0
  775. kwargs.pop("best_of", None)
  776. options = DecodingOptions(**kwargs, temperature=t)
  777. decode_result = model.decode(segment, options, resource_path)
  778. needs_fallback = False
  779. if (
  780. compression_ratio_threshold is not None
  781. and decode_result.compression_ratio > compression_ratio_threshold
  782. ):
  783. needs_fallback = True # too repetitive
  784. if (
  785. logprob_threshold is not None
  786. and decode_result.avg_logprob < logprob_threshold
  787. ):
  788. needs_fallback = True # average log probability is too low
  789. if not needs_fallback:
  790. break
  791. return decode_result
  792. seek = 0
  793. input_stride = exact_div(
  794. N_FRAMES, model.dims.n_audio_ctx
  795. ) # mel frames per output token: 2
  796. time_precision = (
  797. input_stride * HOP_LENGTH / SAMPLE_RATE
  798. ) # time per output token: 0.02 (seconds)
  799. all_tokens = []
  800. all_segments = []
  801. prompt_reset_since = 0
  802. initial_prompt = decode_options.pop("initial_prompt", None)
  803. if initial_prompt and initial_prompt != "None":
  804. initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
  805. all_tokens.extend(initial_prompt)
  806. else:
  807. initial_prompt = []
  808. def add_segment(
  809. *,
  810. start: float,
  811. end: float,
  812. text_tokens: paddle.Tensor,
  813. result: DecodingResult,
  814. ):
  815. text = tokenizer.decode(
  816. [token for token in text_tokens if token < tokenizer.eot]
  817. )
  818. if len(text.strip()) == 0: # skip empty text output
  819. return
  820. all_segments.append(
  821. {
  822. "id": len(all_segments),
  823. "seek": seek,
  824. "start": start,
  825. "end": end,
  826. "text": text,
  827. "tokens": result.tokens,
  828. "temperature": result.temperature,
  829. "avg_logprob": result.avg_logprob,
  830. "compression_ratio": result.compression_ratio,
  831. "no_speech_prob": result.no_speech_prob,
  832. }
  833. )
  834. if verbose:
  835. print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
  836. # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
  837. num_frames = mel.shape[-1]
  838. previous_seek_value = seek
  839. with tqdm.tqdm(
  840. total=num_frames, unit="frames", disable=verbose is not False
  841. ) as pbar:
  842. while seek < num_frames:
  843. timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
  844. segment = pad_or_trim(mel[:, seek:], N_FRAMES)
  845. segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
  846. decode_options["prompt"] = all_tokens[prompt_reset_since:]
  847. result: DecodingResult = decode_with_fallback(segment)
  848. tokens = paddle.to_tensor(result.tokens)
  849. if no_speech_threshold is not None:
  850. # no voice activity check
  851. should_skip = result.no_speech_prob > no_speech_threshold
  852. if (
  853. logprob_threshold is not None
  854. and result.avg_logprob > logprob_threshold
  855. ):
  856. # don't skip if the logprob is high enough, despite the no_speech_prob
  857. should_skip = False
  858. if should_skip:
  859. seek += segment.shape[
  860. -1
  861. ] # fast-forward to the next segment boundary
  862. continue
  863. timestamp_tokens: paddle.Tensor = tokens.greater_equal(
  864. paddle.to_tensor(tokenizer.timestamp_begin)
  865. )
  866. consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
  867. if (
  868. len(consecutive) > 0
  869. ): # if the output contains two consecutive timestamp tokens
  870. consecutive = paddle.add(consecutive, paddle.to_tensor(1))
  871. last_slice = 0
  872. for current_slice in consecutive:
  873. sliced_tokens = tokens[last_slice:current_slice]
  874. start_timestamp_position = (
  875. sliced_tokens[0].item() - tokenizer.timestamp_begin
  876. )
  877. end_timestamp_position = (
  878. sliced_tokens[-1].item() - tokenizer.timestamp_begin
  879. )
  880. add_segment(
  881. start=timestamp_offset
  882. + start_timestamp_position * time_precision,
  883. end=timestamp_offset + end_timestamp_position * time_precision,
  884. text_tokens=sliced_tokens[1:-1],
  885. result=result,
  886. )
  887. last_slice = current_slice
  888. last_timestamp_position = (
  889. tokens[last_slice - 1].item() - tokenizer.timestamp_begin
  890. )
  891. seek += last_timestamp_position * input_stride
  892. all_tokens.extend(tokens[: last_slice + 1].tolist())
  893. else:
  894. duration = segment_duration
  895. timestamps = tokens[timestamp_tokens.nonzero().flatten()]
  896. if (
  897. len(timestamps) > 0
  898. and timestamps[-1].item() != tokenizer.timestamp_begin
  899. ):
  900. # no consecutive timestamps but it has a timestamp; use the last one.
  901. # single timestamp at the end means no speech after the last timestamp.
  902. last_timestamp_position = (
  903. timestamps[-1].item() - tokenizer.timestamp_begin
  904. )
  905. duration = last_timestamp_position * time_precision
  906. add_segment(
  907. start=timestamp_offset,
  908. end=timestamp_offset + duration,
  909. text_tokens=tokens,
  910. result=result,
  911. )
  912. seek += segment.shape[-1]
  913. all_tokens.extend(tokens.tolist())
  914. if not condition_on_previous_text or result.temperature > 0.5:
  915. # do not feed the prompt tokens if a high temperature was used
  916. prompt_reset_since = len(all_tokens)
  917. # update progress bar
  918. pbar.update(min(num_frames, seek) - previous_seek_value)
  919. previous_seek_value = seek
  920. return dict(
  921. text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
  922. segments=all_segments,
  923. language=language,
  924. )
  925. class SequenceRanker:
  926. def rank(
  927. self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
  928. ) -> List[int]:
  929. """
  930. Given a list of groups of samples and their cumulative log probabilities,
  931. return the indices of the samples in each group to select as the final result
  932. """
  933. raise NotImplementedError
  934. class MaximumLikelihoodRanker(SequenceRanker):
  935. """
  936. Select the sample with the highest log probabilities, penalized using either
  937. a simple length normalization or Google NMT paper's length penalty
  938. """
  939. def __init__(self, length_penalty: Optional[float]):
  940. self.length_penalty = length_penalty
  941. def rank(self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]):
  942. def scores(logprobs, lengths):
  943. result = []
  944. for logprob, length in zip(logprobs, lengths):
  945. if self.length_penalty is None or self.length_penalty == "None":
  946. penalty = length
  947. else:
  948. # from the Google NMT paper
  949. penalty = ((5 + length) / 6) ** self.length_penalty
  950. result.append(logprob / penalty)
  951. return result
  952. # get the sequence with the highest score
  953. lengths = [[len(t) for t in s] for s in tokens]
  954. return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
  955. class TokenDecoder:
  956. def reset(self):
  957. """Initialize any stateful variables for decoding a new sequence"""
  958. def update(
  959. self,
  960. tokens: paddle.Tensor,
  961. logits: paddle.Tensor,
  962. sum_logprobs: paddle.Tensor,
  963. ) -> Tuple[paddle.Tensor, bool]:
  964. """Specify how to select the next token, based on the current trace and logits
  965. Parameters
  966. ----------
  967. tokens : Tensor, shape = (n_batch, current_sequence_length)
  968. all tokens in the context so far, including the prefix and sot_sequence tokens
  969. logits : Tensor, shape = (n_batch, vocab_size)
  970. per-token logits of the probability distribution at the current step
  971. sum_logprobs : Tensor, shape = (n_batch)
  972. cumulative log probabilities for each sequence
  973. Returns
  974. -------
  975. tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
  976. the tokens, appended with the selected next token
  977. completed : bool
  978. True if all sequences has reached the end of text
  979. """
  980. raise NotImplementedError
  981. def finalize(
  982. self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
  983. ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
  984. """Finalize search and return the final candidate sequences
  985. Parameters
  986. ----------
  987. tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
  988. all tokens in the context so far, including the prefix and sot_sequence
  989. sum_logprobs : Tensor, shape = (batch_size, beam_size)
  990. cumulative log probabilities for each sequence
  991. Returns
  992. -------
  993. tokens : Sequence[Sequence[Tensor]], length = batch_size
  994. sequence of Tensors containing candidate token sequences, for each audio input
  995. sum_logprobs : List[List[float]], length = batch_size
  996. sequence of cumulative log probabilities corresponding to the above
  997. """
  998. raise NotImplementedError
  999. class GreedyDecoder(TokenDecoder):
  1000. def __init__(self, temperature: float, eot: int):
  1001. self.temperature = temperature
  1002. self.eot = eot
  1003. def update(
  1004. self,
  1005. tokens: paddle.Tensor,
  1006. logits: paddle.Tensor,
  1007. sum_logprobs: paddle.Tensor,
  1008. ) -> Tuple[paddle.Tensor, bool]:
  1009. temperature = self.temperature
  1010. if temperature == 0:
  1011. next_tokens = paddle.argmax(logits, axis=-1)
  1012. else:
  1013. next_tokens = paddle.distribution.Categorical(
  1014. logits=logits / temperature
  1015. ).sample([1])
  1016. next_tokens = paddle.reshape(
  1017. next_tokens,
  1018. [
  1019. next_tokens.shape[0] * next_tokens.shape[1],
  1020. ],
  1021. )
  1022. logprobs = paddle.nn.functional.log_softmax(
  1023. logits, axis=-1, dtype=paddle.float32
  1024. )
  1025. current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
  1026. sum_logprobs += current_logprobs * paddle.to_tensor(
  1027. (tokens[:, -1] != self.eot), dtype=paddle.float32
  1028. )
  1029. next_tokens[tokens[:, -1] == self.eot] = self.eot
  1030. tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
  1031. completed = paddle.all((tokens[:, -1] == self.eot))
  1032. return tokens, completed
  1033. def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
  1034. # make sure each sequence has at least one EOT token at the end
  1035. tokens = paddle.nn.functional.pad(
  1036. tokens, (0, 1), value=self.eot, data_format="NCL"
  1037. )
  1038. return tokens, sum_logprobs.tolist()
  1039. class BeamSearchDecoder(TokenDecoder):
  1040. def __init__(
  1041. self,
  1042. beam_size: int,
  1043. eot: int,
  1044. inference: Inference,
  1045. patience: Optional[float] = None,
  1046. ):
  1047. self.beam_size = beam_size
  1048. self.eot = eot
  1049. self.inference = inference
  1050. self.patience = patience or 1.0
  1051. if patience is None or patience == "None":
  1052. self.patience = 1.0
  1053. else:
  1054. self.patience = patience
  1055. self.max_candidates: int = round(beam_size * self.patience)
  1056. self.finished_sequences = None
  1057. assert (
  1058. self.max_candidates > 0
  1059. ), f"Invalid beam size ({beam_size}) or patience ({patience})"
  1060. def reset(self):
  1061. self.finished_sequences = None
  1062. def update(
  1063. self,
  1064. tokens: paddle.Tensor,
  1065. logits: paddle.Tensor,
  1066. sum_logprobs: paddle.Tensor,
  1067. ) -> Tuple[paddle.Tensor, bool]:
  1068. if tokens.shape[0] % self.beam_size != 0:
  1069. raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
  1070. batch_size = tokens.shape[0] // self.beam_size
  1071. if self.finished_sequences is None: # for the first update
  1072. self.finished_sequences = [{} for _ in range(batch_size)]
  1073. logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
  1074. next_tokens, source_indices, finished_sequences = [], [], []
  1075. for i in range(batch_size):
  1076. scores, sources, finished = {}, {}, {}
  1077. # STEP 1: calculate the cumulative log probabilities for possible candidates
  1078. for j in range(self.beam_size):
  1079. idx = i * self.beam_size + j
  1080. prefix = tokens[idx].tolist()
  1081. logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
  1082. for logprob, token in zip(logprob, token):
  1083. new_logprob = (sum_logprobs[idx] + logprob).item()
  1084. sequence = tuple(prefix + [token.item()])
  1085. scores[sequence] = new_logprob
  1086. sources[sequence] = idx
  1087. # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
  1088. saved = 0
  1089. for sequence in sorted(scores, key=scores.get, reverse=True):
  1090. if sequence[-1] == self.eot:
  1091. finished[sequence] = scores[sequence]
  1092. else:
  1093. sum_logprobs[len(next_tokens)] = scores[sequence]
  1094. next_tokens.append(sequence)
  1095. source_indices.append(sources[sequence])
  1096. saved += 1
  1097. if saved == self.beam_size:
  1098. break
  1099. finished_sequences.append(finished)
  1100. tokens = paddle.to_tensor(next_tokens)
  1101. self.inference.rearrange_kv_cache(source_indices)
  1102. # add newly finished sequences to self.finished_sequences
  1103. assert len(self.finished_sequences) == len(finished_sequences)
  1104. for previously_finished, newly_finished in zip(
  1105. self.finished_sequences, finished_sequences
  1106. ):
  1107. for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
  1108. if len(previously_finished) >= self.max_candidates:
  1109. break # the candidate list is full
  1110. previously_finished[seq] = newly_finished[seq]
  1111. # mark as completed if all audio has enough number of samples
  1112. completed = all(
  1113. len(sequences) >= self.max_candidates
  1114. for sequences in self.finished_sequences
  1115. )
  1116. return tokens, completed
  1117. def finalize(self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
  1118. # collect all finished sequences, including patience, and add unfinished ones if not enough
  1119. sum_logprobs = sum_logprobs.cpu()
  1120. for i, sequences in enumerate(self.finished_sequences):
  1121. if (
  1122. len(sequences) < self.beam_size
  1123. ): # when not enough sequences are finished
  1124. for j in list(np.argsort(sum_logprobs[i]))[::-1]:
  1125. sequence = preceding_tokens[i, j].tolist() + [self.eot]
  1126. sequences[tuple(sequence)] = sum_logprobs[i][j].item()
  1127. if len(sequences) >= self.beam_size:
  1128. break
  1129. tokens: List[List[paddle.Tensor]] = [
  1130. [paddle.to_tensor(seq) for seq in sequences.keys()]
  1131. for sequences in self.finished_sequences
  1132. ]
  1133. sum_logprobs: List[List[float]] = [
  1134. list(sequences.values()) for sequences in self.finished_sequences
  1135. ]
  1136. return tokens, sum_logprobs
  1137. class LogitFilter:
  1138. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
  1139. """Apply any filtering or masking to logits in-place
  1140. Parameters
  1141. ----------
  1142. logits : Tensor, shape = (n_batch, vocab_size)
  1143. per-token logits of the probability distribution at the current step
  1144. tokens : Tensor, shape = (n_batch, current_sequence_length)
  1145. all tokens in the context so far, including the prefix and sot_sequence tokens
  1146. """
  1147. raise NotImplementedError
  1148. class SuppressBlank(LogitFilter):
  1149. def __init__(self, tokenizer: Tokenizer, sample_begin: int):
  1150. self.tokenizer = tokenizer
  1151. self.sample_begin = sample_begin
  1152. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1153. if tokens.shape[1] == self.sample_begin:
  1154. logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
  1155. -np.inf
  1156. )
  1157. class SuppressTokens(LogitFilter):
  1158. def __init__(self, suppress_tokens: Sequence[int]):
  1159. self.suppress_tokens = list(suppress_tokens)
  1160. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1161. logits[:, self.suppress_tokens] = -np.inf
  1162. class ApplyTimestampRules(LogitFilter):
  1163. def __init__(
  1164. self,
  1165. tokenizer: Tokenizer,
  1166. sample_begin: int,
  1167. max_initial_timestamp_index: Optional[int],
  1168. ):
  1169. self.tokenizer = tokenizer
  1170. self.sample_begin = sample_begin
  1171. self.max_initial_timestamp_index = max_initial_timestamp_index
  1172. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1173. # suppress <|notimestamps|> which is handled by without_timestamps
  1174. if self.tokenizer.no_timestamps is not None:
  1175. logits[:, self.tokenizer.no_timestamps] = -np.inf
  1176. # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
  1177. for k in range(tokens.shape[0]):
  1178. seq = [t for t in tokens[k, self.sample_begin :].tolist()]
  1179. last_was_timestamp = (
  1180. len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
  1181. )
  1182. penultimate_was_timestamp = (
  1183. len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
  1184. )
  1185. if last_was_timestamp:
  1186. if penultimate_was_timestamp: # has to be non-timestamp
  1187. logits[k, self.tokenizer.timestamp_begin :] = -np.inf
  1188. else: # cannot be normal text tokens
  1189. logits[k, : self.tokenizer.eot] = -np.inf
  1190. # apply the `max_initial_timestamp` option
  1191. if (
  1192. tokens.shape[1] == self.sample_begin
  1193. and self.max_initial_timestamp_index is not None
  1194. ):
  1195. last_allowed = (
  1196. self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
  1197. )
  1198. logits[:, last_allowed + 1 :] = -np.inf
  1199. # if sum of probability over timestamps is above any other token, sample timestamp
  1200. logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
  1201. for k in range(tokens.shape[0]):
  1202. # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
  1203. # To bypass this issue in CI, we have decomposed the operation into separate steps.
  1204. # It will raise 2e-6 difference in precision.
  1205. # TODO: revert this after logsumexp been fixed.
  1206. timestamp_logprob = paddle.exp(
  1207. logprobs[k, self.tokenizer.timestamp_begin :]
  1208. )
  1209. timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
  1210. timestamp_logprob = paddle.log(timestamp_logprob)
  1211. max_text_token_logprob = paddle.max(
  1212. logprobs[k, : self.tokenizer.timestamp_begin]
  1213. )
  1214. if timestamp_logprob > max_text_token_logprob:
  1215. logits[k, : self.tokenizer.timestamp_begin] = -np.inf
  1216. class DecodingTask:
  1217. inference: Inference
  1218. sequence_ranker: SequenceRanker
  1219. decoder: TokenDecoder
  1220. logit_filters: List[LogitFilter]
  1221. def __init__(self, model: "Whisper", options: DecodingOptions, resource_path: str):
  1222. self.model = model
  1223. language = options.language or "en"
  1224. tokenizer = get_tokenizer(
  1225. model.is_multilingual,
  1226. resource_path=resource_path,
  1227. language=language,
  1228. task=options.task,
  1229. )
  1230. self.tokenizer: Tokenizer = tokenizer
  1231. self.options: DecodingOptions = self._verify_options(options)
  1232. self.resource_path: str = resource_path
  1233. self.beam_size: int = options.beam_size or options.best_of or 1
  1234. self.n_ctx: int = model.dims.n_text_ctx
  1235. self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
  1236. self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
  1237. if self.options.without_timestamps:
  1238. self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
  1239. self.initial_tokens: Tuple[int] = self._get_initial_tokens()
  1240. self.sample_begin: int = len(self.initial_tokens)
  1241. self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
  1242. # inference: implements the forward pass through the decoder, including kv caching
  1243. self.inference = WhisperInference(model, len(self.initial_tokens))
  1244. # sequence ranker: implements how to rank a group of sampled sequences
  1245. self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
  1246. # decoder: implements how to select the next tokens, given the autoregressive distribution
  1247. if options.beam_size is not None:
  1248. self.decoder = BeamSearchDecoder(
  1249. options.beam_size, tokenizer.eot, self.inference, options.patience
  1250. )
  1251. else:
  1252. self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
  1253. # logit filters: applies various rules to suppress or penalize certain tokens
  1254. self.logit_filters = []
  1255. if self.options.suppress_blank:
  1256. self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
  1257. if self.options.suppress_tokens:
  1258. self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
  1259. if not options.without_timestamps:
  1260. precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
  1261. max_initial_timestamp_index = None
  1262. if options.max_initial_timestamp:
  1263. max_initial_timestamp_index = round(
  1264. self.options.max_initial_timestamp / precision
  1265. )
  1266. self.logit_filters.append(
  1267. ApplyTimestampRules(
  1268. tokenizer, self.sample_begin, max_initial_timestamp_index
  1269. )
  1270. )
  1271. def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
  1272. if options.beam_size is not None and options.best_of is not None:
  1273. raise ValueError("beam_size and best_of can't be given together")
  1274. if options.temperature == 0:
  1275. if options.best_of is not None:
  1276. raise ValueError("best_of with greedy sampling (T=0) is not compatible")
  1277. if options.patience is not None and options.beam_size is None:
  1278. raise ValueError("patience requires beam_size to be given")
  1279. if options.length_penalty is not None and options.length_penalty != "None":
  1280. if not (0 <= options.length_penalty <= 1):
  1281. raise ValueError(
  1282. "length_penalty (alpha) should be a value between 0 and 1"
  1283. )
  1284. return options
  1285. def _get_initial_tokens(self) -> Tuple[int]:
  1286. tokens = list(self.sot_sequence)
  1287. prefix = self.options.prefix
  1288. prompt = self.options.prompt
  1289. if prefix:
  1290. prefix_tokens = (
  1291. self.tokenizer.encode(" " + prefix.strip().input_ids)
  1292. if isinstance(prefix, str)
  1293. else prefix
  1294. )
  1295. if self.sample_len is not None:
  1296. max_prefix_len = self.n_ctx // 2 - self.sample_len
  1297. prefix_tokens = prefix_tokens[-max_prefix_len:]
  1298. tokens = tokens + prefix_tokens
  1299. if prompt:
  1300. prompt_tokens = (
  1301. self.tokenizer.encode(" " + prompt.strip().input_ids)
  1302. if isinstance(prompt, str)
  1303. else prompt
  1304. )
  1305. tokens = (
  1306. [self.tokenizer.sot_prev]
  1307. + prompt_tokens[-(self.n_ctx // 2 - 1) :]
  1308. + tokens
  1309. )
  1310. return tuple(tokens)
  1311. def _get_suppress_tokens(self) -> Tuple[int]:
  1312. suppress_tokens = self.options.suppress_tokens
  1313. if isinstance(suppress_tokens, str):
  1314. suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
  1315. if -1 in suppress_tokens:
  1316. suppress_tokens = [t for t in suppress_tokens if t >= 0]
  1317. suppress_tokens.extend(self.tokenizer.non_speech_tokens)
  1318. elif suppress_tokens is None or len(suppress_tokens) == 0:
  1319. suppress_tokens = [] # interpret empty string as an empty list
  1320. else:
  1321. assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
  1322. suppress_tokens.extend(
  1323. [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
  1324. )
  1325. if self.tokenizer.no_speech is not None:
  1326. # no-speech probability is collected separately
  1327. suppress_tokens.append(self.tokenizer.no_speech)
  1328. return tuple(sorted(set(suppress_tokens)))
  1329. def _get_audio_features(self, mel: paddle.Tensor):
  1330. if mel.shape[-2:] == (
  1331. self.model.dims.n_audio_ctx,
  1332. self.model.dims.n_audio_state,
  1333. ):
  1334. # encoded audio features are given; skip audio encoding
  1335. audio_features = mel
  1336. else:
  1337. audio_features = self.model.encoder(mel)
  1338. return audio_features
  1339. def _detect_language(
  1340. self,
  1341. audio_features: paddle.Tensor,
  1342. tokens: paddle.Tensor,
  1343. resource_path: str,
  1344. ):
  1345. languages = [self.options.language] * audio_features.shape[0]
  1346. lang_probs = None
  1347. if self.options.language is None or self.options.task == "lang_id":
  1348. lang_tokens, lang_probs = self.model.detect_language(
  1349. audio_features, self.tokenizer, self.resource_path
  1350. )
  1351. languages = [max(probs, key=probs.get) for probs in lang_probs]
  1352. if self.options.language is None:
  1353. tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
  1354. return languages, lang_probs
  1355. def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
  1356. assert audio_features.shape[0] == tokens.shape[0]
  1357. n_batch = tokens.shape[0]
  1358. sum_logprobs: paddle.Tensor = paddle.zeros(
  1359. paddle.to_tensor(n_batch), dtype=paddle.float32
  1360. )
  1361. no_speech_probs = [np.nan] * n_batch
  1362. try:
  1363. for i in range(self.sample_len):
  1364. logits = self.inference.logits(tokens, audio_features)
  1365. if (
  1366. i == 0 and self.tokenizer.no_speech is not None
  1367. ): # save no_speech_probs
  1368. probs_at_sot = paddle.nn.functional.softmax(
  1369. logits[:, self.sot_index], axis=-1, dtype=paddle.float32
  1370. )
  1371. no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
  1372. # now we need to consider the logits at the last token only
  1373. logits = logits[:, -1]
  1374. # apply the logit filters, e.g. for suppressing or applying penalty to
  1375. for logit_filter in self.logit_filters:
  1376. logit_filter.apply(logits, tokens)
  1377. # expand the tokens tensor with the selected next tokens
  1378. tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
  1379. if completed or tokens.shape[-1] > self.n_ctx:
  1380. break
  1381. finally:
  1382. self.inference.cleanup_caching()
  1383. return tokens, sum_logprobs, no_speech_probs
  1384. @paddle.no_grad()
  1385. def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
  1386. self.decoder.reset()
  1387. tokenizer: Tokenizer = self.tokenizer
  1388. batch_size: int = mel.shape[0]
  1389. audio_features: paddle.Tensor = self._get_audio_features(
  1390. mel
  1391. ) # encoder forward pass
  1392. tokens: paddle.Tensor
  1393. if batch_size > 1:
  1394. for i in range(batch_size):
  1395. tokens = paddle.concat(
  1396. x=[
  1397. paddle.to_tensor([self.initial_tokens]),
  1398. paddle.to_tensor([self.initial_tokens]),
  1399. ],
  1400. axis=0,
  1401. )
  1402. elif batch_size == 1:
  1403. tokens = paddle.to_tensor([self.initial_tokens])
  1404. # detect language if requested, overwriting the language token
  1405. languages, language_probs = self._detect_language(
  1406. paddle.to_tensor(audio_features),
  1407. paddle.to_tensor(tokens),
  1408. self.resource_path,
  1409. )
  1410. if self.options.task == "lang_id":
  1411. return [
  1412. DecodingResult(
  1413. audio_features=features, language=language, language_probs=probs
  1414. )
  1415. for features, language, probs in zip(
  1416. audio_features, languages, language_probs
  1417. )
  1418. ]
  1419. # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
  1420. audio_features = paddle.repeat_interleave(
  1421. audio_features, self.beam_size, axis=0
  1422. )
  1423. tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
  1424. # call the main sampling loop
  1425. tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
  1426. # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
  1427. audio_features = audio_features[:: self.beam_size]
  1428. no_speech_probs = no_speech_probs[:: self.beam_size]
  1429. assert audio_features.shape[0] == len(no_speech_probs) == batch_size
  1430. tokens = tokens.reshape([batch_size, self.beam_size, -1])
  1431. sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
  1432. # get the final candidates for each group, and slice between the first sampled token and EOT
  1433. tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
  1434. tokens: List[List[paddle.Tensor]] = [
  1435. [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
  1436. for s in tokens
  1437. ]
  1438. # select the top-ranked sample in each group
  1439. selected = self.sequence_ranker.rank(tokens, sum_logprobs)
  1440. tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
  1441. texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
  1442. sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
  1443. avg_logprobs: List[float] = [
  1444. lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
  1445. ]
  1446. fields = (
  1447. texts,
  1448. languages,
  1449. tokens,
  1450. audio_features,
  1451. avg_logprobs,
  1452. no_speech_probs,
  1453. )
  1454. if len(set(map(len, fields))) != 1:
  1455. raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
  1456. return [
  1457. DecodingResult(
  1458. audio_features=features,
  1459. language=language,
  1460. tokens=tokens,
  1461. text=text,
  1462. avg_logprob=avg_logprob,
  1463. no_speech_prob=no_speech_prob,
  1464. temperature=self.options.temperature,
  1465. compression_ratio=compression_ratio(text),
  1466. )
  1467. for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
  1468. *fields
  1469. )
  1470. ]
  1471. @paddle.no_grad()
  1472. def decode(
  1473. model: "Whisper",
  1474. mel: paddle.Tensor,
  1475. options: DecodingOptions = DecodingOptions(),
  1476. resource_path=str,
  1477. ) -> Union[DecodingResult, List[DecodingResult]]:
  1478. """
  1479. Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
  1480. Parameters
  1481. ----------
  1482. model: Whisper
  1483. the Whisper model instance
  1484. mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
  1485. A tensor containing the Mel spectrogram(s)
  1486. options: DecodingOptions
  1487. A dataclass that contains all necessary options for decoding 30-second segments
  1488. Returns
  1489. -------
  1490. result: Union[DecodingResult, List[DecodingResult]]
  1491. The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
  1492. """
  1493. single = mel.ndim == 2
  1494. if single:
  1495. mel = mel.unsqueeze(0)
  1496. result = DecodingTask(model, options, resource_path).run(mel)
  1497. if single:
  1498. result = result[0]
  1499. return result
  1500. class Whisper(paddle.nn.Layer):
  1501. """
  1502. The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
  1503. """
  1504. def __init__(self, dims: ModelDimensions):
  1505. super().__init__()
  1506. self.dims = dims
  1507. self.encoder = AudioEncoder(
  1508. self.dims.n_mels,
  1509. self.dims.n_audio_ctx,
  1510. self.dims.n_audio_state,
  1511. self.dims.n_audio_head,
  1512. self.dims.n_audio_layer,
  1513. )
  1514. self.decoder = TextDecoder(
  1515. self.dims.n_vocab,
  1516. self.dims.n_text_ctx,
  1517. self.dims.n_text_state,
  1518. self.dims.n_text_head,
  1519. self.dims.n_text_layer,
  1520. )
  1521. def embed_audio(self, mel: paddle.Tensor):
  1522. return self.encoder.forward(mel)
  1523. def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
  1524. return self.decoder.forward(tokens, audio_features)
  1525. def forward(
  1526. self, mel: paddle.Tensor, tokens: paddle.Tensor
  1527. ) -> Dict[str, paddle.Tensor]:
  1528. return self.decoder(tokens, self.encoder(mel))
  1529. @property
  1530. def device(self):
  1531. return paddle.device.get_device()
  1532. @property
  1533. def is_multilingual(self):
  1534. return self.dims.n_vocab == 51865
  1535. def install_kv_cache_hooks(self, cache: Optional[dict] = None):
  1536. """
  1537. The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
  1538. tensors calculated for the previous positions. This method returns a dictionary that stores
  1539. all caches, and the necessary hooks for the key and value projection modules that save the
  1540. intermediate tensors to be reused during later calculations.
  1541. Returns
  1542. -------
  1543. cache : Dict[nn.Layer, paddle.Tensor]
  1544. A dictionary object mapping the key/value projection modules to its cache
  1545. hooks : List[RemovableHandle]
  1546. List of PyTorch RemovableHandle objects to stop the hooks to be called
  1547. """
  1548. cache = {**cache} if cache is not None else {}
  1549. hooks = []
  1550. def save_to_cache(module, _, output):
  1551. if (
  1552. module not in cache
  1553. or output.shape[1] > self.decoder.positional_embedding.shape[0]
  1554. ):
  1555. cache[module] = (
  1556. output # save as-is, for the first token or cross attention
  1557. )
  1558. else:
  1559. cache[module] = paddle.concat([cache[module], output], axis=1).detach()
  1560. return cache[module]
  1561. def install_hooks(layer: paddle.nn.Layer):
  1562. if isinstance(layer, MultiHeadAttention):
  1563. hooks.append(layer.key.register_forward_post_hook(save_to_cache))
  1564. hooks.append(layer.value.register_forward_post_hook(save_to_cache))
  1565. self.decoder.apply(install_hooks)
  1566. return cache, hooks
  1567. detect_language = detect_language
  1568. transcribe = transcribe
  1569. decode = decode
  1570. def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
  1571. """
  1572. Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
  1573. """
  1574. if paddle.is_tensor(array):
  1575. if array.shape[axis] > length:
  1576. array = array.index_select(axis=axis, index=paddle.arange(length))
  1577. if array.shape[axis] < length:
  1578. pad_widths = [(0, 0)] * array.ndim
  1579. pad_widths[axis] = (0, length - array.shape[axis])
  1580. array = paddle.transpose(array, (1, 0))
  1581. array = paddle.nn.functional.pad(
  1582. array,
  1583. [pad for sizes in pad_widths[::-1] for pad in sizes],
  1584. data_format="NLC",
  1585. )
  1586. array = paddle.transpose(array, (1, 0))
  1587. else:
  1588. if array.shape[axis] > length:
  1589. array = array.take(indices=range(length), axis=axis)
  1590. if array.shape[axis] < length:
  1591. pad_widths = [(0, 0)] * array.ndim
  1592. pad_widths[axis] = (0, length - array.shape[axis])
  1593. array = paddle.transpose(array, (1, 0))
  1594. array = np.pad(array, pad_widths)
  1595. array = paddle.transpose(array, (1, 0))
  1596. return array
  1597. def hann_window(n_fft: int = N_FFT):
  1598. """
  1599. hanning window
  1600. n_fft: The number of frequency components of the discrete Fourier transform.
  1601. """
  1602. return paddle.to_tensor(
  1603. [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
  1604. dtype=paddle.float32,
  1605. )
  1606. @lru_cache(maxsize=None)
  1607. def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
  1608. """
  1609. load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
  1610. Allows decoupling librosa dependency; saved using:
  1611. np.savez_compressed(
  1612. "mel_filters.npz",
  1613. mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
  1614. )
  1615. """
  1616. assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
  1617. with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
  1618. return paddle.to_tensor(f[f"mel_{n_mels}"])
  1619. @function_requires_deps("soundfile")
  1620. def log_mel_spectrogram(
  1621. audio: Union[str, np.ndarray, paddle.Tensor],
  1622. n_mels: int = N_MELS,
  1623. resource_path: str = None,
  1624. ):
  1625. """
  1626. Compute the log-Mel spectrogram of
  1627. Parameters
  1628. ----------
  1629. audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
  1630. The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
  1631. n_mels: int
  1632. The number of Mel-frequency filters, only 80 is supported
  1633. Returns
  1634. -------
  1635. paddle.Tensor, shape = (80, n_frames)
  1636. A Tensor that contains the Mel spectrogram
  1637. """
  1638. if not paddle.is_tensor(audio):
  1639. if isinstance(audio, str):
  1640. audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
  1641. audio = audio[:, 0]
  1642. audio = paddle.to_tensor(audio)
  1643. window = hann_window(N_FFT)
  1644. stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
  1645. magnitudes = stft[:, :-1].abs() ** 2
  1646. filters = mel_filters(resource_path, n_mels)
  1647. mel_spec = filters @ magnitudes
  1648. mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
  1649. log_spec = paddle.clip(mel_spec, min=1e-10).log10()
  1650. log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
  1651. log_spec = (log_spec + 4.0) / 4.0
  1652. return log_spec