processors.py 75 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
  15. import os
  16. import zlib
  17. from dataclasses import dataclass, field
  18. from functools import lru_cache
  19. from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
  20. import numpy as np
  21. from ....utils.deps import function_requires_deps, is_dep_available
  22. from ..common.tokenizer import GPTTokenizer
  23. if is_dep_available("soundfile"):
  24. import soundfile
  25. if is_dep_available("tqdm"):
  26. import tqdm
  27. if is_dep_available("paddlepaddle"):
  28. import paddle
  29. __all__ = [
  30. "Whisper",
  31. "Tokenizer",
  32. ]
  33. def exact_div(x, y):
  34. assert x % y == 0
  35. return x // y
  36. _MODELS = ["large"]
  37. SAMPLE_RATE = 16000
  38. N_FFT = 400
  39. N_MELS = 80
  40. HOP_LENGTH = 160
  41. CHUNK_LENGTH = 30
  42. N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
  43. N_FRAMES = exact_div(
  44. N_SAMPLES, HOP_LENGTH
  45. ) # 3000: number of frames in a mel spectrogram input
  46. @dataclass
  47. class ModelDimensions:
  48. n_mels: int
  49. n_audio_ctx: int
  50. n_audio_state: int
  51. n_audio_head: int
  52. n_audio_layer: int
  53. n_vocab: int
  54. n_text_ctx: int
  55. n_text_state: int
  56. n_text_head: int
  57. n_text_layer: int
  58. LANGUAGES = {
  59. "en": "english",
  60. "zh": "chinese",
  61. "de": "german",
  62. "es": "spanish",
  63. "ru": "russian",
  64. "ko": "korean",
  65. "fr": "french",
  66. "ja": "japanese",
  67. "pt": "portuguese",
  68. "tr": "turkish",
  69. "pl": "polish",
  70. "ca": "catalan",
  71. "nl": "dutch",
  72. "ar": "arabic",
  73. "sv": "swedish",
  74. "it": "italian",
  75. "id": "indonesian",
  76. "hi": "hindi",
  77. "fi": "finnish",
  78. "vi": "vietnamese",
  79. "iw": "hebrew",
  80. "uk": "ukrainian",
  81. "el": "greek",
  82. "ms": "malay",
  83. "cs": "czech",
  84. "ro": "romanian",
  85. "da": "danish",
  86. "hu": "hungarian",
  87. "ta": "tamil",
  88. "no": "norwegian",
  89. "th": "thai",
  90. "ur": "urdu",
  91. "hr": "croatian",
  92. "bg": "bulgarian",
  93. "lt": "lithuanian",
  94. "la": "latin",
  95. "mi": "maori",
  96. "ml": "malayalam",
  97. "cy": "welsh",
  98. "sk": "slovak",
  99. "te": "telugu",
  100. "fa": "persian",
  101. "lv": "latvian",
  102. "bn": "bengali",
  103. "sr": "serbian",
  104. "az": "azerbaijani",
  105. "sl": "slovenian",
  106. "kn": "kannada",
  107. "et": "estonian",
  108. "mk": "macedonian",
  109. "br": "breton",
  110. "eu": "basque",
  111. "is": "icelandic",
  112. "hy": "armenian",
  113. "ne": "nepali",
  114. "mn": "mongolian",
  115. "bs": "bosnian",
  116. "kk": "kazakh",
  117. "sq": "albanian",
  118. "sw": "swahili",
  119. "gl": "galician",
  120. "mr": "marathi",
  121. "pa": "punjabi",
  122. "si": "sinhala",
  123. "km": "khmer",
  124. "sn": "shona",
  125. "yo": "yoruba",
  126. "so": "somali",
  127. "af": "afrikaans",
  128. "oc": "occitan",
  129. "ka": "georgian",
  130. "be": "belarusian",
  131. "tg": "tajik",
  132. "sd": "sindhi",
  133. "gu": "gujarati",
  134. "am": "amharic",
  135. "yi": "yiddish",
  136. "lo": "lao",
  137. "uz": "uzbek",
  138. "fo": "faroese",
  139. "ht": "haitian creole",
  140. "ps": "pashto",
  141. "tk": "turkmen",
  142. "nn": "nynorsk",
  143. "mt": "maltese",
  144. "sa": "sanskrit",
  145. "lb": "luxembourgish",
  146. "my": "myanmar",
  147. "bo": "tibetan",
  148. "tl": "tagalog",
  149. "mg": "malagasy",
  150. "as": "assamese",
  151. "tt": "tatar",
  152. "haw": "hawaiian",
  153. "ln": "lingala",
  154. "ha": "hausa",
  155. "ba": "bashkir",
  156. "jw": "javanese",
  157. "su": "sundanese",
  158. }
  159. # language code lookup by name, with a few language aliases
  160. TO_LANGUAGE_CODE = {
  161. **{language: code for code, language in LANGUAGES.items()},
  162. "burmese": "my",
  163. "valencian": "ca",
  164. "flemish": "nl",
  165. "haitian": "ht",
  166. "letzeburgesch": "lb",
  167. "pushto": "ps",
  168. "panjabi": "pa",
  169. "moldavian": "ro",
  170. "moldovan": "ro",
  171. "sinhalese": "si",
  172. "castilian": "es",
  173. }
  174. def compression_ratio(text) -> float:
  175. return len(text) / len(zlib.compress(text.encode("utf-8")))
  176. def format_timestamp(
  177. seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
  178. ):
  179. assert seconds >= 0, "non-negative timestamp expected"
  180. milliseconds = round(seconds * 1000.0)
  181. hours = milliseconds // 3_600_000
  182. milliseconds -= hours * 3_600_000
  183. minutes = milliseconds // 60_000
  184. milliseconds -= minutes * 60_000
  185. seconds = milliseconds // 1_000
  186. milliseconds -= seconds * 1_000
  187. hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
  188. return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
  189. @dataclass(frozen=True)
  190. class Tokenizer:
  191. """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
  192. tokenizer: "GPTTokenizer"
  193. language: Optional[str]
  194. sot_sequence: Tuple[int]
  195. def encode(self, text, **kwargs):
  196. return self.tokenizer.encode(text, **kwargs)
  197. def decode(
  198. self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
  199. ):
  200. if len(token_ids) > 1:
  201. ids_list = []
  202. for ids in token_ids:
  203. if paddle.is_tensor(ids):
  204. ids = ids.item()
  205. if ids < len(self.tokenizer):
  206. ids_list.append(ids)
  207. token_ids = ids_list
  208. elif len(token_ids) == 1:
  209. token_ids = token_ids[0]
  210. else:
  211. raise ValueError(f"token_ids {token_ids} load error.")
  212. return self.tokenizer.decode(token_ids, **kwargs)
  213. def decode_with_timestamps(self, tokens) -> str:
  214. """
  215. Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
  216. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
  217. """
  218. outputs = [[]]
  219. for token in tokens:
  220. if token >= self.timestamp_begin:
  221. timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
  222. outputs.append(timestamp)
  223. outputs.append([])
  224. else:
  225. outputs[-1].append(token)
  226. outputs = [
  227. s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
  228. ]
  229. return "".join(outputs)
  230. @property
  231. @lru_cache()
  232. def eot(self) -> int:
  233. return self.tokenizer.eos_token_id
  234. @property
  235. @lru_cache()
  236. def sot(self) -> int:
  237. return self._get_single_token_id("<|startoftranscript|>")
  238. @property
  239. @lru_cache()
  240. def sot_lm(self) -> int:
  241. return self._get_single_token_id("<|startoflm|>")
  242. @property
  243. @lru_cache()
  244. def sot_prev(self) -> int:
  245. return self._get_single_token_id("<|startofprev|>")
  246. @property
  247. @lru_cache()
  248. def no_speech(self) -> int:
  249. return self._get_single_token_id("<|nospeech|>")
  250. @property
  251. @lru_cache()
  252. def no_timestamps(self) -> int:
  253. return self._get_single_token_id("<|notimestamps|>")
  254. @property
  255. @lru_cache()
  256. def timestamp_begin(self) -> int:
  257. return self.tokenizer.all_special_ids[-1] + 1
  258. @property
  259. @lru_cache()
  260. def language_token(self) -> int:
  261. """Returns the token id corresponding to the value of the `language` field"""
  262. if self.language is None:
  263. raise ValueError(
  264. "This tokenizer does not have language token configured"
  265. )
  266. additional_tokens = dict(
  267. zip(
  268. self.tokenizer.additional_special_tokens,
  269. self.tokenizer.additional_special_tokens_ids,
  270. )
  271. )
  272. candidate = f"<|{self.language}|>"
  273. if candidate in additional_tokens:
  274. return additional_tokens[candidate]
  275. raise KeyError(f"Language {self.language} not found in tokenizer.")
  276. @property
  277. @lru_cache()
  278. def all_language_tokens(self) -> Tuple[int]:
  279. result = []
  280. for token, token_id in zip(
  281. self.tokenizer.additional_special_tokens,
  282. self.tokenizer.additional_special_tokens_ids,
  283. ):
  284. if token.strip("<|>") in LANGUAGES:
  285. result.append(token_id)
  286. return tuple(result)
  287. @property
  288. @lru_cache()
  289. def all_language_codes(self) -> Tuple[str]:
  290. return tuple(
  291. self.decode([l]).strip("<|>") for l in self.all_language_tokens
  292. )
  293. @property
  294. @lru_cache()
  295. def sot_sequence_including_notimestamps(self) -> Tuple[int]:
  296. return tuple(list(self.sot_sequence) + [self.no_timestamps])
  297. @property
  298. @lru_cache()
  299. def non_speech_tokens(self) -> Tuple[int]:
  300. """
  301. Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
  302. annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
  303. - ♪♪♪
  304. - ( SPEAKING FOREIGN LANGUAGE )
  305. - [DAVID] Hey there,
  306. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
  307. """
  308. symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
  309. symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
  310. # symbols that may be a single token or multiple tokens depending on the tokenizer.
  311. # In case they're multiple tokens, suppress the first token, which is safe because:
  312. # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
  313. # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
  314. miscellaneous = set("♩♪♫♬♭♮♯")
  315. assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
  316. # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
  317. result = {
  318. self.tokenizer.encode(" -").input_ids[0],
  319. self.tokenizer.encode(" '").input_ids[0],
  320. }
  321. for symbol in symbols + list(miscellaneous):
  322. for tokens in [
  323. self.tokenizer.encode(symbol).input_ids,
  324. self.tokenizer.encode(" " + symbol).input_ids,
  325. ]:
  326. if len(tokens) == 1 or symbol in miscellaneous:
  327. result.add(tokens[0])
  328. return tuple(sorted(result))
  329. def _get_single_token_id(self, text) -> int:
  330. tokens = self.tokenizer.encode(text).input_ids
  331. assert len(tokens) == 1, f"{text} is not encoded as a single token"
  332. return tokens[0]
  333. @lru_cache(maxsize=None)
  334. def build_tokenizer(resource_path: str, name: str = "gpt2"):
  335. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  336. path = os.path.join(resource_path, "assets", name)
  337. tokenizer = GPTTokenizer.from_pretrained(path)
  338. specials = [
  339. "<|startoftranscript|>",
  340. *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
  341. "<|translate|>",
  342. "<|transcribe|>",
  343. "<|startoflm|>",
  344. "<|startofprev|>",
  345. "<|nospeech|>",
  346. "<|notimestamps|>",
  347. ]
  348. tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
  349. return tokenizer
  350. @lru_cache(maxsize=None)
  351. def get_tokenizer(
  352. multilingual: bool,
  353. resource_path: str,
  354. *,
  355. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  356. language: Optional[str] = None,
  357. ) -> Tokenizer:
  358. if language is not None:
  359. language = language.lower()
  360. if language not in LANGUAGES:
  361. if language in TO_LANGUAGE_CODE:
  362. language = TO_LANGUAGE_CODE[language]
  363. else:
  364. raise ValueError(f"Unsupported language: {language}")
  365. if multilingual:
  366. tokenizer_name = "multilingual"
  367. task = task or "transcribe"
  368. language = language or "en"
  369. else:
  370. tokenizer_name = "gpt2"
  371. task = None
  372. language = None
  373. tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
  374. all_special_ids: List[int] = tokenizer.all_special_ids
  375. sot: int = all_special_ids[1]
  376. translate: int = all_special_ids[-6]
  377. transcribe: int = all_special_ids[-5]
  378. langs = tuple(LANGUAGES.keys())
  379. sot_sequence = [sot]
  380. if language is not None:
  381. sot_sequence.append(sot + 1 + langs.index(language))
  382. if task is not None:
  383. sot_sequence.append(transcribe if task == "transcribe" else translate)
  384. return Tokenizer(
  385. tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
  386. )
  387. class MultiHeadAttention(paddle.nn.Layer):
  388. def __init__(self, n_state: int, n_head: int):
  389. super().__init__()
  390. self.n_head = n_head
  391. self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  392. self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
  393. self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  394. self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  395. def forward(
  396. self,
  397. x: paddle.Tensor,
  398. xa: Optional[paddle.Tensor] = None,
  399. mask: Optional[paddle.Tensor] = None,
  400. kv_cache: Optional[dict] = None,
  401. ):
  402. q = self.query(x)
  403. if kv_cache is None or xa is None or self.key not in kv_cache:
  404. # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
  405. # otherwise, perform key/value projections for self- or cross-attention as usual.
  406. k = self.key(x if xa is None else xa)
  407. v = self.value(x if xa is None else xa)
  408. else:
  409. # for cross-attention, calculate keys and values once and reuse in subsequent calls.
  410. k = kv_cache[self.key]
  411. v = kv_cache[self.value]
  412. wv = self.qkv_attention(q, k, v, mask)
  413. return self.out(wv)
  414. def qkv_attention(
  415. self,
  416. q: paddle.Tensor,
  417. k: paddle.Tensor,
  418. v: paddle.Tensor,
  419. mask: Optional[paddle.Tensor] = None,
  420. ):
  421. n_batch, n_ctx, n_state = q.shape
  422. scale = (n_state // self.n_head) ** -0.25
  423. q = (
  424. paddle.transpose(
  425. q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3)
  426. )
  427. * scale
  428. )
  429. k = (
  430. paddle.transpose(
  431. k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1)
  432. )
  433. * scale
  434. )
  435. v = paddle.transpose(
  436. v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3)
  437. )
  438. qk = q @ k
  439. if mask is not None:
  440. qk = qk + mask[:n_ctx, :n_ctx]
  441. w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
  442. return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
  443. class ResidualAttentionBlock(paddle.nn.Layer):
  444. def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
  445. super().__init__()
  446. self.attn = MultiHeadAttention(n_state, n_head)
  447. self.attn_ln = paddle.nn.LayerNorm(n_state)
  448. self.cross_attn = (
  449. MultiHeadAttention(n_state, n_head) if cross_attention else None
  450. )
  451. self.cross_attn_ln = (
  452. paddle.nn.LayerNorm(n_state) if cross_attention else None
  453. )
  454. n_mlp = n_state * 4
  455. self.mlp = paddle.nn.Sequential(
  456. paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
  457. paddle.nn.GELU(),
  458. paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
  459. )
  460. self.mlp_ln = paddle.nn.LayerNorm(n_state)
  461. def forward(
  462. self,
  463. x: paddle.Tensor,
  464. xa: Optional[paddle.Tensor] = None,
  465. mask: Optional[paddle.Tensor] = None,
  466. kv_cache: Optional[dict] = None,
  467. ):
  468. x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
  469. if self.cross_attn:
  470. x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
  471. x = x + self.mlp(self.mlp_ln(x))
  472. return x
  473. def sinusoids(length, channels, max_timescale=10000):
  474. """Returns sinusoids for positional embedding"""
  475. assert channels % 2 == 0
  476. log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
  477. inv_timescales = paddle.exp(
  478. -log_timescale_increment
  479. * paddle.arange(channels // 2, dtype=paddle.float32)
  480. )
  481. scaled_time = (
  482. paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
  483. * inv_timescales[np.newaxis, :]
  484. )
  485. return paddle.to_tensor(
  486. paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
  487. )
  488. class AudioEncoder(paddle.nn.Layer):
  489. def __init__(
  490. self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  491. ):
  492. super().__init__()
  493. self.conv1 = paddle.nn.Conv1D(
  494. n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
  495. )
  496. self.conv2 = paddle.nn.Conv1D(
  497. n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
  498. )
  499. self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
  500. self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
  501. [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
  502. )
  503. self.ln_post = paddle.nn.LayerNorm(n_state)
  504. def forward(self, x: paddle.Tensor):
  505. """
  506. x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
  507. the mel spectrogram of the audio
  508. """
  509. x = paddle.nn.functional.gelu(self.conv1(x))
  510. x = paddle.nn.functional.gelu(self.conv2(x))
  511. x = paddle.transpose(x, (0, 2, 1))
  512. assert (
  513. x.shape[1:] == self.positional_embedding.shape
  514. ), "incorrect audio shape"
  515. x = x + self.positional_embedding
  516. for block in self.blocks:
  517. x = block(x)
  518. x = self.ln_post(x)
  519. return x
  520. class TextDecoder(paddle.nn.Layer):
  521. def __init__(
  522. self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  523. ):
  524. super().__init__()
  525. self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
  526. self.positional_embedding = paddle.create_parameter(
  527. shape=[n_ctx, n_state], dtype="float32"
  528. )
  529. self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
  530. [
  531. ResidualAttentionBlock(n_state, n_head, cross_attention=True)
  532. for _ in range(n_layer)
  533. ]
  534. )
  535. self.ln = paddle.nn.LayerNorm(n_state)
  536. mask = paddle.full(
  537. shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32"
  538. )
  539. mask = paddle.triu(mask, diagonal=1)
  540. self.register_buffer("mask", mask, persistable=False)
  541. def forward(
  542. self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
  543. ):
  544. """
  545. x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
  546. the text tokens
  547. xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
  548. the encoded audio features to be attended on
  549. """
  550. offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
  551. x = (
  552. self.token_embedding(x)
  553. + self.positional_embedding[offset : offset + x.shape[-1]]
  554. )
  555. x = x.to(xa.dtype)
  556. for block in self.blocks:
  557. x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
  558. x = self.ln(x)
  559. logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
  560. return logits
  561. @dataclass(frozen=True)
  562. class DecodingOptions:
  563. task: str = (
  564. "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
  565. )
  566. language: Optional[str] = (
  567. None # language that the audio is in; uses detected language if None
  568. )
  569. # sampling-related options
  570. temperature: float = 0.0
  571. sample_len: Optional[int] = None # maximum number of tokens to sample
  572. best_of: Optional[int] = (
  573. None # number of independent samples to collect, when t > 0
  574. )
  575. beam_size: Optional[int] = None # number of beams in beam search, when t == 0
  576. patience: Optional[float] = (
  577. None # patience in beam search (https://arxiv.org/abs/2204.05424)
  578. )
  579. # options for ranking generations (either beams or best-of-N samples)
  580. length_penalty: Optional[float] = (
  581. None # "alpha" in Google NMT, None defaults to length norm
  582. )
  583. # prompt, prefix, and token suppression
  584. prompt: Optional[Union[str, List[int]]] = (
  585. None # text or tokens for the previous context
  586. )
  587. prefix: Optional[Union[str, List[int]]] = (
  588. None # text or tokens to prefix the current context
  589. )
  590. suppress_blank: bool = True # this will suppress blank outputs
  591. # list of tokens ids (or comma-separated token ids) to suppress
  592. # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
  593. suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
  594. # timestamp sampling options
  595. without_timestamps: bool = (
  596. False # use <|notimestamps|> to sample text tokens only
  597. )
  598. max_initial_timestamp: Optional[float] = (
  599. 1.0 # the initial timestamp cannot be later than this
  600. )
  601. # implementation details
  602. fp16: bool = False # use fp16 for most of the calculation
  603. @dataclass(frozen=True)
  604. class DecodingResult:
  605. audio_features: paddle.Tensor
  606. language: str
  607. language_probs: Optional[Dict[str, float]] = None
  608. tokens: List[int] = field(default_factory=list)
  609. text: str = ""
  610. avg_logprob: float = np.nan
  611. no_speech_prob: float = np.nan
  612. temperature: float = np.nan
  613. compression_ratio: float = np.nan
  614. class Inference:
  615. def logits(
  616. self, tokens: paddle.Tensor, audio_features: paddle.Tensor
  617. ) -> paddle.Tensor:
  618. """Perform a forward pass on the decoder and return per-token logits"""
  619. raise NotImplementedError
  620. def rearrange_kv_cache(self, source_indices) -> None:
  621. """Update the key-value cache according to the updated beams"""
  622. raise NotImplementedError
  623. def cleanup_caching(self) -> None:
  624. """Clean up any resources or hooks after decoding is finished"""
  625. class WhisperInference(Inference):
  626. def __init__(self, model: "Whisper", initial_token_length: int):
  627. self.model: "Whisper" = model
  628. self.initial_token_length = initial_token_length
  629. self.kv_cache = {}
  630. self.hooks = []
  631. def logits(
  632. self, tokens: paddle.Tensor, audio_features: paddle.Tensor
  633. ) -> paddle.Tensor:
  634. if not self.kv_cache:
  635. self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
  636. if tokens.shape[-1] > self.initial_token_length:
  637. # only need to use the last token except in the first forward pass
  638. tokens = tokens[:, -1:]
  639. return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
  640. def cleanup_caching(self):
  641. for hook in self.hooks:
  642. hook.remove()
  643. self.kv_cache = {}
  644. self.hooks = []
  645. def rearrange_kv_cache(self, source_indices):
  646. for module, tensor in self.kv_cache.items():
  647. # update the key/value cache to contain the selected sequences
  648. self.kv_cache[module] = tensor[source_indices].detach()
  649. @paddle.no_grad()
  650. def detect_language(
  651. model: "Whisper",
  652. mel: paddle.Tensor,
  653. resource_path: str,
  654. tokenizer: Tokenizer = None,
  655. ) -> Tuple[paddle.Tensor, List[dict]]:
  656. """
  657. Detect the spoken language in the audio, and return them as list of strings, along with the ids
  658. of the most probable language tokens and the probability distribution over all language tokens.
  659. This is performed outside the main decode loop in order to not interfere with kv-caching.
  660. Returns
  661. -------
  662. language_tokens : Tensor, shape = (batch_size,)
  663. ids of the most probable language tokens, which appears after the startoftranscript token.
  664. language_probs : List[Dict[str, float]], length = batch_size
  665. list of dictionaries containing the probability distribution over all languages.
  666. """
  667. if tokenizer is None:
  668. tokenizer = get_tokenizer(
  669. model.is_multilingual, resource_path=resource_path
  670. )
  671. if (
  672. tokenizer.language is None
  673. or tokenizer.language_token not in tokenizer.sot_sequence
  674. ):
  675. raise ValueError(
  676. "This model doesn't have language tokens so it can't perform lang id"
  677. )
  678. single = mel.ndim == 2
  679. if single:
  680. mel = mel.unsqueeze(0)
  681. # skip encoder forward pass if already-encoded audio features were given
  682. if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
  683. mel = model.encoder(mel)
  684. # forward pass using a single token, startoftranscript
  685. batch_size = mel.shape[0]
  686. x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
  687. logits = model.logits(x, mel)[:, 0]
  688. # collect detected languages; suppress all non-language tokens
  689. mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
  690. mask[list(tokenizer.all_language_tokens)] = False
  691. logits[:, mask] = -np.inf
  692. language_tokens = paddle.argmax(logits, axis=-1)
  693. language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
  694. language_probs = [
  695. {
  696. c: language_token_probs[i, j].tolist()
  697. for j, c in zip(
  698. tokenizer.all_language_tokens, tokenizer.all_language_codes
  699. )
  700. }
  701. for i in range(batch_size)
  702. ]
  703. if single:
  704. language_tokens = language_tokens[0]
  705. language_probs = language_probs[0]
  706. return language_tokens, language_probs
  707. @function_requires_deps("tqdm")
  708. def transcribe(
  709. model: "Whisper",
  710. mel: paddle.Tensor,
  711. resource_path: str,
  712. *,
  713. verbose: Optional[bool] = None,
  714. temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
  715. compression_ratio_threshold: Optional[float] = 2.4,
  716. logprob_threshold: Optional[float] = -1.0,
  717. no_speech_threshold: Optional[float] = 0.6,
  718. condition_on_previous_text: bool = True,
  719. **decode_options,
  720. ):
  721. """
  722. Transcribe an audio file using Whisper
  723. Parameters
  724. ----------
  725. model: Whisper
  726. The Whisper model instance
  727. mel: paddle.Tensor
  728. The audio feature
  729. verbose: bool
  730. Whether to display the text being decoded to the console. If True, displays all the details,
  731. If False, displays minimal details. If None, does not display anything
  732. temperature: Union[float, Tuple[float, ...]]
  733. Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
  734. upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
  735. compression_ratio_threshold: float
  736. If the gzip compression ratio is above this value, treat as failed
  737. logprob_threshold: float
  738. If the average log probability over sampled tokens is below this value, treat as failed
  739. no_speech_threshold: float
  740. If the no_speech probability is higher than this value AND the average log probability
  741. over sampled tokens is below `logprob_threshold`, consider the segment as silent
  742. condition_on_previous_text: bool
  743. if True, the previous output of the model is provided as a prompt for the next window;
  744. disabling may make the text inconsistent across windows, but the model becomes less prone to
  745. getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
  746. decode_options: dict
  747. Keyword arguments to construct `DecodingOptions` instances
  748. Returns
  749. -------
  750. A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
  751. the spoken language ("language"), which is detected when `decode_options["language"]` is None.
  752. """
  753. dtype = np.float32 # paddle only support float32
  754. if dtype == np.float32:
  755. decode_options["fp16"] = False
  756. if (
  757. decode_options.get("language") == "None"
  758. or decode_options.get("language", None) is None
  759. ):
  760. if not model.is_multilingual:
  761. decode_options["language"] = "en"
  762. else:
  763. if verbose:
  764. print(
  765. "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
  766. )
  767. segment = pad_or_trim(mel, N_FRAMES)
  768. _, probs = model.detect_language(segment, resource_path)
  769. decode_options["language"] = max(probs, key=probs.get)
  770. if verbose is not None:
  771. print(
  772. f"Detected language: {LANGUAGES[decode_options['language']].title()}"
  773. )
  774. language = decode_options["language"]
  775. task = decode_options.get("task", "transcribe")
  776. tokenizer = get_tokenizer(
  777. model.is_multilingual,
  778. resource_path=resource_path,
  779. language=language,
  780. task=task,
  781. )
  782. def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
  783. temperatures = (
  784. [temperature] if isinstance(temperature, (int, float)) else temperature
  785. )
  786. decode_result = None
  787. for t in temperatures:
  788. kwargs = {**decode_options}
  789. if t > 0:
  790. # disable beam_size and patience when t > 0
  791. kwargs.pop("beam_size", None)
  792. kwargs.pop("patience", None)
  793. else:
  794. # disable best_of when t == 0
  795. kwargs.pop("best_of", None)
  796. options = DecodingOptions(**kwargs, temperature=t)
  797. decode_result = model.decode(segment, options, resource_path)
  798. needs_fallback = False
  799. if (
  800. compression_ratio_threshold is not None
  801. and decode_result.compression_ratio > compression_ratio_threshold
  802. ):
  803. needs_fallback = True # too repetitive
  804. if (
  805. logprob_threshold is not None
  806. and decode_result.avg_logprob < logprob_threshold
  807. ):
  808. needs_fallback = True # average log probability is too low
  809. if not needs_fallback:
  810. break
  811. return decode_result
  812. seek = 0
  813. input_stride = exact_div(
  814. N_FRAMES, model.dims.n_audio_ctx
  815. ) # mel frames per output token: 2
  816. time_precision = (
  817. input_stride * HOP_LENGTH / SAMPLE_RATE
  818. ) # time per output token: 0.02 (seconds)
  819. all_tokens = []
  820. all_segments = []
  821. prompt_reset_since = 0
  822. initial_prompt = decode_options.pop("initial_prompt", None)
  823. if initial_prompt and initial_prompt != "None":
  824. initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
  825. all_tokens.extend(initial_prompt)
  826. else:
  827. initial_prompt = []
  828. def add_segment(
  829. *,
  830. start: float,
  831. end: float,
  832. text_tokens: paddle.Tensor,
  833. result: DecodingResult,
  834. ):
  835. text = tokenizer.decode(
  836. [token for token in text_tokens if token < tokenizer.eot]
  837. )
  838. if len(text.strip()) == 0: # skip empty text output
  839. return
  840. all_segments.append(
  841. {
  842. "id": len(all_segments),
  843. "seek": seek,
  844. "start": start,
  845. "end": end,
  846. "text": text,
  847. "tokens": result.tokens,
  848. "temperature": result.temperature,
  849. "avg_logprob": result.avg_logprob,
  850. "compression_ratio": result.compression_ratio,
  851. "no_speech_prob": result.no_speech_prob,
  852. }
  853. )
  854. if verbose:
  855. print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
  856. # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
  857. num_frames = mel.shape[-1]
  858. previous_seek_value = seek
  859. with tqdm.tqdm(
  860. total=num_frames, unit="frames", disable=verbose is not False
  861. ) as pbar:
  862. while seek < num_frames:
  863. timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
  864. segment = pad_or_trim(mel[:, seek:], N_FRAMES)
  865. segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
  866. decode_options["prompt"] = all_tokens[prompt_reset_since:]
  867. result: DecodingResult = decode_with_fallback(segment)
  868. tokens = paddle.to_tensor(result.tokens)
  869. if no_speech_threshold is not None:
  870. # no voice activity check
  871. should_skip = result.no_speech_prob > no_speech_threshold
  872. if (
  873. logprob_threshold is not None
  874. and result.avg_logprob > logprob_threshold
  875. ):
  876. # don't skip if the logprob is high enough, despite the no_speech_prob
  877. should_skip = False
  878. if should_skip:
  879. seek += segment.shape[
  880. -1
  881. ] # fast-forward to the next segment boundary
  882. continue
  883. timestamp_tokens: paddle.Tensor = tokens.greater_equal(
  884. paddle.to_tensor(tokenizer.timestamp_begin)
  885. )
  886. consecutive = paddle.where(
  887. timestamp_tokens[:-1] & timestamp_tokens[1:]
  888. )[0]
  889. if (
  890. len(consecutive) > 0
  891. ): # if the output contains two consecutive timestamp tokens
  892. consecutive = paddle.add(consecutive, paddle.to_tensor(1))
  893. last_slice = 0
  894. for current_slice in consecutive:
  895. sliced_tokens = tokens[last_slice:current_slice]
  896. start_timestamp_position = (
  897. sliced_tokens[0].item() - tokenizer.timestamp_begin
  898. )
  899. end_timestamp_position = (
  900. sliced_tokens[-1].item() - tokenizer.timestamp_begin
  901. )
  902. add_segment(
  903. start=timestamp_offset
  904. + start_timestamp_position * time_precision,
  905. end=timestamp_offset
  906. + end_timestamp_position * time_precision,
  907. text_tokens=sliced_tokens[1:-1],
  908. result=result,
  909. )
  910. last_slice = current_slice
  911. last_timestamp_position = (
  912. tokens[last_slice - 1].item() - tokenizer.timestamp_begin
  913. )
  914. seek += last_timestamp_position * input_stride
  915. all_tokens.extend(tokens[: last_slice + 1].tolist())
  916. else:
  917. duration = segment_duration
  918. timestamps = tokens[timestamp_tokens.nonzero().flatten()]
  919. if (
  920. len(timestamps) > 0
  921. and timestamps[-1].item() != tokenizer.timestamp_begin
  922. ):
  923. # no consecutive timestamps but it has a timestamp; use the last one.
  924. # single timestamp at the end means no speech after the last timestamp.
  925. last_timestamp_position = (
  926. timestamps[-1].item() - tokenizer.timestamp_begin
  927. )
  928. duration = last_timestamp_position * time_precision
  929. add_segment(
  930. start=timestamp_offset,
  931. end=timestamp_offset + duration,
  932. text_tokens=tokens,
  933. result=result,
  934. )
  935. seek += segment.shape[-1]
  936. all_tokens.extend(tokens.tolist())
  937. if not condition_on_previous_text or result.temperature > 0.5:
  938. # do not feed the prompt tokens if a high temperature was used
  939. prompt_reset_since = len(all_tokens)
  940. # update progress bar
  941. pbar.update(min(num_frames, seek) - previous_seek_value)
  942. previous_seek_value = seek
  943. return dict(
  944. text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
  945. segments=all_segments,
  946. language=language,
  947. )
  948. class SequenceRanker:
  949. def rank(
  950. self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
  951. ) -> List[int]:
  952. """
  953. Given a list of groups of samples and their cumulative log probabilities,
  954. return the indices of the samples in each group to select as the final result
  955. """
  956. raise NotImplementedError
  957. class MaximumLikelihoodRanker(SequenceRanker):
  958. """
  959. Select the sample with the highest log probabilities, penalized using either
  960. a simple length normalization or Google NMT paper's length penalty
  961. """
  962. def __init__(self, length_penalty: Optional[float]):
  963. self.length_penalty = length_penalty
  964. def rank(
  965. self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
  966. ):
  967. def scores(logprobs, lengths):
  968. result = []
  969. for logprob, length in zip(logprobs, lengths):
  970. if self.length_penalty is None or self.length_penalty == "None":
  971. penalty = length
  972. else:
  973. # from the Google NMT paper
  974. penalty = ((5 + length) / 6) ** self.length_penalty
  975. result.append(logprob / penalty)
  976. return result
  977. # get the sequence with the highest score
  978. lengths = [[len(t) for t in s] for s in tokens]
  979. return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
  980. class TokenDecoder:
  981. def reset(self):
  982. """Initialize any stateful variables for decoding a new sequence"""
  983. def update(
  984. self,
  985. tokens: paddle.Tensor,
  986. logits: paddle.Tensor,
  987. sum_logprobs: paddle.Tensor,
  988. ) -> Tuple[paddle.Tensor, bool]:
  989. """Specify how to select the next token, based on the current trace and logits
  990. Parameters
  991. ----------
  992. tokens : Tensor, shape = (n_batch, current_sequence_length)
  993. all tokens in the context so far, including the prefix and sot_sequence tokens
  994. logits : Tensor, shape = (n_batch, vocab_size)
  995. per-token logits of the probability distribution at the current step
  996. sum_logprobs : Tensor, shape = (n_batch)
  997. cumulative log probabilities for each sequence
  998. Returns
  999. -------
  1000. tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
  1001. the tokens, appended with the selected next token
  1002. completed : bool
  1003. True if all sequences has reached the end of text
  1004. """
  1005. raise NotImplementedError
  1006. def finalize(
  1007. self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
  1008. ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
  1009. """Finalize search and return the final candidate sequences
  1010. Parameters
  1011. ----------
  1012. tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
  1013. all tokens in the context so far, including the prefix and sot_sequence
  1014. sum_logprobs : Tensor, shape = (batch_size, beam_size)
  1015. cumulative log probabilities for each sequence
  1016. Returns
  1017. -------
  1018. tokens : Sequence[Sequence[Tensor]], length = batch_size
  1019. sequence of Tensors containing candidate token sequences, for each audio input
  1020. sum_logprobs : List[List[float]], length = batch_size
  1021. sequence of cumulative log probabilities corresponding to the above
  1022. """
  1023. raise NotImplementedError
  1024. class GreedyDecoder(TokenDecoder):
  1025. def __init__(self, temperature: float, eot: int):
  1026. self.temperature = temperature
  1027. self.eot = eot
  1028. def update(
  1029. self,
  1030. tokens: paddle.Tensor,
  1031. logits: paddle.Tensor,
  1032. sum_logprobs: paddle.Tensor,
  1033. ) -> Tuple[paddle.Tensor, bool]:
  1034. temperature = self.temperature
  1035. if temperature == 0:
  1036. next_tokens = paddle.argmax(logits, axis=-1)
  1037. else:
  1038. next_tokens = paddle.distribution.Categorical(
  1039. logits=logits / temperature
  1040. ).sample([1])
  1041. next_tokens = paddle.reshape(
  1042. next_tokens,
  1043. [
  1044. next_tokens.shape[0] * next_tokens.shape[1],
  1045. ],
  1046. )
  1047. logprobs = paddle.nn.functional.log_softmax(
  1048. logits, axis=-1, dtype=paddle.float32
  1049. )
  1050. current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
  1051. sum_logprobs += current_logprobs * paddle.to_tensor(
  1052. (tokens[:, -1] != self.eot), dtype=paddle.float32
  1053. )
  1054. next_tokens[tokens[:, -1] == self.eot] = self.eot
  1055. tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
  1056. completed = paddle.all((tokens[:, -1] == self.eot))
  1057. return tokens, completed
  1058. def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
  1059. # make sure each sequence has at least one EOT token at the end
  1060. tokens = paddle.nn.functional.pad(
  1061. tokens, (0, 1), value=self.eot, data_format="NCL"
  1062. )
  1063. return tokens, sum_logprobs.tolist()
  1064. class BeamSearchDecoder(TokenDecoder):
  1065. def __init__(
  1066. self,
  1067. beam_size: int,
  1068. eot: int,
  1069. inference: Inference,
  1070. patience: Optional[float] = None,
  1071. ):
  1072. self.beam_size = beam_size
  1073. self.eot = eot
  1074. self.inference = inference
  1075. self.patience = patience or 1.0
  1076. if patience is None or patience == "None":
  1077. self.patience = 1.0
  1078. else:
  1079. self.patience = patience
  1080. self.max_candidates: int = round(beam_size * self.patience)
  1081. self.finished_sequences = None
  1082. assert (
  1083. self.max_candidates > 0
  1084. ), f"Invalid beam size ({beam_size}) or patience ({patience})"
  1085. def reset(self):
  1086. self.finished_sequences = None
  1087. def update(
  1088. self,
  1089. tokens: paddle.Tensor,
  1090. logits: paddle.Tensor,
  1091. sum_logprobs: paddle.Tensor,
  1092. ) -> Tuple[paddle.Tensor, bool]:
  1093. if tokens.shape[0] % self.beam_size != 0:
  1094. raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
  1095. batch_size = tokens.shape[0] // self.beam_size
  1096. if self.finished_sequences is None: # for the first update
  1097. self.finished_sequences = [{} for _ in range(batch_size)]
  1098. logprobs = paddle.nn.functional.log_softmax(
  1099. logits, axis=-1, dtype="float32"
  1100. )
  1101. next_tokens, source_indices, finished_sequences = [], [], []
  1102. for i in range(batch_size):
  1103. scores, sources, finished = {}, {}, {}
  1104. # STEP 1: calculate the cumulative log probabilities for possible candidates
  1105. for j in range(self.beam_size):
  1106. idx = i * self.beam_size + j
  1107. prefix = tokens[idx].tolist()
  1108. logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
  1109. for logprob, token in zip(logprob, token):
  1110. new_logprob = (sum_logprobs[idx] + logprob).item()
  1111. sequence = tuple(prefix + [token.item()])
  1112. scores[sequence] = new_logprob
  1113. sources[sequence] = idx
  1114. # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
  1115. saved = 0
  1116. for sequence in sorted(scores, key=scores.get, reverse=True):
  1117. if sequence[-1] == self.eot:
  1118. finished[sequence] = scores[sequence]
  1119. else:
  1120. sum_logprobs[len(next_tokens)] = scores[sequence]
  1121. next_tokens.append(sequence)
  1122. source_indices.append(sources[sequence])
  1123. saved += 1
  1124. if saved == self.beam_size:
  1125. break
  1126. finished_sequences.append(finished)
  1127. tokens = paddle.to_tensor(next_tokens)
  1128. self.inference.rearrange_kv_cache(source_indices)
  1129. # add newly finished sequences to self.finished_sequences
  1130. assert len(self.finished_sequences) == len(finished_sequences)
  1131. for previously_finished, newly_finished in zip(
  1132. self.finished_sequences, finished_sequences
  1133. ):
  1134. for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
  1135. if len(previously_finished) >= self.max_candidates:
  1136. break # the candidate list is full
  1137. previously_finished[seq] = newly_finished[seq]
  1138. # mark as completed if all audio has enough number of samples
  1139. completed = all(
  1140. len(sequences) >= self.max_candidates
  1141. for sequences in self.finished_sequences
  1142. )
  1143. return tokens, completed
  1144. def finalize(
  1145. self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
  1146. ):
  1147. # collect all finished sequences, including patience, and add unfinished ones if not enough
  1148. sum_logprobs = sum_logprobs.cpu()
  1149. for i, sequences in enumerate(self.finished_sequences):
  1150. if (
  1151. len(sequences) < self.beam_size
  1152. ): # when not enough sequences are finished
  1153. for j in list(np.argsort(sum_logprobs[i]))[::-1]:
  1154. sequence = preceding_tokens[i, j].tolist() + [self.eot]
  1155. sequences[tuple(sequence)] = sum_logprobs[i][j].item()
  1156. if len(sequences) >= self.beam_size:
  1157. break
  1158. tokens: List[List[paddle.Tensor]] = [
  1159. [paddle.to_tensor(seq) for seq in sequences.keys()]
  1160. for sequences in self.finished_sequences
  1161. ]
  1162. sum_logprobs: List[List[float]] = [
  1163. list(sequences.values()) for sequences in self.finished_sequences
  1164. ]
  1165. return tokens, sum_logprobs
  1166. class LogitFilter:
  1167. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
  1168. """Apply any filtering or masking to logits in-place
  1169. Parameters
  1170. ----------
  1171. logits : Tensor, shape = (n_batch, vocab_size)
  1172. per-token logits of the probability distribution at the current step
  1173. tokens : Tensor, shape = (n_batch, current_sequence_length)
  1174. all tokens in the context so far, including the prefix and sot_sequence tokens
  1175. """
  1176. raise NotImplementedError
  1177. class SuppressBlank(LogitFilter):
  1178. def __init__(self, tokenizer: Tokenizer, sample_begin: int):
  1179. self.tokenizer = tokenizer
  1180. self.sample_begin = sample_begin
  1181. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1182. if tokens.shape[1] == self.sample_begin:
  1183. logits[
  1184. :, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]
  1185. ] = -np.inf
  1186. class SuppressTokens(LogitFilter):
  1187. def __init__(self, suppress_tokens: Sequence[int]):
  1188. self.suppress_tokens = list(suppress_tokens)
  1189. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1190. logits[:, self.suppress_tokens] = -np.inf
  1191. class ApplyTimestampRules(LogitFilter):
  1192. def __init__(
  1193. self,
  1194. tokenizer: Tokenizer,
  1195. sample_begin: int,
  1196. max_initial_timestamp_index: Optional[int],
  1197. ):
  1198. self.tokenizer = tokenizer
  1199. self.sample_begin = sample_begin
  1200. self.max_initial_timestamp_index = max_initial_timestamp_index
  1201. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1202. # suppress <|notimestamps|> which is handled by without_timestamps
  1203. if self.tokenizer.no_timestamps is not None:
  1204. logits[:, self.tokenizer.no_timestamps] = -np.inf
  1205. # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
  1206. for k in range(tokens.shape[0]):
  1207. seq = [t for t in tokens[k, self.sample_begin :].tolist()]
  1208. last_was_timestamp = (
  1209. len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
  1210. )
  1211. penultimate_was_timestamp = (
  1212. len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
  1213. )
  1214. if last_was_timestamp:
  1215. if penultimate_was_timestamp: # has to be non-timestamp
  1216. logits[k, self.tokenizer.timestamp_begin :] = -np.inf
  1217. else: # cannot be normal text tokens
  1218. logits[k, : self.tokenizer.eot] = -np.inf
  1219. # apply the `max_initial_timestamp` option
  1220. if (
  1221. tokens.shape[1] == self.sample_begin
  1222. and self.max_initial_timestamp_index is not None
  1223. ):
  1224. last_allowed = (
  1225. self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
  1226. )
  1227. logits[:, last_allowed + 1 :] = -np.inf
  1228. # if sum of probability over timestamps is above any other token, sample timestamp
  1229. logprobs = paddle.nn.functional.log_softmax(
  1230. logits, axis=-1, dtype="float32"
  1231. )
  1232. for k in range(tokens.shape[0]):
  1233. # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
  1234. # To bypass this issue in CI, we have decomposed the operation into separate steps.
  1235. # It will raise 2e-6 difference in precision.
  1236. # TODO: revert this after logsumexp been fixed.
  1237. timestamp_logprob = paddle.exp(
  1238. logprobs[k, self.tokenizer.timestamp_begin :]
  1239. )
  1240. timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
  1241. timestamp_logprob = paddle.log(timestamp_logprob)
  1242. max_text_token_logprob = paddle.max(
  1243. logprobs[k, : self.tokenizer.timestamp_begin]
  1244. )
  1245. if timestamp_logprob > max_text_token_logprob:
  1246. logits[k, : self.tokenizer.timestamp_begin] = -np.inf
  1247. class DecodingTask:
  1248. inference: Inference
  1249. sequence_ranker: SequenceRanker
  1250. decoder: TokenDecoder
  1251. logit_filters: List[LogitFilter]
  1252. def __init__(
  1253. self, model: "Whisper", options: DecodingOptions, resource_path: str
  1254. ):
  1255. self.model = model
  1256. language = options.language or "en"
  1257. tokenizer = get_tokenizer(
  1258. model.is_multilingual,
  1259. resource_path=resource_path,
  1260. language=language,
  1261. task=options.task,
  1262. )
  1263. self.tokenizer: Tokenizer = tokenizer
  1264. self.options: DecodingOptions = self._verify_options(options)
  1265. self.resource_path: str = resource_path
  1266. self.beam_size: int = options.beam_size or options.best_of or 1
  1267. self.n_ctx: int = model.dims.n_text_ctx
  1268. self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
  1269. self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
  1270. if self.options.without_timestamps:
  1271. self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
  1272. self.initial_tokens: Tuple[int] = self._get_initial_tokens()
  1273. self.sample_begin: int = len(self.initial_tokens)
  1274. self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
  1275. # inference: implements the forward pass through the decoder, including kv caching
  1276. self.inference = WhisperInference(model, len(self.initial_tokens))
  1277. # sequence ranker: implements how to rank a group of sampled sequences
  1278. self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
  1279. # decoder: implements how to select the next tokens, given the autoregressive distribution
  1280. if options.beam_size is not None:
  1281. self.decoder = BeamSearchDecoder(
  1282. options.beam_size, tokenizer.eot, self.inference, options.patience
  1283. )
  1284. else:
  1285. self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
  1286. # logit filters: applies various rules to suppress or penalize certain tokens
  1287. self.logit_filters = []
  1288. if self.options.suppress_blank:
  1289. self.logit_filters.append(
  1290. SuppressBlank(self.tokenizer, self.sample_begin)
  1291. )
  1292. if self.options.suppress_tokens:
  1293. self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
  1294. if not options.without_timestamps:
  1295. precision = (
  1296. CHUNK_LENGTH / model.dims.n_audio_ctx
  1297. ) # usually 0.02 seconds
  1298. max_initial_timestamp_index = None
  1299. if options.max_initial_timestamp:
  1300. max_initial_timestamp_index = round(
  1301. self.options.max_initial_timestamp / precision
  1302. )
  1303. self.logit_filters.append(
  1304. ApplyTimestampRules(
  1305. tokenizer, self.sample_begin, max_initial_timestamp_index
  1306. )
  1307. )
  1308. def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
  1309. if options.beam_size is not None and options.best_of is not None:
  1310. raise ValueError("beam_size and best_of can't be given together")
  1311. if options.temperature == 0:
  1312. if options.best_of is not None:
  1313. raise ValueError(
  1314. "best_of with greedy sampling (T=0) is not compatible"
  1315. )
  1316. if options.patience is not None and options.beam_size is None:
  1317. raise ValueError("patience requires beam_size to be given")
  1318. if options.length_penalty is not None and options.length_penalty != "None":
  1319. if not (0 <= options.length_penalty <= 1):
  1320. raise ValueError(
  1321. "length_penalty (alpha) should be a value between 0 and 1"
  1322. )
  1323. return options
  1324. def _get_initial_tokens(self) -> Tuple[int]:
  1325. tokens = list(self.sot_sequence)
  1326. prefix = self.options.prefix
  1327. prompt = self.options.prompt
  1328. if prefix:
  1329. prefix_tokens = (
  1330. self.tokenizer.encode(" " + prefix.strip().input_ids)
  1331. if isinstance(prefix, str)
  1332. else prefix
  1333. )
  1334. if self.sample_len is not None:
  1335. max_prefix_len = self.n_ctx // 2 - self.sample_len
  1336. prefix_tokens = prefix_tokens[-max_prefix_len:]
  1337. tokens = tokens + prefix_tokens
  1338. if prompt:
  1339. prompt_tokens = (
  1340. self.tokenizer.encode(" " + prompt.strip().input_ids)
  1341. if isinstance(prompt, str)
  1342. else prompt
  1343. )
  1344. tokens = (
  1345. [self.tokenizer.sot_prev]
  1346. + prompt_tokens[-(self.n_ctx // 2 - 1) :]
  1347. + tokens
  1348. )
  1349. return tuple(tokens)
  1350. def _get_suppress_tokens(self) -> Tuple[int]:
  1351. suppress_tokens = self.options.suppress_tokens
  1352. if isinstance(suppress_tokens, str):
  1353. suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
  1354. if -1 in suppress_tokens:
  1355. suppress_tokens = [t for t in suppress_tokens if t >= 0]
  1356. suppress_tokens.extend(self.tokenizer.non_speech_tokens)
  1357. elif suppress_tokens is None or len(suppress_tokens) == 0:
  1358. suppress_tokens = [] # interpret empty string as an empty list
  1359. else:
  1360. assert isinstance(
  1361. suppress_tokens, list
  1362. ), "suppress_tokens must be a list"
  1363. suppress_tokens.extend(
  1364. [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
  1365. )
  1366. if self.tokenizer.no_speech is not None:
  1367. # no-speech probability is collected separately
  1368. suppress_tokens.append(self.tokenizer.no_speech)
  1369. return tuple(sorted(set(suppress_tokens)))
  1370. def _get_audio_features(self, mel: paddle.Tensor):
  1371. if mel.shape[-2:] == (
  1372. self.model.dims.n_audio_ctx,
  1373. self.model.dims.n_audio_state,
  1374. ):
  1375. # encoded audio features are given; skip audio encoding
  1376. audio_features = mel
  1377. else:
  1378. audio_features = self.model.encoder(mel)
  1379. return audio_features
  1380. def _detect_language(
  1381. self,
  1382. audio_features: paddle.Tensor,
  1383. tokens: paddle.Tensor,
  1384. resource_path: str,
  1385. ):
  1386. languages = [self.options.language] * audio_features.shape[0]
  1387. lang_probs = None
  1388. if self.options.language is None or self.options.task == "lang_id":
  1389. lang_tokens, lang_probs = self.model.detect_language(
  1390. audio_features, self.tokenizer, self.resource_path
  1391. )
  1392. languages = [max(probs, key=probs.get) for probs in lang_probs]
  1393. if self.options.language is None:
  1394. tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
  1395. return languages, lang_probs
  1396. def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
  1397. assert audio_features.shape[0] == tokens.shape[0]
  1398. n_batch = tokens.shape[0]
  1399. sum_logprobs: paddle.Tensor = paddle.zeros(
  1400. paddle.to_tensor(n_batch), dtype=paddle.float32
  1401. )
  1402. no_speech_probs = [np.nan] * n_batch
  1403. try:
  1404. for i in range(self.sample_len):
  1405. logits = self.inference.logits(tokens, audio_features)
  1406. if (
  1407. i == 0 and self.tokenizer.no_speech is not None
  1408. ): # save no_speech_probs
  1409. probs_at_sot = paddle.nn.functional.softmax(
  1410. logits[:, self.sot_index], axis=-1, dtype=paddle.float32
  1411. )
  1412. no_speech_probs = probs_at_sot[
  1413. :, self.tokenizer.no_speech
  1414. ].tolist()
  1415. # now we need to consider the logits at the last token only
  1416. logits = logits[:, -1]
  1417. # apply the logit filters, e.g. for suppressing or applying penalty to
  1418. for logit_filter in self.logit_filters:
  1419. logit_filter.apply(logits, tokens)
  1420. # expand the tokens tensor with the selected next tokens
  1421. tokens, completed = self.decoder.update(
  1422. tokens, logits, sum_logprobs
  1423. )
  1424. if completed or tokens.shape[-1] > self.n_ctx:
  1425. break
  1426. finally:
  1427. self.inference.cleanup_caching()
  1428. return tokens, sum_logprobs, no_speech_probs
  1429. @paddle.no_grad()
  1430. def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
  1431. self.decoder.reset()
  1432. tokenizer: Tokenizer = self.tokenizer
  1433. batch_size: int = mel.shape[0]
  1434. audio_features: paddle.Tensor = self._get_audio_features(
  1435. mel
  1436. ) # encoder forward pass
  1437. tokens: paddle.Tensor
  1438. if batch_size > 1:
  1439. for i in range(batch_size):
  1440. tokens = paddle.concat(
  1441. x=[
  1442. paddle.to_tensor([self.initial_tokens]),
  1443. paddle.to_tensor([self.initial_tokens]),
  1444. ],
  1445. axis=0,
  1446. )
  1447. elif batch_size == 1:
  1448. tokens = paddle.to_tensor([self.initial_tokens])
  1449. # detect language if requested, overwriting the language token
  1450. languages, language_probs = self._detect_language(
  1451. paddle.to_tensor(audio_features),
  1452. paddle.to_tensor(tokens),
  1453. self.resource_path,
  1454. )
  1455. if self.options.task == "lang_id":
  1456. return [
  1457. DecodingResult(
  1458. audio_features=features, language=language, language_probs=probs
  1459. )
  1460. for features, language, probs in zip(
  1461. audio_features, languages, language_probs
  1462. )
  1463. ]
  1464. # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
  1465. audio_features = paddle.repeat_interleave(
  1466. audio_features, self.beam_size, axis=0
  1467. )
  1468. tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
  1469. # call the main sampling loop
  1470. tokens, sum_logprobs, no_speech_probs = self._main_loop(
  1471. audio_features, tokens
  1472. )
  1473. # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
  1474. audio_features = audio_features[:: self.beam_size]
  1475. no_speech_probs = no_speech_probs[:: self.beam_size]
  1476. assert audio_features.shape[0] == len(no_speech_probs) == batch_size
  1477. tokens = tokens.reshape([batch_size, self.beam_size, -1])
  1478. sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
  1479. # get the final candidates for each group, and slice between the first sampled token and EOT
  1480. tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
  1481. tokens: List[List[paddle.Tensor]] = [
  1482. [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
  1483. for s in tokens
  1484. ]
  1485. # select the top-ranked sample in each group
  1486. selected = self.sequence_ranker.rank(tokens, sum_logprobs)
  1487. tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
  1488. texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
  1489. sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
  1490. avg_logprobs: List[float] = [
  1491. lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
  1492. ]
  1493. fields = (
  1494. texts,
  1495. languages,
  1496. tokens,
  1497. audio_features,
  1498. avg_logprobs,
  1499. no_speech_probs,
  1500. )
  1501. if len(set(map(len, fields))) != 1:
  1502. raise RuntimeError(
  1503. f"inconsistent result lengths: {list(map(len, fields))}"
  1504. )
  1505. return [
  1506. DecodingResult(
  1507. audio_features=features,
  1508. language=language,
  1509. tokens=tokens,
  1510. text=text,
  1511. avg_logprob=avg_logprob,
  1512. no_speech_prob=no_speech_prob,
  1513. temperature=self.options.temperature,
  1514. compression_ratio=compression_ratio(text),
  1515. )
  1516. for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
  1517. *fields
  1518. )
  1519. ]
  1520. @paddle.no_grad()
  1521. def decode(
  1522. model: "Whisper",
  1523. mel: paddle.Tensor,
  1524. options: DecodingOptions = DecodingOptions(),
  1525. resource_path=str,
  1526. ) -> Union[DecodingResult, List[DecodingResult]]:
  1527. """
  1528. Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
  1529. Parameters
  1530. ----------
  1531. model: Whisper
  1532. the Whisper model instance
  1533. mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
  1534. A tensor containing the Mel spectrogram(s)
  1535. options: DecodingOptions
  1536. A dataclass that contains all necessary options for decoding 30-second segments
  1537. Returns
  1538. -------
  1539. result: Union[DecodingResult, List[DecodingResult]]
  1540. The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
  1541. """
  1542. single = mel.ndim == 2
  1543. if single:
  1544. mel = mel.unsqueeze(0)
  1545. result = DecodingTask(model, options, resource_path).run(mel)
  1546. if single:
  1547. result = result[0]
  1548. return result
  1549. class Whisper(paddle.nn.Layer):
  1550. """
  1551. The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
  1552. """
  1553. def __init__(self, dims: ModelDimensions):
  1554. super().__init__()
  1555. self.dims = dims
  1556. self.encoder = AudioEncoder(
  1557. self.dims.n_mels,
  1558. self.dims.n_audio_ctx,
  1559. self.dims.n_audio_state,
  1560. self.dims.n_audio_head,
  1561. self.dims.n_audio_layer,
  1562. )
  1563. self.decoder = TextDecoder(
  1564. self.dims.n_vocab,
  1565. self.dims.n_text_ctx,
  1566. self.dims.n_text_state,
  1567. self.dims.n_text_head,
  1568. self.dims.n_text_layer,
  1569. )
  1570. def embed_audio(self, mel: paddle.Tensor):
  1571. return self.encoder.forward(mel)
  1572. def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
  1573. return self.decoder.forward(tokens, audio_features)
  1574. def forward(
  1575. self, mel: paddle.Tensor, tokens: paddle.Tensor
  1576. ) -> Dict[str, paddle.Tensor]:
  1577. return self.decoder(tokens, self.encoder(mel))
  1578. @property
  1579. def device(self):
  1580. return paddle.device.get_device()
  1581. @property
  1582. def is_multilingual(self):
  1583. return self.dims.n_vocab == 51865
  1584. def install_kv_cache_hooks(self, cache: Optional[dict] = None):
  1585. """
  1586. The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
  1587. tensors calculated for the previous positions. This method returns a dictionary that stores
  1588. all caches, and the necessary hooks for the key and value projection modules that save the
  1589. intermediate tensors to be reused during later calculations.
  1590. Returns
  1591. -------
  1592. cache : Dict[nn.Layer, paddle.Tensor]
  1593. A dictionary object mapping the key/value projection modules to its cache
  1594. hooks : List[RemovableHandle]
  1595. List of PyTorch RemovableHandle objects to stop the hooks to be called
  1596. """
  1597. cache = {**cache} if cache is not None else {}
  1598. hooks = []
  1599. def save_to_cache(module, _, output):
  1600. if (
  1601. module not in cache
  1602. or output.shape[1] > self.decoder.positional_embedding.shape[0]
  1603. ):
  1604. cache[module] = (
  1605. output # save as-is, for the first token or cross attention
  1606. )
  1607. else:
  1608. cache[module] = paddle.concat(
  1609. [cache[module], output], axis=1
  1610. ).detach()
  1611. return cache[module]
  1612. def install_hooks(layer: paddle.nn.Layer):
  1613. if isinstance(layer, MultiHeadAttention):
  1614. hooks.append(layer.key.register_forward_post_hook(save_to_cache))
  1615. hooks.append(layer.value.register_forward_post_hook(save_to_cache))
  1616. self.decoder.apply(install_hooks)
  1617. return cache, hooks
  1618. detect_language = detect_language
  1619. transcribe = transcribe
  1620. decode = decode
  1621. def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
  1622. """
  1623. Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
  1624. """
  1625. if paddle.is_tensor(array):
  1626. if array.shape[axis] > length:
  1627. array = array.index_select(axis=axis, index=paddle.arange(length))
  1628. if array.shape[axis] < length:
  1629. pad_widths = [(0, 0)] * array.ndim
  1630. pad_widths[axis] = (0, length - array.shape[axis])
  1631. array = paddle.transpose(array, (1, 0))
  1632. array = paddle.nn.functional.pad(
  1633. array,
  1634. [pad for sizes in pad_widths[::-1] for pad in sizes],
  1635. data_format="NLC",
  1636. )
  1637. array = paddle.transpose(array, (1, 0))
  1638. else:
  1639. if array.shape[axis] > length:
  1640. array = array.take(indices=range(length), axis=axis)
  1641. if array.shape[axis] < length:
  1642. pad_widths = [(0, 0)] * array.ndim
  1643. pad_widths[axis] = (0, length - array.shape[axis])
  1644. array = paddle.transpose(array, (1, 0))
  1645. array = np.pad(array, pad_widths)
  1646. array = paddle.transpose(array, (1, 0))
  1647. return array
  1648. def hann_window(n_fft: int = N_FFT):
  1649. """
  1650. hanning window
  1651. n_fft: The number of frequency components of the discrete Fourier transform.
  1652. """
  1653. return paddle.to_tensor(
  1654. [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
  1655. dtype=paddle.float32,
  1656. )
  1657. @lru_cache(maxsize=None)
  1658. def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
  1659. """
  1660. load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
  1661. Allows decoupling librosa dependency; saved using:
  1662. np.savez_compressed(
  1663. "mel_filters.npz",
  1664. mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
  1665. )
  1666. """
  1667. assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
  1668. with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
  1669. return paddle.to_tensor(f[f"mel_{n_mels}"])
  1670. @function_requires_deps("soundfile")
  1671. def log_mel_spectrogram(
  1672. audio: Union[str, np.ndarray, paddle.Tensor],
  1673. n_mels: int = N_MELS,
  1674. resource_path: str = None,
  1675. ):
  1676. """
  1677. Compute the log-Mel spectrogram of
  1678. Parameters
  1679. ----------
  1680. audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
  1681. The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
  1682. n_mels: int
  1683. The number of Mel-frequency filters, only 80 is supported
  1684. Returns
  1685. -------
  1686. paddle.Tensor, shape = (80, n_frames)
  1687. A Tensor that contains the Mel spectrogram
  1688. """
  1689. if not paddle.is_tensor(audio):
  1690. if isinstance(audio, str):
  1691. audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
  1692. audio = audio[:, 0]
  1693. audio = paddle.to_tensor(audio)
  1694. window = hann_window(N_FFT)
  1695. stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
  1696. magnitudes = stft[:, :-1].abs() ** 2
  1697. filters = mel_filters(resource_path, n_mels)
  1698. mel_spec = filters @ magnitudes
  1699. mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
  1700. log_spec = paddle.clip(mel_spec, min=1e-10).log10()
  1701. log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
  1702. log_spec = (log_spec + 4.0) / 4.0
  1703. return log_spec