processors.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
  15. import os
  16. import zlib
  17. from dataclasses import dataclass, field
  18. from functools import lru_cache
  19. from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
  20. import numpy as np
  21. import paddle
  22. from ....utils.deps import function_requires_deps, is_dep_available
  23. from ...utils.benchmark import (
  24. benchmark,
  25. get_inference_operations,
  26. set_inference_operations,
  27. )
  28. from ..common.tokenizer import GPTTokenizer
  29. if is_dep_available("soundfile"):
  30. import soundfile
  31. if is_dep_available("tqdm"):
  32. import tqdm
  33. __all__ = [
  34. "Whisper",
  35. "Tokenizer",
  36. ]
  37. def exact_div(x, y):
  38. assert x % y == 0
  39. return x // y
  40. _MODELS = ["large"]
  41. SAMPLE_RATE = 16000
  42. N_FFT = 400
  43. N_MELS = 80
  44. HOP_LENGTH = 160
  45. CHUNK_LENGTH = 30
  46. N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
  47. N_FRAMES = exact_div(
  48. N_SAMPLES, HOP_LENGTH
  49. ) # 3000: number of frames in a mel spectrogram input
  50. @dataclass
  51. class ModelDimensions:
  52. n_mels: int
  53. n_audio_ctx: int
  54. n_audio_state: int
  55. n_audio_head: int
  56. n_audio_layer: int
  57. n_vocab: int
  58. n_text_ctx: int
  59. n_text_state: int
  60. n_text_head: int
  61. n_text_layer: int
  62. LANGUAGES = {
  63. "en": "english",
  64. "zh": "chinese",
  65. "de": "german",
  66. "es": "spanish",
  67. "ru": "russian",
  68. "ko": "korean",
  69. "fr": "french",
  70. "ja": "japanese",
  71. "pt": "portuguese",
  72. "tr": "turkish",
  73. "pl": "polish",
  74. "ca": "catalan",
  75. "nl": "dutch",
  76. "ar": "arabic",
  77. "sv": "swedish",
  78. "it": "italian",
  79. "id": "indonesian",
  80. "hi": "hindi",
  81. "fi": "finnish",
  82. "vi": "vietnamese",
  83. "iw": "hebrew",
  84. "uk": "ukrainian",
  85. "el": "greek",
  86. "ms": "malay",
  87. "cs": "czech",
  88. "ro": "romanian",
  89. "da": "danish",
  90. "hu": "hungarian",
  91. "ta": "tamil",
  92. "no": "norwegian",
  93. "th": "thai",
  94. "ur": "urdu",
  95. "hr": "croatian",
  96. "bg": "bulgarian",
  97. "lt": "lithuanian",
  98. "la": "latin",
  99. "mi": "maori",
  100. "ml": "malayalam",
  101. "cy": "welsh",
  102. "sk": "slovak",
  103. "te": "telugu",
  104. "fa": "persian",
  105. "lv": "latvian",
  106. "bn": "bengali",
  107. "sr": "serbian",
  108. "az": "azerbaijani",
  109. "sl": "slovenian",
  110. "kn": "kannada",
  111. "et": "estonian",
  112. "mk": "macedonian",
  113. "br": "breton",
  114. "eu": "basque",
  115. "is": "icelandic",
  116. "hy": "armenian",
  117. "ne": "nepali",
  118. "mn": "mongolian",
  119. "bs": "bosnian",
  120. "kk": "kazakh",
  121. "sq": "albanian",
  122. "sw": "swahili",
  123. "gl": "galician",
  124. "mr": "marathi",
  125. "pa": "punjabi",
  126. "si": "sinhala",
  127. "km": "khmer",
  128. "sn": "shona",
  129. "yo": "yoruba",
  130. "so": "somali",
  131. "af": "afrikaans",
  132. "oc": "occitan",
  133. "ka": "georgian",
  134. "be": "belarusian",
  135. "tg": "tajik",
  136. "sd": "sindhi",
  137. "gu": "gujarati",
  138. "am": "amharic",
  139. "yi": "yiddish",
  140. "lo": "lao",
  141. "uz": "uzbek",
  142. "fo": "faroese",
  143. "ht": "haitian creole",
  144. "ps": "pashto",
  145. "tk": "turkmen",
  146. "nn": "nynorsk",
  147. "mt": "maltese",
  148. "sa": "sanskrit",
  149. "lb": "luxembourgish",
  150. "my": "myanmar",
  151. "bo": "tibetan",
  152. "tl": "tagalog",
  153. "mg": "malagasy",
  154. "as": "assamese",
  155. "tt": "tatar",
  156. "haw": "hawaiian",
  157. "ln": "lingala",
  158. "ha": "hausa",
  159. "ba": "bashkir",
  160. "jw": "javanese",
  161. "su": "sundanese",
  162. }
  163. # language code lookup by name, with a few language aliases
  164. TO_LANGUAGE_CODE = {
  165. **{language: code for code, language in LANGUAGES.items()},
  166. "burmese": "my",
  167. "valencian": "ca",
  168. "flemish": "nl",
  169. "haitian": "ht",
  170. "letzeburgesch": "lb",
  171. "pushto": "ps",
  172. "panjabi": "pa",
  173. "moldavian": "ro",
  174. "moldovan": "ro",
  175. "sinhalese": "si",
  176. "castilian": "es",
  177. }
  178. def compression_ratio(text) -> float:
  179. return len(text) / len(zlib.compress(text.encode("utf-8")))
  180. def format_timestamp(
  181. seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
  182. ):
  183. assert seconds >= 0, "non-negative timestamp expected"
  184. milliseconds = round(seconds * 1000.0)
  185. hours = milliseconds // 3_600_000
  186. milliseconds -= hours * 3_600_000
  187. minutes = milliseconds // 60_000
  188. milliseconds -= minutes * 60_000
  189. seconds = milliseconds // 1_000
  190. milliseconds -= seconds * 1_000
  191. hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
  192. return (
  193. f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
  194. )
  195. @dataclass(frozen=True)
  196. class Tokenizer:
  197. """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
  198. tokenizer: "GPTTokenizer"
  199. language: Optional[str]
  200. sot_sequence: Tuple[int]
  201. def encode(self, text, **kwargs):
  202. return self.tokenizer.encode(text, **kwargs)
  203. def decode(
  204. self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
  205. ):
  206. if len(token_ids) > 1:
  207. ids_list = []
  208. for ids in token_ids:
  209. if paddle.is_tensor(ids):
  210. ids = ids.item()
  211. if ids < len(self.tokenizer):
  212. ids_list.append(ids)
  213. token_ids = ids_list
  214. elif len(token_ids) == 1:
  215. token_ids = token_ids[0]
  216. else:
  217. raise ValueError(f"token_ids {token_ids} load error.")
  218. return self.tokenizer.decode(token_ids, **kwargs)
  219. def decode_with_timestamps(self, tokens) -> str:
  220. """
  221. Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
  222. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
  223. """
  224. outputs = [[]]
  225. for token in tokens:
  226. if token >= self.timestamp_begin:
  227. timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
  228. outputs.append(timestamp)
  229. outputs.append([])
  230. else:
  231. outputs[-1].append(token)
  232. outputs = [
  233. s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
  234. ]
  235. return "".join(outputs)
  236. @property
  237. @lru_cache()
  238. def eot(self) -> int:
  239. return self.tokenizer.eos_token_id
  240. @property
  241. @lru_cache()
  242. def sot(self) -> int:
  243. return self._get_single_token_id("<|startoftranscript|>")
  244. @property
  245. @lru_cache()
  246. def sot_lm(self) -> int:
  247. return self._get_single_token_id("<|startoflm|>")
  248. @property
  249. @lru_cache()
  250. def sot_prev(self) -> int:
  251. return self._get_single_token_id("<|startofprev|>")
  252. @property
  253. @lru_cache()
  254. def no_speech(self) -> int:
  255. return self._get_single_token_id("<|nospeech|>")
  256. @property
  257. @lru_cache()
  258. def no_timestamps(self) -> int:
  259. return self._get_single_token_id("<|notimestamps|>")
  260. @property
  261. @lru_cache()
  262. def timestamp_begin(self) -> int:
  263. return self.tokenizer.all_special_ids[-1] + 1
  264. @property
  265. @lru_cache()
  266. def language_token(self) -> int:
  267. """Returns the token id corresponding to the value of the `language` field"""
  268. if self.language is None:
  269. raise ValueError("This tokenizer does not have language token configured")
  270. additional_tokens = dict(
  271. zip(
  272. self.tokenizer.additional_special_tokens,
  273. self.tokenizer.additional_special_tokens_ids,
  274. )
  275. )
  276. candidate = f"<|{self.language}|>"
  277. if candidate in additional_tokens:
  278. return additional_tokens[candidate]
  279. raise KeyError(f"Language {self.language} not found in tokenizer.")
  280. @property
  281. @lru_cache()
  282. def all_language_tokens(self) -> Tuple[int]:
  283. result = []
  284. for token, token_id in zip(
  285. self.tokenizer.additional_special_tokens,
  286. self.tokenizer.additional_special_tokens_ids,
  287. ):
  288. if token.strip("<|>") in LANGUAGES:
  289. result.append(token_id)
  290. return tuple(result)
  291. @property
  292. @lru_cache()
  293. def all_language_codes(self) -> Tuple[str]:
  294. return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
  295. @property
  296. @lru_cache()
  297. def sot_sequence_including_notimestamps(self) -> Tuple[int]:
  298. return tuple(list(self.sot_sequence) + [self.no_timestamps])
  299. @property
  300. @lru_cache()
  301. def non_speech_tokens(self) -> Tuple[int]:
  302. """
  303. Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
  304. annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
  305. - ♪♪♪
  306. - ( SPEAKING FOREIGN LANGUAGE )
  307. - [DAVID] Hey there,
  308. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
  309. """
  310. symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
  311. symbols += (
  312. "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
  313. )
  314. # symbols that may be a single token or multiple tokens depending on the tokenizer.
  315. # In case they're multiple tokens, suppress the first token, which is safe because:
  316. # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
  317. # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
  318. miscellaneous = set("♩♪♫♬♭♮♯")
  319. assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
  320. # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
  321. result = {
  322. self.tokenizer.encode(" -").input_ids[0],
  323. self.tokenizer.encode(" '").input_ids[0],
  324. }
  325. for symbol in symbols + list(miscellaneous):
  326. for tokens in [
  327. self.tokenizer.encode(symbol).input_ids,
  328. self.tokenizer.encode(" " + symbol).input_ids,
  329. ]:
  330. if len(tokens) == 1 or symbol in miscellaneous:
  331. result.add(tokens[0])
  332. return tuple(sorted(result))
  333. def _get_single_token_id(self, text) -> int:
  334. tokens = self.tokenizer.encode(text).input_ids
  335. assert len(tokens) == 1, f"{text} is not encoded as a single token"
  336. return tokens[0]
  337. @lru_cache(maxsize=None)
  338. def build_tokenizer(resource_path: str, name: str = "gpt2"):
  339. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  340. path = os.path.join(resource_path, "assets", name)
  341. tokenizer = GPTTokenizer.from_pretrained(path)
  342. specials = [
  343. "<|startoftranscript|>",
  344. *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
  345. "<|translate|>",
  346. "<|transcribe|>",
  347. "<|startoflm|>",
  348. "<|startofprev|>",
  349. "<|nospeech|>",
  350. "<|notimestamps|>",
  351. ]
  352. tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
  353. return tokenizer
  354. @lru_cache(maxsize=None)
  355. def get_tokenizer(
  356. multilingual: bool,
  357. resource_path: str,
  358. *,
  359. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  360. language: Optional[str] = None,
  361. ) -> Tokenizer:
  362. if language is not None:
  363. language = language.lower()
  364. if language not in LANGUAGES:
  365. if language in TO_LANGUAGE_CODE:
  366. language = TO_LANGUAGE_CODE[language]
  367. else:
  368. raise ValueError(f"Unsupported language: {language}")
  369. if multilingual:
  370. tokenizer_name = "multilingual"
  371. task = task or "transcribe"
  372. language = language or "en"
  373. else:
  374. tokenizer_name = "gpt2"
  375. task = None
  376. language = None
  377. tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
  378. all_special_ids: List[int] = tokenizer.all_special_ids
  379. sot: int = all_special_ids[1]
  380. translate: int = all_special_ids[-6]
  381. transcribe: int = all_special_ids[-5]
  382. langs = tuple(LANGUAGES.keys())
  383. sot_sequence = [sot]
  384. if language is not None:
  385. sot_sequence.append(sot + 1 + langs.index(language))
  386. if task is not None:
  387. sot_sequence.append(transcribe if task == "transcribe" else translate)
  388. return Tokenizer(
  389. tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
  390. )
  391. class MultiHeadAttention(paddle.nn.Layer):
  392. def __init__(self, n_state: int, n_head: int):
  393. super().__init__()
  394. self.n_head = n_head
  395. self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  396. self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
  397. self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  398. self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  399. def forward(
  400. self,
  401. x: paddle.Tensor,
  402. xa: Optional[paddle.Tensor] = None,
  403. mask: Optional[paddle.Tensor] = None,
  404. kv_cache: Optional[dict] = None,
  405. ):
  406. q = self.query(x)
  407. if kv_cache is None or xa is None or self.key not in kv_cache:
  408. # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
  409. # otherwise, perform key/value projections for self- or cross-attention as usual.
  410. k = self.key(x if xa is None else xa)
  411. v = self.value(x if xa is None else xa)
  412. else:
  413. # for cross-attention, calculate keys and values once and reuse in subsequent calls.
  414. k = kv_cache[self.key]
  415. v = kv_cache[self.value]
  416. wv = self.qkv_attention(q, k, v, mask)
  417. return self.out(wv)
  418. def qkv_attention(
  419. self,
  420. q: paddle.Tensor,
  421. k: paddle.Tensor,
  422. v: paddle.Tensor,
  423. mask: Optional[paddle.Tensor] = None,
  424. ):
  425. n_batch, n_ctx, n_state = q.shape
  426. scale = (n_state // self.n_head) ** -0.25
  427. q = (
  428. paddle.transpose(q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
  429. * scale
  430. )
  431. k = (
  432. paddle.transpose(k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1))
  433. * scale
  434. )
  435. v = paddle.transpose(v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
  436. qk = q @ k
  437. if mask is not None:
  438. qk = qk + mask[:n_ctx, :n_ctx]
  439. w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
  440. return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
  441. class ResidualAttentionBlock(paddle.nn.Layer):
  442. def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
  443. super().__init__()
  444. self.attn = MultiHeadAttention(n_state, n_head)
  445. self.attn_ln = paddle.nn.LayerNorm(n_state)
  446. self.cross_attn = (
  447. MultiHeadAttention(n_state, n_head) if cross_attention else None
  448. )
  449. self.cross_attn_ln = paddle.nn.LayerNorm(n_state) if cross_attention else None
  450. n_mlp = n_state * 4
  451. self.mlp = paddle.nn.Sequential(
  452. paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
  453. paddle.nn.GELU(),
  454. paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
  455. )
  456. self.mlp_ln = paddle.nn.LayerNorm(n_state)
  457. def forward(
  458. self,
  459. x: paddle.Tensor,
  460. xa: Optional[paddle.Tensor] = None,
  461. mask: Optional[paddle.Tensor] = None,
  462. kv_cache: Optional[dict] = None,
  463. ):
  464. x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
  465. if self.cross_attn:
  466. x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
  467. x = x + self.mlp(self.mlp_ln(x))
  468. return x
  469. def sinusoids(length, channels, max_timescale=10000):
  470. """Returns sinusoids for positional embedding"""
  471. assert channels % 2 == 0
  472. log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
  473. inv_timescales = paddle.exp(
  474. -log_timescale_increment * paddle.arange(channels // 2, dtype=paddle.float32)
  475. )
  476. scaled_time = (
  477. paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
  478. * inv_timescales[np.newaxis, :]
  479. )
  480. return paddle.to_tensor(
  481. paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
  482. )
  483. class AudioEncoder(paddle.nn.Layer):
  484. def __init__(
  485. self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  486. ):
  487. super().__init__()
  488. self.conv1 = paddle.nn.Conv1D(
  489. n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
  490. )
  491. self.conv2 = paddle.nn.Conv1D(
  492. n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
  493. )
  494. self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
  495. self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
  496. [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
  497. )
  498. self.ln_post = paddle.nn.LayerNorm(n_state)
  499. def forward(self, x: paddle.Tensor):
  500. """
  501. x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
  502. the mel spectrogram of the audio
  503. """
  504. x = paddle.nn.functional.gelu(self.conv1(x))
  505. x = paddle.nn.functional.gelu(self.conv2(x))
  506. x = paddle.transpose(x, (0, 2, 1))
  507. assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
  508. x = x + self.positional_embedding
  509. for block in self.blocks:
  510. x = block(x)
  511. x = self.ln_post(x)
  512. return x
  513. class TextDecoder(paddle.nn.Layer):
  514. def __init__(
  515. self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  516. ):
  517. super().__init__()
  518. self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
  519. self.positional_embedding = paddle.create_parameter(
  520. shape=[n_ctx, n_state], dtype="float32"
  521. )
  522. self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
  523. [
  524. ResidualAttentionBlock(n_state, n_head, cross_attention=True)
  525. for _ in range(n_layer)
  526. ]
  527. )
  528. self.ln = paddle.nn.LayerNorm(n_state)
  529. mask = paddle.full(shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32")
  530. mask = paddle.triu(mask, diagonal=1)
  531. self.register_buffer("mask", mask, persistable=False)
  532. def forward(
  533. self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
  534. ):
  535. """
  536. x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
  537. the text tokens
  538. xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
  539. the encoded audio features to be attended on
  540. """
  541. offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
  542. x = (
  543. self.token_embedding(x)
  544. + self.positional_embedding[offset : offset + x.shape[-1]]
  545. )
  546. x = x.to(xa.dtype)
  547. for block in self.blocks:
  548. x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
  549. x = self.ln(x)
  550. logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
  551. return logits
  552. @dataclass(frozen=True)
  553. class DecodingOptions:
  554. task: str = (
  555. "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
  556. )
  557. language: Optional[str] = (
  558. None # language that the audio is in; uses detected language if None
  559. )
  560. # sampling-related options
  561. temperature: float = 0.0
  562. sample_len: Optional[int] = None # maximum number of tokens to sample
  563. best_of: Optional[int] = (
  564. None # number of independent samples to collect, when t > 0
  565. )
  566. beam_size: Optional[int] = None # number of beams in beam search, when t == 0
  567. patience: Optional[float] = (
  568. None # patience in beam search (https://arxiv.org/abs/2204.05424)
  569. )
  570. # options for ranking generations (either beams or best-of-N samples)
  571. length_penalty: Optional[float] = (
  572. None # "alpha" in Google NMT, None defaults to length norm
  573. )
  574. # prompt, prefix, and token suppression
  575. prompt: Optional[Union[str, List[int]]] = (
  576. None # text or tokens for the previous context
  577. )
  578. prefix: Optional[Union[str, List[int]]] = (
  579. None # text or tokens to prefix the current context
  580. )
  581. suppress_blank: bool = True # this will suppress blank outputs
  582. # list of tokens ids (or comma-separated token ids) to suppress
  583. # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
  584. suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
  585. # timestamp sampling options
  586. without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
  587. max_initial_timestamp: Optional[float] = (
  588. 1.0 # the initial timestamp cannot be later than this
  589. )
  590. # implementation details
  591. fp16: bool = False # use fp16 for most of the calculation
  592. @dataclass(frozen=True)
  593. class DecodingResult:
  594. audio_features: paddle.Tensor
  595. language: str
  596. language_probs: Optional[Dict[str, float]] = None
  597. tokens: List[int] = field(default_factory=list)
  598. text: str = ""
  599. avg_logprob: float = np.nan
  600. no_speech_prob: float = np.nan
  601. temperature: float = np.nan
  602. compression_ratio: float = np.nan
  603. class Inference:
  604. def logits(
  605. self, tokens: paddle.Tensor, audio_features: paddle.Tensor
  606. ) -> paddle.Tensor:
  607. """Perform a forward pass on the decoder and return per-token logits"""
  608. raise NotImplementedError
  609. def rearrange_kv_cache(self, source_indices) -> None:
  610. """Update the key-value cache according to the updated beams"""
  611. raise NotImplementedError
  612. def cleanup_caching(self) -> None:
  613. """Clean up any resources or hooks after decoding is finished"""
  614. class WhisperInference(Inference):
  615. def __init__(self, model: "Whisper", initial_token_length: int):
  616. self.model: "Whisper" = model
  617. self.initial_token_length = initial_token_length
  618. self.kv_cache = {}
  619. self.hooks = []
  620. def logits(
  621. self, tokens: paddle.Tensor, audio_features: paddle.Tensor
  622. ) -> paddle.Tensor:
  623. if not self.kv_cache:
  624. self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
  625. if tokens.shape[-1] > self.initial_token_length:
  626. # only need to use the last token except in the first forward pass
  627. tokens = tokens[:, -1:]
  628. return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
  629. def cleanup_caching(self):
  630. for hook in self.hooks:
  631. hook.remove()
  632. self.kv_cache = {}
  633. self.hooks = []
  634. def rearrange_kv_cache(self, source_indices):
  635. for module, tensor in self.kv_cache.items():
  636. # update the key/value cache to contain the selected sequences
  637. self.kv_cache[module] = tensor[source_indices].detach()
  638. @paddle.no_grad()
  639. def detect_language(
  640. model: "Whisper",
  641. mel: paddle.Tensor,
  642. resource_path: str,
  643. tokenizer: Tokenizer = None,
  644. ) -> Tuple[paddle.Tensor, List[dict]]:
  645. """
  646. Detect the spoken language in the audio, and return them as list of strings, along with the ids
  647. of the most probable language tokens and the probability distribution over all language tokens.
  648. This is performed outside the main decode loop in order to not interfere with kv-caching.
  649. Returns
  650. -------
  651. language_tokens : Tensor, shape = (batch_size,)
  652. ids of the most probable language tokens, which appears after the startoftranscript token.
  653. language_probs : List[Dict[str, float]], length = batch_size
  654. list of dictionaries containing the probability distribution over all languages.
  655. """
  656. if tokenizer is None:
  657. tokenizer = get_tokenizer(model.is_multilingual, resource_path=resource_path)
  658. if (
  659. tokenizer.language is None
  660. or tokenizer.language_token not in tokenizer.sot_sequence
  661. ):
  662. raise ValueError(
  663. "This model doesn't have language tokens so it can't perform lang id"
  664. )
  665. single = mel.ndim == 2
  666. if single:
  667. mel = mel.unsqueeze(0)
  668. # skip encoder forward pass if already-encoded audio features were given
  669. if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
  670. mel = model.encoder(mel)
  671. # forward pass using a single token, startoftranscript
  672. batch_size = mel.shape[0]
  673. x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
  674. logits = model.logits(x, mel)[:, 0]
  675. # collect detected languages; suppress all non-language tokens
  676. mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
  677. mask[list(tokenizer.all_language_tokens)] = False
  678. logits[:, mask] = -np.inf
  679. language_tokens = paddle.argmax(logits, axis=-1)
  680. language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
  681. language_probs = [
  682. {
  683. c: language_token_probs[i, j].tolist()
  684. for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
  685. }
  686. for i in range(batch_size)
  687. ]
  688. if single:
  689. language_tokens = language_tokens[0]
  690. language_probs = language_probs[0]
  691. return language_tokens, language_probs
  692. @function_requires_deps("tqdm")
  693. def transcribe(
  694. model: "Whisper",
  695. mel: paddle.Tensor,
  696. resource_path: str,
  697. *,
  698. verbose: Optional[bool] = None,
  699. temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
  700. compression_ratio_threshold: Optional[float] = 2.4,
  701. logprob_threshold: Optional[float] = -1.0,
  702. no_speech_threshold: Optional[float] = 0.6,
  703. condition_on_previous_text: bool = True,
  704. **decode_options,
  705. ):
  706. """
  707. Transcribe an audio file using Whisper
  708. Parameters
  709. ----------
  710. model: Whisper
  711. The Whisper model instance
  712. mel: paddle.Tensor
  713. The audio feature
  714. verbose: bool
  715. Whether to display the text being decoded to the console. If True, displays all the details,
  716. If False, displays minimal details. If None, does not display anything
  717. temperature: Union[float, Tuple[float, ...]]
  718. Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
  719. upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
  720. compression_ratio_threshold: float
  721. If the gzip compression ratio is above this value, treat as failed
  722. logprob_threshold: float
  723. If the average log probability over sampled tokens is below this value, treat as failed
  724. no_speech_threshold: float
  725. If the no_speech probability is higher than this value AND the average log probability
  726. over sampled tokens is below `logprob_threshold`, consider the segment as silent
  727. condition_on_previous_text: bool
  728. if True, the previous output of the model is provided as a prompt for the next window;
  729. disabling may make the text inconsistent across windows, but the model becomes less prone to
  730. getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
  731. decode_options: dict
  732. Keyword arguments to construct `DecodingOptions` instances
  733. Returns
  734. -------
  735. A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
  736. the spoken language ("language"), which is detected when `decode_options["language"]` is None.
  737. """
  738. dtype = np.float32 # paddle only support float32
  739. if dtype == np.float32:
  740. decode_options["fp16"] = False
  741. if (
  742. decode_options.get("language") == "None"
  743. or decode_options.get("language", None) is None
  744. ):
  745. if not model.is_multilingual:
  746. decode_options["language"] = "en"
  747. else:
  748. if verbose:
  749. print(
  750. "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
  751. )
  752. segment = pad_or_trim(mel, N_FRAMES)
  753. _, probs = model.detect_language(segment, resource_path)
  754. decode_options["language"] = max(probs, key=probs.get)
  755. if verbose is not None:
  756. print(
  757. f"Detected language: {LANGUAGES[decode_options['language']].title()}"
  758. )
  759. language = decode_options["language"]
  760. task = decode_options.get("task", "transcribe")
  761. tokenizer = get_tokenizer(
  762. model.is_multilingual,
  763. resource_path=resource_path,
  764. language=language,
  765. task=task,
  766. )
  767. def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
  768. temperatures = (
  769. [temperature] if isinstance(temperature, (int, float)) else temperature
  770. )
  771. decode_result = None
  772. for t in temperatures:
  773. kwargs = {**decode_options}
  774. if t > 0:
  775. # disable beam_size and patience when t > 0
  776. kwargs.pop("beam_size", None)
  777. kwargs.pop("patience", None)
  778. else:
  779. # disable best_of when t == 0
  780. kwargs.pop("best_of", None)
  781. options = DecodingOptions(**kwargs, temperature=t)
  782. decode_result = model.decode(segment, options, resource_path)
  783. needs_fallback = False
  784. if (
  785. compression_ratio_threshold is not None
  786. and decode_result.compression_ratio > compression_ratio_threshold
  787. ):
  788. needs_fallback = True # too repetitive
  789. if (
  790. logprob_threshold is not None
  791. and decode_result.avg_logprob < logprob_threshold
  792. ):
  793. needs_fallback = True # average log probability is too low
  794. if not needs_fallback:
  795. break
  796. return decode_result
  797. seek = 0
  798. input_stride = exact_div(
  799. N_FRAMES, model.dims.n_audio_ctx
  800. ) # mel frames per output token: 2
  801. time_precision = (
  802. input_stride * HOP_LENGTH / SAMPLE_RATE
  803. ) # time per output token: 0.02 (seconds)
  804. all_tokens = []
  805. all_segments = []
  806. prompt_reset_since = 0
  807. initial_prompt = decode_options.pop("initial_prompt", None)
  808. if initial_prompt and initial_prompt != "None":
  809. initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
  810. all_tokens.extend(initial_prompt)
  811. else:
  812. initial_prompt = []
  813. def add_segment(
  814. *,
  815. start: float,
  816. end: float,
  817. text_tokens: paddle.Tensor,
  818. result: DecodingResult,
  819. ):
  820. text = tokenizer.decode(
  821. [token for token in text_tokens if token < tokenizer.eot]
  822. )
  823. if len(text.strip()) == 0: # skip empty text output
  824. return
  825. all_segments.append(
  826. {
  827. "id": len(all_segments),
  828. "seek": seek,
  829. "start": start,
  830. "end": end,
  831. "text": text,
  832. "tokens": result.tokens,
  833. "temperature": result.temperature,
  834. "avg_logprob": result.avg_logprob,
  835. "compression_ratio": result.compression_ratio,
  836. "no_speech_prob": result.no_speech_prob,
  837. }
  838. )
  839. if verbose:
  840. print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
  841. # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
  842. num_frames = mel.shape[-1]
  843. previous_seek_value = seek
  844. with tqdm.tqdm(
  845. total=num_frames, unit="frames", disable=verbose is not False
  846. ) as pbar:
  847. while seek < num_frames:
  848. timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
  849. segment = pad_or_trim(mel[:, seek:], N_FRAMES)
  850. segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
  851. decode_options["prompt"] = all_tokens[prompt_reset_since:]
  852. result: DecodingResult = decode_with_fallback(segment)
  853. tokens = paddle.to_tensor(result.tokens)
  854. if no_speech_threshold is not None:
  855. # no voice activity check
  856. should_skip = result.no_speech_prob > no_speech_threshold
  857. if (
  858. logprob_threshold is not None
  859. and result.avg_logprob > logprob_threshold
  860. ):
  861. # don't skip if the logprob is high enough, despite the no_speech_prob
  862. should_skip = False
  863. if should_skip:
  864. seek += segment.shape[
  865. -1
  866. ] # fast-forward to the next segment boundary
  867. continue
  868. timestamp_tokens: paddle.Tensor = tokens.greater_equal(
  869. paddle.to_tensor(tokenizer.timestamp_begin)
  870. )
  871. consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
  872. if (
  873. len(consecutive) > 0
  874. ): # if the output contains two consecutive timestamp tokens
  875. consecutive = paddle.add(consecutive, paddle.to_tensor(1))
  876. last_slice = 0
  877. for current_slice in consecutive:
  878. sliced_tokens = tokens[last_slice:current_slice]
  879. start_timestamp_position = (
  880. sliced_tokens[0].item() - tokenizer.timestamp_begin
  881. )
  882. end_timestamp_position = (
  883. sliced_tokens[-1].item() - tokenizer.timestamp_begin
  884. )
  885. add_segment(
  886. start=timestamp_offset
  887. + start_timestamp_position * time_precision,
  888. end=timestamp_offset + end_timestamp_position * time_precision,
  889. text_tokens=sliced_tokens[1:-1],
  890. result=result,
  891. )
  892. last_slice = current_slice
  893. last_timestamp_position = (
  894. tokens[last_slice - 1].item() - tokenizer.timestamp_begin
  895. )
  896. seek += last_timestamp_position * input_stride
  897. all_tokens.extend(tokens[: last_slice + 1].tolist())
  898. else:
  899. duration = segment_duration
  900. timestamps = tokens[timestamp_tokens.nonzero().flatten()]
  901. if (
  902. len(timestamps) > 0
  903. and timestamps[-1].item() != tokenizer.timestamp_begin
  904. ):
  905. # no consecutive timestamps but it has a timestamp; use the last one.
  906. # single timestamp at the end means no speech after the last timestamp.
  907. last_timestamp_position = (
  908. timestamps[-1].item() - tokenizer.timestamp_begin
  909. )
  910. duration = last_timestamp_position * time_precision
  911. add_segment(
  912. start=timestamp_offset,
  913. end=timestamp_offset + duration,
  914. text_tokens=tokens,
  915. result=result,
  916. )
  917. seek += segment.shape[-1]
  918. all_tokens.extend(tokens.tolist())
  919. if not condition_on_previous_text or result.temperature > 0.5:
  920. # do not feed the prompt tokens if a high temperature was used
  921. prompt_reset_since = len(all_tokens)
  922. # update progress bar
  923. pbar.update(min(num_frames, seek) - previous_seek_value)
  924. previous_seek_value = seek
  925. return dict(
  926. text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
  927. segments=all_segments,
  928. language=language,
  929. )
  930. class SequenceRanker:
  931. def rank(
  932. self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
  933. ) -> List[int]:
  934. """
  935. Given a list of groups of samples and their cumulative log probabilities,
  936. return the indices of the samples in each group to select as the final result
  937. """
  938. raise NotImplementedError
  939. class MaximumLikelihoodRanker(SequenceRanker):
  940. """
  941. Select the sample with the highest log probabilities, penalized using either
  942. a simple length normalization or Google NMT paper's length penalty
  943. """
  944. def __init__(self, length_penalty: Optional[float]):
  945. self.length_penalty = length_penalty
  946. def rank(self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]):
  947. def scores(logprobs, lengths):
  948. result = []
  949. for logprob, length in zip(logprobs, lengths):
  950. if self.length_penalty is None or self.length_penalty == "None":
  951. penalty = length
  952. else:
  953. # from the Google NMT paper
  954. penalty = ((5 + length) / 6) ** self.length_penalty
  955. result.append(logprob / penalty)
  956. return result
  957. # get the sequence with the highest score
  958. lengths = [[len(t) for t in s] for s in tokens]
  959. return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
  960. class TokenDecoder:
  961. def reset(self):
  962. """Initialize any stateful variables for decoding a new sequence"""
  963. def update(
  964. self,
  965. tokens: paddle.Tensor,
  966. logits: paddle.Tensor,
  967. sum_logprobs: paddle.Tensor,
  968. ) -> Tuple[paddle.Tensor, bool]:
  969. """Specify how to select the next token, based on the current trace and logits
  970. Parameters
  971. ----------
  972. tokens : Tensor, shape = (n_batch, current_sequence_length)
  973. all tokens in the context so far, including the prefix and sot_sequence tokens
  974. logits : Tensor, shape = (n_batch, vocab_size)
  975. per-token logits of the probability distribution at the current step
  976. sum_logprobs : Tensor, shape = (n_batch)
  977. cumulative log probabilities for each sequence
  978. Returns
  979. -------
  980. tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
  981. the tokens, appended with the selected next token
  982. completed : bool
  983. True if all sequences has reached the end of text
  984. """
  985. raise NotImplementedError
  986. def finalize(
  987. self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
  988. ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
  989. """Finalize search and return the final candidate sequences
  990. Parameters
  991. ----------
  992. tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
  993. all tokens in the context so far, including the prefix and sot_sequence
  994. sum_logprobs : Tensor, shape = (batch_size, beam_size)
  995. cumulative log probabilities for each sequence
  996. Returns
  997. -------
  998. tokens : Sequence[Sequence[Tensor]], length = batch_size
  999. sequence of Tensors containing candidate token sequences, for each audio input
  1000. sum_logprobs : List[List[float]], length = batch_size
  1001. sequence of cumulative log probabilities corresponding to the above
  1002. """
  1003. raise NotImplementedError
  1004. class GreedyDecoder(TokenDecoder):
  1005. def __init__(self, temperature: float, eot: int):
  1006. self.temperature = temperature
  1007. self.eot = eot
  1008. def update(
  1009. self,
  1010. tokens: paddle.Tensor,
  1011. logits: paddle.Tensor,
  1012. sum_logprobs: paddle.Tensor,
  1013. ) -> Tuple[paddle.Tensor, bool]:
  1014. temperature = self.temperature
  1015. if temperature == 0:
  1016. next_tokens = paddle.argmax(logits, axis=-1)
  1017. else:
  1018. next_tokens = paddle.distribution.Categorical(
  1019. logits=logits / temperature
  1020. ).sample([1])
  1021. next_tokens = paddle.reshape(
  1022. next_tokens,
  1023. [
  1024. next_tokens.shape[0] * next_tokens.shape[1],
  1025. ],
  1026. )
  1027. logprobs = paddle.nn.functional.log_softmax(
  1028. logits, axis=-1, dtype=paddle.float32
  1029. )
  1030. current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
  1031. sum_logprobs += current_logprobs * paddle.to_tensor(
  1032. (tokens[:, -1] != self.eot), dtype=paddle.float32
  1033. )
  1034. next_tokens[tokens[:, -1] == self.eot] = self.eot
  1035. tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
  1036. completed = paddle.all((tokens[:, -1] == self.eot))
  1037. return tokens, completed
  1038. def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
  1039. # make sure each sequence has at least one EOT token at the end
  1040. tokens = paddle.nn.functional.pad(
  1041. tokens, (0, 1), value=self.eot, data_format="NCL"
  1042. )
  1043. return tokens, sum_logprobs.tolist()
  1044. class BeamSearchDecoder(TokenDecoder):
  1045. def __init__(
  1046. self,
  1047. beam_size: int,
  1048. eot: int,
  1049. inference: Inference,
  1050. patience: Optional[float] = None,
  1051. ):
  1052. self.beam_size = beam_size
  1053. self.eot = eot
  1054. self.inference = inference
  1055. self.patience = patience or 1.0
  1056. if patience is None or patience == "None":
  1057. self.patience = 1.0
  1058. else:
  1059. self.patience = patience
  1060. self.max_candidates: int = round(beam_size * self.patience)
  1061. self.finished_sequences = None
  1062. assert (
  1063. self.max_candidates > 0
  1064. ), f"Invalid beam size ({beam_size}) or patience ({patience})"
  1065. def reset(self):
  1066. self.finished_sequences = None
  1067. def update(
  1068. self,
  1069. tokens: paddle.Tensor,
  1070. logits: paddle.Tensor,
  1071. sum_logprobs: paddle.Tensor,
  1072. ) -> Tuple[paddle.Tensor, bool]:
  1073. if tokens.shape[0] % self.beam_size != 0:
  1074. raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
  1075. batch_size = tokens.shape[0] // self.beam_size
  1076. if self.finished_sequences is None: # for the first update
  1077. self.finished_sequences = [{} for _ in range(batch_size)]
  1078. logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
  1079. next_tokens, source_indices, finished_sequences = [], [], []
  1080. for i in range(batch_size):
  1081. scores, sources, finished = {}, {}, {}
  1082. # STEP 1: calculate the cumulative log probabilities for possible candidates
  1083. for j in range(self.beam_size):
  1084. idx = i * self.beam_size + j
  1085. prefix = tokens[idx].tolist()
  1086. logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
  1087. for logprob, token in zip(logprob, token):
  1088. new_logprob = (sum_logprobs[idx] + logprob).item()
  1089. sequence = tuple(prefix + [token.item()])
  1090. scores[sequence] = new_logprob
  1091. sources[sequence] = idx
  1092. # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
  1093. saved = 0
  1094. for sequence in sorted(scores, key=scores.get, reverse=True):
  1095. if sequence[-1] == self.eot:
  1096. finished[sequence] = scores[sequence]
  1097. else:
  1098. sum_logprobs[len(next_tokens)] = scores[sequence]
  1099. next_tokens.append(sequence)
  1100. source_indices.append(sources[sequence])
  1101. saved += 1
  1102. if saved == self.beam_size:
  1103. break
  1104. finished_sequences.append(finished)
  1105. tokens = paddle.to_tensor(next_tokens)
  1106. self.inference.rearrange_kv_cache(source_indices)
  1107. # add newly finished sequences to self.finished_sequences
  1108. assert len(self.finished_sequences) == len(finished_sequences)
  1109. for previously_finished, newly_finished in zip(
  1110. self.finished_sequences, finished_sequences
  1111. ):
  1112. for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
  1113. if len(previously_finished) >= self.max_candidates:
  1114. break # the candidate list is full
  1115. previously_finished[seq] = newly_finished[seq]
  1116. # mark as completed if all audio has enough number of samples
  1117. completed = all(
  1118. len(sequences) >= self.max_candidates
  1119. for sequences in self.finished_sequences
  1120. )
  1121. return tokens, completed
  1122. def finalize(self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
  1123. # collect all finished sequences, including patience, and add unfinished ones if not enough
  1124. sum_logprobs = sum_logprobs.cpu()
  1125. for i, sequences in enumerate(self.finished_sequences):
  1126. if (
  1127. len(sequences) < self.beam_size
  1128. ): # when not enough sequences are finished
  1129. for j in list(np.argsort(sum_logprobs[i]))[::-1]:
  1130. sequence = preceding_tokens[i, j].tolist() + [self.eot]
  1131. sequences[tuple(sequence)] = sum_logprobs[i][j].item()
  1132. if len(sequences) >= self.beam_size:
  1133. break
  1134. tokens: List[List[paddle.Tensor]] = [
  1135. [paddle.to_tensor(seq) for seq in sequences.keys()]
  1136. for sequences in self.finished_sequences
  1137. ]
  1138. sum_logprobs: List[List[float]] = [
  1139. list(sequences.values()) for sequences in self.finished_sequences
  1140. ]
  1141. return tokens, sum_logprobs
  1142. class LogitFilter:
  1143. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
  1144. """Apply any filtering or masking to logits in-place
  1145. Parameters
  1146. ----------
  1147. logits : Tensor, shape = (n_batch, vocab_size)
  1148. per-token logits of the probability distribution at the current step
  1149. tokens : Tensor, shape = (n_batch, current_sequence_length)
  1150. all tokens in the context so far, including the prefix and sot_sequence tokens
  1151. """
  1152. raise NotImplementedError
  1153. class SuppressBlank(LogitFilter):
  1154. def __init__(self, tokenizer: Tokenizer, sample_begin: int):
  1155. self.tokenizer = tokenizer
  1156. self.sample_begin = sample_begin
  1157. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1158. if tokens.shape[1] == self.sample_begin:
  1159. logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
  1160. -np.inf
  1161. )
  1162. class SuppressTokens(LogitFilter):
  1163. def __init__(self, suppress_tokens: Sequence[int]):
  1164. self.suppress_tokens = list(suppress_tokens)
  1165. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1166. logits[:, self.suppress_tokens] = -np.inf
  1167. class ApplyTimestampRules(LogitFilter):
  1168. def __init__(
  1169. self,
  1170. tokenizer: Tokenizer,
  1171. sample_begin: int,
  1172. max_initial_timestamp_index: Optional[int],
  1173. ):
  1174. self.tokenizer = tokenizer
  1175. self.sample_begin = sample_begin
  1176. self.max_initial_timestamp_index = max_initial_timestamp_index
  1177. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1178. # suppress <|notimestamps|> which is handled by without_timestamps
  1179. if self.tokenizer.no_timestamps is not None:
  1180. logits[:, self.tokenizer.no_timestamps] = -np.inf
  1181. # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
  1182. for k in range(tokens.shape[0]):
  1183. seq = [t for t in tokens[k, self.sample_begin :].tolist()]
  1184. last_was_timestamp = (
  1185. len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
  1186. )
  1187. penultimate_was_timestamp = (
  1188. len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
  1189. )
  1190. if last_was_timestamp:
  1191. if penultimate_was_timestamp: # has to be non-timestamp
  1192. logits[k, self.tokenizer.timestamp_begin :] = -np.inf
  1193. else: # cannot be normal text tokens
  1194. logits[k, : self.tokenizer.eot] = -np.inf
  1195. # apply the `max_initial_timestamp` option
  1196. if (
  1197. tokens.shape[1] == self.sample_begin
  1198. and self.max_initial_timestamp_index is not None
  1199. ):
  1200. last_allowed = (
  1201. self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
  1202. )
  1203. logits[:, last_allowed + 1 :] = -np.inf
  1204. # if sum of probability over timestamps is above any other token, sample timestamp
  1205. logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
  1206. for k in range(tokens.shape[0]):
  1207. # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
  1208. # To bypass this issue in CI, we have decomposed the operation into separate steps.
  1209. # It will raise 2e-6 difference in precision.
  1210. # TODO: revert this after logsumexp been fixed.
  1211. timestamp_logprob = paddle.exp(
  1212. logprobs[k, self.tokenizer.timestamp_begin :]
  1213. )
  1214. timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
  1215. timestamp_logprob = paddle.log(timestamp_logprob)
  1216. max_text_token_logprob = paddle.max(
  1217. logprobs[k, : self.tokenizer.timestamp_begin]
  1218. )
  1219. if timestamp_logprob > max_text_token_logprob:
  1220. logits[k, : self.tokenizer.timestamp_begin] = -np.inf
  1221. class DecodingTask:
  1222. inference: Inference
  1223. sequence_ranker: SequenceRanker
  1224. decoder: TokenDecoder
  1225. logit_filters: List[LogitFilter]
  1226. def __init__(self, model: "Whisper", options: DecodingOptions, resource_path: str):
  1227. self.model = model
  1228. language = options.language or "en"
  1229. tokenizer = get_tokenizer(
  1230. model.is_multilingual,
  1231. resource_path=resource_path,
  1232. language=language,
  1233. task=options.task,
  1234. )
  1235. self.tokenizer: Tokenizer = tokenizer
  1236. self.options: DecodingOptions = self._verify_options(options)
  1237. self.resource_path: str = resource_path
  1238. self.beam_size: int = options.beam_size or options.best_of or 1
  1239. self.n_ctx: int = model.dims.n_text_ctx
  1240. self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
  1241. self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
  1242. if self.options.without_timestamps:
  1243. self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
  1244. self.initial_tokens: Tuple[int] = self._get_initial_tokens()
  1245. self.sample_begin: int = len(self.initial_tokens)
  1246. self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
  1247. # inference: implements the forward pass through the decoder, including kv caching
  1248. self.inference = WhisperInference(model, len(self.initial_tokens))
  1249. # sequence ranker: implements how to rank a group of sampled sequences
  1250. self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
  1251. # decoder: implements how to select the next tokens, given the autoregressive distribution
  1252. if options.beam_size is not None:
  1253. self.decoder = BeamSearchDecoder(
  1254. options.beam_size, tokenizer.eot, self.inference, options.patience
  1255. )
  1256. else:
  1257. self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
  1258. # logit filters: applies various rules to suppress or penalize certain tokens
  1259. self.logit_filters = []
  1260. if self.options.suppress_blank:
  1261. self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
  1262. if self.options.suppress_tokens:
  1263. self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
  1264. if not options.without_timestamps:
  1265. precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
  1266. max_initial_timestamp_index = None
  1267. if options.max_initial_timestamp:
  1268. max_initial_timestamp_index = round(
  1269. self.options.max_initial_timestamp / precision
  1270. )
  1271. self.logit_filters.append(
  1272. ApplyTimestampRules(
  1273. tokenizer, self.sample_begin, max_initial_timestamp_index
  1274. )
  1275. )
  1276. def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
  1277. if options.beam_size is not None and options.best_of is not None:
  1278. raise ValueError("beam_size and best_of can't be given together")
  1279. if options.temperature == 0:
  1280. if options.best_of is not None:
  1281. raise ValueError("best_of with greedy sampling (T=0) is not compatible")
  1282. if options.patience is not None and options.beam_size is None:
  1283. raise ValueError("patience requires beam_size to be given")
  1284. if options.length_penalty is not None and options.length_penalty != "None":
  1285. if not (0 <= options.length_penalty <= 1):
  1286. raise ValueError(
  1287. "length_penalty (alpha) should be a value between 0 and 1"
  1288. )
  1289. return options
  1290. def _get_initial_tokens(self) -> Tuple[int]:
  1291. tokens = list(self.sot_sequence)
  1292. prefix = self.options.prefix
  1293. prompt = self.options.prompt
  1294. if prefix:
  1295. prefix_tokens = (
  1296. self.tokenizer.encode(" " + prefix.strip().input_ids)
  1297. if isinstance(prefix, str)
  1298. else prefix
  1299. )
  1300. if self.sample_len is not None:
  1301. max_prefix_len = self.n_ctx // 2 - self.sample_len
  1302. prefix_tokens = prefix_tokens[-max_prefix_len:]
  1303. tokens = tokens + prefix_tokens
  1304. if prompt:
  1305. prompt_tokens = (
  1306. self.tokenizer.encode(" " + prompt.strip().input_ids)
  1307. if isinstance(prompt, str)
  1308. else prompt
  1309. )
  1310. tokens = (
  1311. [self.tokenizer.sot_prev]
  1312. + prompt_tokens[-(self.n_ctx // 2 - 1) :]
  1313. + tokens
  1314. )
  1315. return tuple(tokens)
  1316. def _get_suppress_tokens(self) -> Tuple[int]:
  1317. suppress_tokens = self.options.suppress_tokens
  1318. if isinstance(suppress_tokens, str):
  1319. suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
  1320. if -1 in suppress_tokens:
  1321. suppress_tokens = [t for t in suppress_tokens if t >= 0]
  1322. suppress_tokens.extend(self.tokenizer.non_speech_tokens)
  1323. elif suppress_tokens is None or len(suppress_tokens) == 0:
  1324. suppress_tokens = [] # interpret empty string as an empty list
  1325. else:
  1326. assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
  1327. suppress_tokens.extend(
  1328. [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
  1329. )
  1330. if self.tokenizer.no_speech is not None:
  1331. # no-speech probability is collected separately
  1332. suppress_tokens.append(self.tokenizer.no_speech)
  1333. return tuple(sorted(set(suppress_tokens)))
  1334. def _get_audio_features(self, mel: paddle.Tensor):
  1335. if mel.shape[-2:] == (
  1336. self.model.dims.n_audio_ctx,
  1337. self.model.dims.n_audio_state,
  1338. ):
  1339. # encoded audio features are given; skip audio encoding
  1340. audio_features = mel
  1341. else:
  1342. audio_features = self.model.encoder(mel)
  1343. return audio_features
  1344. def _detect_language(
  1345. self,
  1346. audio_features: paddle.Tensor,
  1347. tokens: paddle.Tensor,
  1348. resource_path: str,
  1349. ):
  1350. languages = [self.options.language] * audio_features.shape[0]
  1351. lang_probs = None
  1352. if self.options.language is None or self.options.task == "lang_id":
  1353. lang_tokens, lang_probs = self.model.detect_language(
  1354. audio_features, self.tokenizer, self.resource_path
  1355. )
  1356. languages = [max(probs, key=probs.get) for probs in lang_probs]
  1357. if self.options.language is None:
  1358. tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
  1359. return languages, lang_probs
  1360. def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
  1361. assert audio_features.shape[0] == tokens.shape[0]
  1362. n_batch = tokens.shape[0]
  1363. sum_logprobs: paddle.Tensor = paddle.zeros(
  1364. paddle.to_tensor(n_batch), dtype=paddle.float32
  1365. )
  1366. no_speech_probs = [np.nan] * n_batch
  1367. try:
  1368. for i in range(self.sample_len):
  1369. logits = self.inference.logits(tokens, audio_features)
  1370. if (
  1371. i == 0 and self.tokenizer.no_speech is not None
  1372. ): # save no_speech_probs
  1373. probs_at_sot = paddle.nn.functional.softmax(
  1374. logits[:, self.sot_index], axis=-1, dtype=paddle.float32
  1375. )
  1376. no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
  1377. # now we need to consider the logits at the last token only
  1378. logits = logits[:, -1]
  1379. # apply the logit filters, e.g. for suppressing or applying penalty to
  1380. for logit_filter in self.logit_filters:
  1381. logit_filter.apply(logits, tokens)
  1382. # expand the tokens tensor with the selected next tokens
  1383. tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
  1384. if completed or tokens.shape[-1] > self.n_ctx:
  1385. break
  1386. finally:
  1387. self.inference.cleanup_caching()
  1388. return tokens, sum_logprobs, no_speech_probs
  1389. @paddle.no_grad()
  1390. def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
  1391. self.decoder.reset()
  1392. tokenizer: Tokenizer = self.tokenizer
  1393. batch_size: int = mel.shape[0]
  1394. audio_features: paddle.Tensor = self._get_audio_features(
  1395. mel
  1396. ) # encoder forward pass
  1397. tokens: paddle.Tensor
  1398. if batch_size > 1:
  1399. for i in range(batch_size):
  1400. tokens = paddle.concat(
  1401. x=[
  1402. paddle.to_tensor([self.initial_tokens]),
  1403. paddle.to_tensor([self.initial_tokens]),
  1404. ],
  1405. axis=0,
  1406. )
  1407. elif batch_size == 1:
  1408. tokens = paddle.to_tensor([self.initial_tokens])
  1409. # detect language if requested, overwriting the language token
  1410. languages, language_probs = self._detect_language(
  1411. paddle.to_tensor(audio_features),
  1412. paddle.to_tensor(tokens),
  1413. self.resource_path,
  1414. )
  1415. if self.options.task == "lang_id":
  1416. return [
  1417. DecodingResult(
  1418. audio_features=features, language=language, language_probs=probs
  1419. )
  1420. for features, language, probs in zip(
  1421. audio_features, languages, language_probs
  1422. )
  1423. ]
  1424. # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
  1425. audio_features = paddle.repeat_interleave(
  1426. audio_features, self.beam_size, axis=0
  1427. )
  1428. tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
  1429. # call the main sampling loop
  1430. tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
  1431. # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
  1432. audio_features = audio_features[:: self.beam_size]
  1433. no_speech_probs = no_speech_probs[:: self.beam_size]
  1434. assert audio_features.shape[0] == len(no_speech_probs) == batch_size
  1435. tokens = tokens.reshape([batch_size, self.beam_size, -1])
  1436. sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
  1437. # get the final candidates for each group, and slice between the first sampled token and EOT
  1438. tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
  1439. tokens: List[List[paddle.Tensor]] = [
  1440. [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
  1441. for s in tokens
  1442. ]
  1443. # select the top-ranked sample in each group
  1444. selected = self.sequence_ranker.rank(tokens, sum_logprobs)
  1445. tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
  1446. texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
  1447. sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
  1448. avg_logprobs: List[float] = [
  1449. lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
  1450. ]
  1451. fields = (
  1452. texts,
  1453. languages,
  1454. tokens,
  1455. audio_features,
  1456. avg_logprobs,
  1457. no_speech_probs,
  1458. )
  1459. if len(set(map(len, fields))) != 1:
  1460. raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
  1461. return [
  1462. DecodingResult(
  1463. audio_features=features,
  1464. language=language,
  1465. tokens=tokens,
  1466. text=text,
  1467. avg_logprob=avg_logprob,
  1468. no_speech_prob=no_speech_prob,
  1469. temperature=self.options.temperature,
  1470. compression_ratio=compression_ratio(text),
  1471. )
  1472. for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
  1473. *fields
  1474. )
  1475. ]
  1476. @paddle.no_grad()
  1477. def decode(
  1478. model: "Whisper",
  1479. mel: paddle.Tensor,
  1480. options: DecodingOptions = DecodingOptions(),
  1481. resource_path=str,
  1482. ) -> Union[DecodingResult, List[DecodingResult]]:
  1483. """
  1484. Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
  1485. Parameters
  1486. ----------
  1487. model: Whisper
  1488. the Whisper model instance
  1489. mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
  1490. A tensor containing the Mel spectrogram(s)
  1491. options: DecodingOptions
  1492. A dataclass that contains all necessary options for decoding 30-second segments
  1493. Returns
  1494. -------
  1495. result: Union[DecodingResult, List[DecodingResult]]
  1496. The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
  1497. """
  1498. single = mel.ndim == 2
  1499. if single:
  1500. mel = mel.unsqueeze(0)
  1501. result = DecodingTask(model, options, resource_path).run(mel)
  1502. if single:
  1503. result = result[0]
  1504. return result
  1505. class Whisper(paddle.nn.Layer):
  1506. """
  1507. The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
  1508. """
  1509. def __init__(self, dims: ModelDimensions):
  1510. super().__init__()
  1511. self.dims = dims
  1512. self.encoder = AudioEncoder(
  1513. self.dims.n_mels,
  1514. self.dims.n_audio_ctx,
  1515. self.dims.n_audio_state,
  1516. self.dims.n_audio_head,
  1517. self.dims.n_audio_layer,
  1518. )
  1519. self.decoder = TextDecoder(
  1520. self.dims.n_vocab,
  1521. self.dims.n_text_ctx,
  1522. self.dims.n_text_state,
  1523. self.dims.n_text_head,
  1524. self.dims.n_text_layer,
  1525. )
  1526. def embed_audio(self, mel: paddle.Tensor):
  1527. return self.encoder.forward(mel)
  1528. def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
  1529. return self.decoder.forward(tokens, audio_features)
  1530. def forward(
  1531. self, mel: paddle.Tensor, tokens: paddle.Tensor
  1532. ) -> Dict[str, paddle.Tensor]:
  1533. return self.decoder(tokens, self.encoder(mel))
  1534. @property
  1535. def device(self):
  1536. return paddle.device.get_device()
  1537. @property
  1538. def is_multilingual(self):
  1539. return self.dims.n_vocab == 51865
  1540. def install_kv_cache_hooks(self, cache: Optional[dict] = None):
  1541. """
  1542. The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
  1543. tensors calculated for the previous positions. This method returns a dictionary that stores
  1544. all caches, and the necessary hooks for the key and value projection modules that save the
  1545. intermediate tensors to be reused during later calculations.
  1546. Returns
  1547. -------
  1548. cache : Dict[nn.Layer, paddle.Tensor]
  1549. A dictionary object mapping the key/value projection modules to its cache
  1550. hooks : List[RemovableHandle]
  1551. List of PyTorch RemovableHandle objects to stop the hooks to be called
  1552. """
  1553. cache = {**cache} if cache is not None else {}
  1554. hooks = []
  1555. def save_to_cache(module, _, output):
  1556. if (
  1557. module not in cache
  1558. or output.shape[1] > self.decoder.positional_embedding.shape[0]
  1559. ):
  1560. cache[module] = (
  1561. output # save as-is, for the first token or cross attention
  1562. )
  1563. else:
  1564. cache[module] = paddle.concat([cache[module], output], axis=1).detach()
  1565. return cache[module]
  1566. def install_hooks(layer: paddle.nn.Layer):
  1567. if isinstance(layer, MultiHeadAttention):
  1568. hooks.append(layer.key.register_forward_post_hook(save_to_cache))
  1569. hooks.append(layer.value.register_forward_post_hook(save_to_cache))
  1570. self.decoder.apply(install_hooks)
  1571. return cache, hooks
  1572. detect_language = detect_language
  1573. set_inference_operations(get_inference_operations() + ["speech_transcribe"])
  1574. transcribe = benchmark.timeit_with_options(name="speech_transcribe")(transcribe)
  1575. decode = decode
  1576. def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
  1577. """
  1578. Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
  1579. """
  1580. if paddle.is_tensor(array):
  1581. if array.shape[axis] > length:
  1582. array = array.index_select(axis=axis, index=paddle.arange(length))
  1583. if array.shape[axis] < length:
  1584. pad_widths = [(0, 0)] * array.ndim
  1585. pad_widths[axis] = (0, length - array.shape[axis])
  1586. array = paddle.transpose(array, (1, 0))
  1587. array = paddle.nn.functional.pad(
  1588. array,
  1589. [pad for sizes in pad_widths[::-1] for pad in sizes],
  1590. data_format="NLC",
  1591. )
  1592. array = paddle.transpose(array, (1, 0))
  1593. else:
  1594. if array.shape[axis] > length:
  1595. array = array.take(indices=range(length), axis=axis)
  1596. if array.shape[axis] < length:
  1597. pad_widths = [(0, 0)] * array.ndim
  1598. pad_widths[axis] = (0, length - array.shape[axis])
  1599. array = paddle.transpose(array, (1, 0))
  1600. array = np.pad(array, pad_widths)
  1601. array = paddle.transpose(array, (1, 0))
  1602. return array
  1603. def hann_window(n_fft: int = N_FFT):
  1604. """
  1605. hanning window
  1606. n_fft: The number of frequency components of the discrete Fourier transform.
  1607. """
  1608. return paddle.to_tensor(
  1609. [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
  1610. dtype=paddle.float32,
  1611. )
  1612. @lru_cache(maxsize=None)
  1613. def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
  1614. """
  1615. load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
  1616. Allows decoupling librosa dependency; saved using:
  1617. np.savez_compressed(
  1618. "mel_filters.npz",
  1619. mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
  1620. )
  1621. """
  1622. assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
  1623. with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
  1624. return paddle.to_tensor(f[f"mel_{n_mels}"])
  1625. @function_requires_deps("soundfile")
  1626. def log_mel_spectrogram(
  1627. audio: Union[str, np.ndarray, paddle.Tensor],
  1628. n_mels: int = N_MELS,
  1629. resource_path: str = None,
  1630. ):
  1631. """
  1632. Compute the log-Mel spectrogram of
  1633. Parameters
  1634. ----------
  1635. audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
  1636. The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
  1637. n_mels: int
  1638. The number of Mel-frequency filters, only 80 is supported
  1639. Returns
  1640. -------
  1641. paddle.Tensor, shape = (80, n_frames)
  1642. A Tensor that contains the Mel spectrogram
  1643. """
  1644. if not paddle.is_tensor(audio):
  1645. if isinstance(audio, str):
  1646. audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
  1647. audio = audio[:, 0]
  1648. audio = paddle.to_tensor(audio)
  1649. window = hann_window(N_FFT)
  1650. stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
  1651. magnitudes = stft[:, :-1].abs() ** 2
  1652. filters = mel_filters(resource_path, n_mels)
  1653. mel_spec = filters @ magnitudes
  1654. mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
  1655. log_spec = paddle.clip(mel_spec, min=1e-10).log10()
  1656. log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
  1657. log_spec = (log_spec + 4.0) / 4.0
  1658. return log_spec