processors.py 68 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
  15. import os
  16. import tqdm
  17. import zlib
  18. import soundfile
  19. import numpy as np
  20. import lazy_paddle as paddle
  21. from dataclasses import dataclass
  22. from dataclasses import field
  23. from functools import lru_cache
  24. from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
  25. from ..common.tokenizer import GPTTokenizer
  26. __all__ = [
  27. "Whisper",
  28. "Tokenizer",
  29. ]
  30. def exact_div(x, y):
  31. assert x % y == 0
  32. return x // y
  33. _MODELS = ["large"]
  34. SAMPLE_RATE = 16000
  35. N_FFT = 400
  36. N_MELS = 80
  37. HOP_LENGTH = 160
  38. CHUNK_LENGTH = 30
  39. N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
  40. N_FRAMES = exact_div(
  41. N_SAMPLES, HOP_LENGTH
  42. ) # 3000: number of frames in a mel spectrogram input
  43. @dataclass
  44. class ModelDimensions:
  45. n_mels: int
  46. n_audio_ctx: int
  47. n_audio_state: int
  48. n_audio_head: int
  49. n_audio_layer: int
  50. n_vocab: int
  51. n_text_ctx: int
  52. n_text_state: int
  53. n_text_head: int
  54. n_text_layer: int
  55. LANGUAGES = {
  56. "en": "english",
  57. "zh": "chinese",
  58. "de": "german",
  59. "es": "spanish",
  60. "ru": "russian",
  61. "ko": "korean",
  62. "fr": "french",
  63. "ja": "japanese",
  64. "pt": "portuguese",
  65. "tr": "turkish",
  66. "pl": "polish",
  67. "ca": "catalan",
  68. "nl": "dutch",
  69. "ar": "arabic",
  70. "sv": "swedish",
  71. "it": "italian",
  72. "id": "indonesian",
  73. "hi": "hindi",
  74. "fi": "finnish",
  75. "vi": "vietnamese",
  76. "iw": "hebrew",
  77. "uk": "ukrainian",
  78. "el": "greek",
  79. "ms": "malay",
  80. "cs": "czech",
  81. "ro": "romanian",
  82. "da": "danish",
  83. "hu": "hungarian",
  84. "ta": "tamil",
  85. "no": "norwegian",
  86. "th": "thai",
  87. "ur": "urdu",
  88. "hr": "croatian",
  89. "bg": "bulgarian",
  90. "lt": "lithuanian",
  91. "la": "latin",
  92. "mi": "maori",
  93. "ml": "malayalam",
  94. "cy": "welsh",
  95. "sk": "slovak",
  96. "te": "telugu",
  97. "fa": "persian",
  98. "lv": "latvian",
  99. "bn": "bengali",
  100. "sr": "serbian",
  101. "az": "azerbaijani",
  102. "sl": "slovenian",
  103. "kn": "kannada",
  104. "et": "estonian",
  105. "mk": "macedonian",
  106. "br": "breton",
  107. "eu": "basque",
  108. "is": "icelandic",
  109. "hy": "armenian",
  110. "ne": "nepali",
  111. "mn": "mongolian",
  112. "bs": "bosnian",
  113. "kk": "kazakh",
  114. "sq": "albanian",
  115. "sw": "swahili",
  116. "gl": "galician",
  117. "mr": "marathi",
  118. "pa": "punjabi",
  119. "si": "sinhala",
  120. "km": "khmer",
  121. "sn": "shona",
  122. "yo": "yoruba",
  123. "so": "somali",
  124. "af": "afrikaans",
  125. "oc": "occitan",
  126. "ka": "georgian",
  127. "be": "belarusian",
  128. "tg": "tajik",
  129. "sd": "sindhi",
  130. "gu": "gujarati",
  131. "am": "amharic",
  132. "yi": "yiddish",
  133. "lo": "lao",
  134. "uz": "uzbek",
  135. "fo": "faroese",
  136. "ht": "haitian creole",
  137. "ps": "pashto",
  138. "tk": "turkmen",
  139. "nn": "nynorsk",
  140. "mt": "maltese",
  141. "sa": "sanskrit",
  142. "lb": "luxembourgish",
  143. "my": "myanmar",
  144. "bo": "tibetan",
  145. "tl": "tagalog",
  146. "mg": "malagasy",
  147. "as": "assamese",
  148. "tt": "tatar",
  149. "haw": "hawaiian",
  150. "ln": "lingala",
  151. "ha": "hausa",
  152. "ba": "bashkir",
  153. "jw": "javanese",
  154. "su": "sundanese",
  155. }
  156. # language code lookup by name, with a few language aliases
  157. TO_LANGUAGE_CODE = {
  158. **{language: code for code, language in LANGUAGES.items()},
  159. "burmese": "my",
  160. "valencian": "ca",
  161. "flemish": "nl",
  162. "haitian": "ht",
  163. "letzeburgesch": "lb",
  164. "pushto": "ps",
  165. "panjabi": "pa",
  166. "moldavian": "ro",
  167. "moldovan": "ro",
  168. "sinhalese": "si",
  169. "castilian": "es",
  170. }
  171. def compression_ratio(text) -> float:
  172. return len(text) / len(zlib.compress(text.encode("utf-8")))
  173. def format_timestamp(
  174. seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
  175. ):
  176. assert seconds >= 0, "non-negative timestamp expected"
  177. milliseconds = round(seconds * 1000.0)
  178. hours = milliseconds // 3_600_000
  179. milliseconds -= hours * 3_600_000
  180. minutes = milliseconds // 60_000
  181. milliseconds -= minutes * 60_000
  182. seconds = milliseconds // 1_000
  183. milliseconds -= seconds * 1_000
  184. hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
  185. return (
  186. f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
  187. )
  188. @dataclass(frozen=True)
  189. class Tokenizer:
  190. """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
  191. tokenizer: "GPTTokenizer"
  192. language: Optional[str]
  193. sot_sequence: Tuple[int]
  194. def encode(self, text, **kwargs):
  195. return self.tokenizer.encode(text, **kwargs)
  196. def decode(
  197. self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
  198. ):
  199. if len(token_ids) > 1:
  200. ids_list = []
  201. for ids in token_ids:
  202. if paddle.is_tensor(ids):
  203. ids = ids.item()
  204. if ids < len(self.tokenizer):
  205. ids_list.append(ids)
  206. token_ids = ids_list
  207. elif len(token_ids) == 1:
  208. token_ids = token_ids[0]
  209. else:
  210. raise ValueError(f"token_ids {token_ids} load error.")
  211. return self.tokenizer.decode(token_ids, **kwargs)
  212. def decode_with_timestamps(self, tokens) -> str:
  213. """
  214. Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
  215. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
  216. """
  217. outputs = [[]]
  218. for token in tokens:
  219. if token >= self.timestamp_begin:
  220. timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
  221. outputs.append(timestamp)
  222. outputs.append([])
  223. else:
  224. outputs[-1].append(token)
  225. outputs = [
  226. s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
  227. ]
  228. return "".join(outputs)
  229. @property
  230. @lru_cache()
  231. def eot(self) -> int:
  232. return self.tokenizer.eos_token_id
  233. @property
  234. @lru_cache()
  235. def sot(self) -> int:
  236. return self._get_single_token_id("<|startoftranscript|>")
  237. @property
  238. @lru_cache()
  239. def sot_lm(self) -> int:
  240. return self._get_single_token_id("<|startoflm|>")
  241. @property
  242. @lru_cache()
  243. def sot_prev(self) -> int:
  244. return self._get_single_token_id("<|startofprev|>")
  245. @property
  246. @lru_cache()
  247. def no_speech(self) -> int:
  248. return self._get_single_token_id("<|nospeech|>")
  249. @property
  250. @lru_cache()
  251. def no_timestamps(self) -> int:
  252. return self._get_single_token_id("<|notimestamps|>")
  253. @property
  254. @lru_cache()
  255. def timestamp_begin(self) -> int:
  256. return self.tokenizer.all_special_ids[-1] + 1
  257. @property
  258. @lru_cache()
  259. def language_token(self) -> int:
  260. """Returns the token id corresponding to the value of the `language` field"""
  261. if self.language is None:
  262. raise ValueError("This tokenizer does not have language token configured")
  263. additional_tokens = dict(
  264. zip(
  265. self.tokenizer.additional_special_tokens,
  266. self.tokenizer.additional_special_tokens_ids,
  267. )
  268. )
  269. candidate = f"<|{self.language}|>"
  270. if candidate in additional_tokens:
  271. return additional_tokens[candidate]
  272. raise KeyError(f"Language {self.language} not found in tokenizer.")
  273. @property
  274. @lru_cache()
  275. def all_language_tokens(self) -> Tuple[int]:
  276. result = []
  277. for token, token_id in zip(
  278. self.tokenizer.additional_special_tokens,
  279. self.tokenizer.additional_special_tokens_ids,
  280. ):
  281. if token.strip("<|>") in LANGUAGES:
  282. result.append(token_id)
  283. return tuple(result)
  284. @property
  285. @lru_cache()
  286. def all_language_codes(self) -> Tuple[str]:
  287. return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
  288. @property
  289. @lru_cache()
  290. def sot_sequence_including_notimestamps(self) -> Tuple[int]:
  291. return tuple(list(self.sot_sequence) + [self.no_timestamps])
  292. @property
  293. @lru_cache()
  294. def non_speech_tokens(self) -> Tuple[int]:
  295. """
  296. Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
  297. annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
  298. - ♪♪♪
  299. - ( SPEAKING FOREIGN LANGUAGE )
  300. - [DAVID] Hey there,
  301. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
  302. """
  303. symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
  304. symbols += (
  305. "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
  306. )
  307. # symbols that may be a single token or multiple tokens depending on the tokenizer.
  308. # In case they're multiple tokens, suppress the first token, which is safe because:
  309. # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
  310. # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
  311. miscellaneous = set("♩♪♫♬♭♮♯")
  312. assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
  313. # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
  314. result = {
  315. self.tokenizer.encode(" -").input_ids[0],
  316. self.tokenizer.encode(" '").input_ids[0],
  317. }
  318. for symbol in symbols + list(miscellaneous):
  319. for tokens in [
  320. self.tokenizer.encode(symbol).input_ids,
  321. self.tokenizer.encode(" " + symbol).input_ids,
  322. ]:
  323. if len(tokens) == 1 or symbol in miscellaneous:
  324. result.add(tokens[0])
  325. return tuple(sorted(result))
  326. def _get_single_token_id(self, text) -> int:
  327. tokens = self.tokenizer.encode(text).input_ids
  328. assert len(tokens) == 1, f"{text} is not encoded as a single token"
  329. return tokens[0]
  330. @lru_cache(maxsize=None)
  331. def build_tokenizer(resource_path: str, name: str = "gpt2"):
  332. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  333. path = os.path.join(resource_path, "assets", name)
  334. tokenizer = GPTTokenizer.from_pretrained(path)
  335. specials = [
  336. "<|startoftranscript|>",
  337. *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
  338. "<|translate|>",
  339. "<|transcribe|>",
  340. "<|startoflm|>",
  341. "<|startofprev|>",
  342. "<|nospeech|>",
  343. "<|notimestamps|>",
  344. ]
  345. tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
  346. return tokenizer
  347. @lru_cache(maxsize=None)
  348. def get_tokenizer(
  349. multilingual: bool,
  350. resource_path: str,
  351. *,
  352. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  353. language: Optional[str] = None,
  354. ) -> Tokenizer:
  355. if language is not None:
  356. language = language.lower()
  357. if language not in LANGUAGES:
  358. if language in TO_LANGUAGE_CODE:
  359. language = TO_LANGUAGE_CODE[language]
  360. else:
  361. raise ValueError(f"Unsupported language: {language}")
  362. if multilingual:
  363. tokenizer_name = "multilingual"
  364. task = task or "transcribe"
  365. language = language or "en"
  366. else:
  367. tokenizer_name = "gpt2"
  368. task = None
  369. language = None
  370. tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
  371. all_special_ids: List[int] = tokenizer.all_special_ids
  372. sot: int = all_special_ids[1]
  373. translate: int = all_special_ids[-6]
  374. transcribe: int = all_special_ids[-5]
  375. langs = tuple(LANGUAGES.keys())
  376. sot_sequence = [sot]
  377. if language is not None:
  378. sot_sequence.append(sot + 1 + langs.index(language))
  379. if task is not None:
  380. sot_sequence.append(transcribe if task == "transcribe" else translate)
  381. return Tokenizer(
  382. tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
  383. )
  384. class MultiHeadAttention(paddle.nn.Layer):
  385. def __init__(self, n_state: int, n_head: int):
  386. super().__init__()
  387. self.n_head = n_head
  388. self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  389. self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
  390. self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  391. self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
  392. def forward(
  393. self,
  394. x: paddle.Tensor,
  395. xa: Optional[paddle.Tensor] = None,
  396. mask: Optional[paddle.Tensor] = None,
  397. kv_cache: Optional[dict] = None,
  398. ):
  399. q = self.query(x)
  400. if kv_cache is None or xa is None or self.key not in kv_cache:
  401. # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
  402. # otherwise, perform key/value projections for self- or cross-attention as usual.
  403. k = self.key(x if xa is None else xa)
  404. v = self.value(x if xa is None else xa)
  405. else:
  406. # for cross-attention, calculate keys and values once and reuse in subsequent calls.
  407. k = kv_cache[self.key]
  408. v = kv_cache[self.value]
  409. wv = self.qkv_attention(q, k, v, mask)
  410. return self.out(wv)
  411. def qkv_attention(
  412. self,
  413. q: paddle.Tensor,
  414. k: paddle.Tensor,
  415. v: paddle.Tensor,
  416. mask: Optional[paddle.Tensor] = None,
  417. ):
  418. n_batch, n_ctx, n_state = q.shape
  419. scale = (n_state // self.n_head) ** -0.25
  420. q = (
  421. paddle.transpose(q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
  422. * scale
  423. )
  424. k = (
  425. paddle.transpose(k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1))
  426. * scale
  427. )
  428. v = paddle.transpose(v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
  429. qk = q @ k
  430. if mask is not None:
  431. qk = qk + mask[:n_ctx, :n_ctx]
  432. w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
  433. return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
  434. class ResidualAttentionBlock(paddle.nn.Layer):
  435. def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
  436. super().__init__()
  437. self.attn = MultiHeadAttention(n_state, n_head)
  438. self.attn_ln = paddle.nn.LayerNorm(n_state)
  439. self.cross_attn = (
  440. MultiHeadAttention(n_state, n_head) if cross_attention else None
  441. )
  442. self.cross_attn_ln = paddle.nn.LayerNorm(n_state) if cross_attention else None
  443. n_mlp = n_state * 4
  444. self.mlp = paddle.nn.Sequential(
  445. paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
  446. paddle.nn.GELU(),
  447. paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
  448. )
  449. self.mlp_ln = paddle.nn.LayerNorm(n_state)
  450. def forward(
  451. self,
  452. x: paddle.Tensor,
  453. xa: Optional[paddle.Tensor] = None,
  454. mask: Optional[paddle.Tensor] = None,
  455. kv_cache: Optional[dict] = None,
  456. ):
  457. x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
  458. if self.cross_attn:
  459. x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
  460. x = x + self.mlp(self.mlp_ln(x))
  461. return x
  462. def sinusoids(length, channels, max_timescale=10000):
  463. """Returns sinusoids for positional embedding"""
  464. assert channels % 2 == 0
  465. log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
  466. inv_timescales = paddle.exp(
  467. -log_timescale_increment * paddle.arange(channels // 2, dtype=paddle.float32)
  468. )
  469. scaled_time = (
  470. paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
  471. * inv_timescales[np.newaxis, :]
  472. )
  473. return paddle.to_tensor(
  474. paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
  475. )
  476. class AudioEncoder(paddle.nn.Layer):
  477. def __init__(
  478. self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  479. ):
  480. super().__init__()
  481. self.conv1 = paddle.nn.Conv1D(
  482. n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
  483. )
  484. self.conv2 = paddle.nn.Conv1D(
  485. n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
  486. )
  487. self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
  488. self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
  489. [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
  490. )
  491. self.ln_post = paddle.nn.LayerNorm(n_state)
  492. def forward(self, x: paddle.Tensor):
  493. """
  494. x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
  495. the mel spectrogram of the audio
  496. """
  497. x = paddle.nn.functional.gelu(self.conv1(x))
  498. x = paddle.nn.functional.gelu(self.conv2(x))
  499. x = paddle.transpose(x, (0, 2, 1))
  500. assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
  501. x = x + self.positional_embedding
  502. for block in self.blocks:
  503. x = block(x)
  504. x = self.ln_post(x)
  505. return x
  506. class TextDecoder(paddle.nn.Layer):
  507. def __init__(
  508. self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  509. ):
  510. super().__init__()
  511. self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
  512. self.positional_embedding = paddle.create_parameter(
  513. shape=[n_ctx, n_state], dtype="float32"
  514. )
  515. self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
  516. [
  517. ResidualAttentionBlock(n_state, n_head, cross_attention=True)
  518. for _ in range(n_layer)
  519. ]
  520. )
  521. self.ln = paddle.nn.LayerNorm(n_state)
  522. mask = paddle.full(shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32")
  523. mask = paddle.triu(mask, diagonal=1)
  524. self.register_buffer("mask", mask, persistable=False)
  525. def forward(
  526. self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
  527. ):
  528. """
  529. x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
  530. the text tokens
  531. xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
  532. the encoded audio features to be attended on
  533. """
  534. offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
  535. x = (
  536. self.token_embedding(x)
  537. + self.positional_embedding[offset : offset + x.shape[-1]]
  538. )
  539. x = x.to(xa.dtype)
  540. for block in self.blocks:
  541. x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
  542. x = self.ln(x)
  543. logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
  544. return logits
  545. @dataclass(frozen=True)
  546. class DecodingOptions:
  547. task: str = (
  548. "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
  549. )
  550. language: Optional[str] = (
  551. None # language that the audio is in; uses detected language if None
  552. )
  553. # sampling-related options
  554. temperature: float = 0.0
  555. sample_len: Optional[int] = None # maximum number of tokens to sample
  556. best_of: Optional[int] = (
  557. None # number of independent samples to collect, when t > 0
  558. )
  559. beam_size: Optional[int] = None # number of beams in beam search, when t == 0
  560. patience: Optional[float] = (
  561. None # patience in beam search (https://arxiv.org/abs/2204.05424)
  562. )
  563. # options for ranking generations (either beams or best-of-N samples)
  564. length_penalty: Optional[float] = (
  565. None # "alpha" in Google NMT, None defaults to length norm
  566. )
  567. # prompt, prefix, and token suppression
  568. prompt: Optional[Union[str, List[int]]] = (
  569. None # text or tokens for the previous context
  570. )
  571. prefix: Optional[Union[str, List[int]]] = (
  572. None # text or tokens to prefix the current context
  573. )
  574. suppress_blank: bool = True # this will suppress blank outputs
  575. # list of tokens ids (or comma-separated token ids) to suppress
  576. # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
  577. suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
  578. # timestamp sampling options
  579. without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
  580. max_initial_timestamp: Optional[float] = (
  581. 1.0 # the initial timestamp cannot be later than this
  582. )
  583. # implementation details
  584. fp16: bool = False # use fp16 for most of the calculation
  585. @dataclass(frozen=True)
  586. class DecodingResult:
  587. audio_features: paddle.Tensor
  588. language: str
  589. language_probs: Optional[Dict[str, float]] = None
  590. tokens: List[int] = field(default_factory=list)
  591. text: str = ""
  592. avg_logprob: float = np.nan
  593. no_speech_prob: float = np.nan
  594. temperature: float = np.nan
  595. compression_ratio: float = np.nan
  596. class Inference:
  597. def logits(
  598. self, tokens: paddle.Tensor, audio_features: paddle.Tensor
  599. ) -> paddle.Tensor:
  600. """Perform a forward pass on the decoder and return per-token logits"""
  601. raise NotImplementedError
  602. def rearrange_kv_cache(self, source_indices) -> None:
  603. """Update the key-value cache according to the updated beams"""
  604. raise NotImplementedError
  605. def cleanup_caching(self) -> None:
  606. """Clean up any resources or hooks after decoding is finished"""
  607. pass
  608. class WhisperInference(Inference):
  609. def __init__(self, model: "Whisper", initial_token_length: int):
  610. self.model: "Whisper" = model
  611. self.initial_token_length = initial_token_length
  612. self.kv_cache = {}
  613. self.hooks = []
  614. def logits(
  615. self, tokens: paddle.Tensor, audio_features: paddle.Tensor
  616. ) -> paddle.Tensor:
  617. if not self.kv_cache:
  618. self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
  619. if tokens.shape[-1] > self.initial_token_length:
  620. # only need to use the last token except in the first forward pass
  621. tokens = tokens[:, -1:]
  622. return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
  623. def cleanup_caching(self):
  624. for hook in self.hooks:
  625. hook.remove()
  626. self.kv_cache = {}
  627. self.hooks = []
  628. def rearrange_kv_cache(self, source_indices):
  629. for module, tensor in self.kv_cache.items():
  630. # update the key/value cache to contain the selected sequences
  631. self.kv_cache[module] = tensor[source_indices].detach()
  632. @paddle.no_grad()
  633. def detect_language(
  634. model: "Whisper",
  635. mel: paddle.Tensor,
  636. resource_path: str,
  637. tokenizer: Tokenizer = None,
  638. ) -> Tuple[paddle.Tensor, List[dict]]:
  639. """
  640. Detect the spoken language in the audio, and return them as list of strings, along with the ids
  641. of the most probable language tokens and the probability distribution over all language tokens.
  642. This is performed outside the main decode loop in order to not interfere with kv-caching.
  643. Returns
  644. -------
  645. language_tokens : Tensor, shape = (batch_size,)
  646. ids of the most probable language tokens, which appears after the startoftranscript token.
  647. language_probs : List[Dict[str, float]], length = batch_size
  648. list of dictionaries containing the probability distribution over all languages.
  649. """
  650. if tokenizer is None:
  651. tokenizer = get_tokenizer(model.is_multilingual, resource_path=resource_path)
  652. if (
  653. tokenizer.language is None
  654. or tokenizer.language_token not in tokenizer.sot_sequence
  655. ):
  656. raise ValueError(
  657. "This model doesn't have language tokens so it can't perform lang id"
  658. )
  659. single = mel.ndim == 2
  660. if single:
  661. mel = mel.unsqueeze(0)
  662. # skip encoder forward pass if already-encoded audio features were given
  663. if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
  664. mel = model.encoder(mel)
  665. # forward pass using a single token, startoftranscript
  666. batch_size = mel.shape[0]
  667. x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
  668. logits = model.logits(x, mel)[:, 0]
  669. # collect detected languages; suppress all non-language tokens
  670. mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
  671. mask[list(tokenizer.all_language_tokens)] = False
  672. logits[:, mask] = -np.inf
  673. language_tokens = paddle.argmax(logits, axis=-1)
  674. language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
  675. language_probs = [
  676. {
  677. c: language_token_probs[i, j].tolist()
  678. for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
  679. }
  680. for i in range(batch_size)
  681. ]
  682. if single:
  683. language_tokens = language_tokens[0]
  684. language_probs = language_probs[0]
  685. return language_tokens, language_probs
  686. def transcribe(
  687. model: "Whisper",
  688. mel: paddle.Tensor,
  689. resource_path: str,
  690. *,
  691. verbose: Optional[bool] = None,
  692. temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
  693. compression_ratio_threshold: Optional[float] = 2.4,
  694. logprob_threshold: Optional[float] = -1.0,
  695. no_speech_threshold: Optional[float] = 0.6,
  696. condition_on_previous_text: bool = True,
  697. **decode_options,
  698. ):
  699. """
  700. Transcribe an audio file using Whisper
  701. Parameters
  702. ----------
  703. model: Whisper
  704. The Whisper model instance
  705. mel: paddle.Tensor
  706. The audio feature
  707. verbose: bool
  708. Whether to display the text being decoded to the console. If True, displays all the details,
  709. If False, displays minimal details. If None, does not display anything
  710. temperature: Union[float, Tuple[float, ...]]
  711. Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
  712. upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
  713. compression_ratio_threshold: float
  714. If the gzip compression ratio is above this value, treat as failed
  715. logprob_threshold: float
  716. If the average log probability over sampled tokens is below this value, treat as failed
  717. no_speech_threshold: float
  718. If the no_speech probability is higher than this value AND the average log probability
  719. over sampled tokens is below `logprob_threshold`, consider the segment as silent
  720. condition_on_previous_text: bool
  721. if True, the previous output of the model is provided as a prompt for the next window;
  722. disabling may make the text inconsistent across windows, but the model becomes less prone to
  723. getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
  724. decode_options: dict
  725. Keyword arguments to construct `DecodingOptions` instances
  726. Returns
  727. -------
  728. A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
  729. the spoken language ("language"), which is detected when `decode_options["language"]` is None.
  730. """
  731. dtype = np.float32 # paddle only support float32
  732. if dtype == np.float32:
  733. decode_options["fp16"] = False
  734. if (
  735. decode_options.get("language") == "None"
  736. or decode_options.get("language", None) is None
  737. ):
  738. if not model.is_multilingual:
  739. decode_options["language"] = "en"
  740. else:
  741. if verbose:
  742. print(
  743. "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
  744. )
  745. segment = pad_or_trim(mel, N_FRAMES)
  746. _, probs = model.detect_language(segment, resource_path)
  747. decode_options["language"] = max(probs, key=probs.get)
  748. if verbose is not None:
  749. print(
  750. f"Detected language: {LANGUAGES[decode_options['language']].title()}"
  751. )
  752. language = decode_options["language"]
  753. task = decode_options.get("task", "transcribe")
  754. tokenizer = get_tokenizer(
  755. model.is_multilingual, resource_path=resource_path, language=language, task=task
  756. )
  757. def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
  758. temperatures = (
  759. [temperature] if isinstance(temperature, (int, float)) else temperature
  760. )
  761. decode_result = None
  762. for t in temperatures:
  763. kwargs = {**decode_options}
  764. if t > 0:
  765. # disable beam_size and patience when t > 0
  766. kwargs.pop("beam_size", None)
  767. kwargs.pop("patience", None)
  768. else:
  769. # disable best_of when t == 0
  770. kwargs.pop("best_of", None)
  771. options = DecodingOptions(**kwargs, temperature=t)
  772. decode_result = model.decode(segment, options, resource_path)
  773. needs_fallback = False
  774. if (
  775. compression_ratio_threshold is not None
  776. and decode_result.compression_ratio > compression_ratio_threshold
  777. ):
  778. needs_fallback = True # too repetitive
  779. if (
  780. logprob_threshold is not None
  781. and decode_result.avg_logprob < logprob_threshold
  782. ):
  783. needs_fallback = True # average log probability is too low
  784. if not needs_fallback:
  785. break
  786. return decode_result
  787. seek = 0
  788. input_stride = exact_div(
  789. N_FRAMES, model.dims.n_audio_ctx
  790. ) # mel frames per output token: 2
  791. time_precision = (
  792. input_stride * HOP_LENGTH / SAMPLE_RATE
  793. ) # time per output token: 0.02 (seconds)
  794. all_tokens = []
  795. all_segments = []
  796. prompt_reset_since = 0
  797. initial_prompt = decode_options.pop("initial_prompt", None)
  798. if initial_prompt and initial_prompt != "None":
  799. initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
  800. all_tokens.extend(initial_prompt)
  801. else:
  802. initial_prompt = []
  803. def add_segment(
  804. *, start: float, end: float, text_tokens: paddle.Tensor, result: DecodingResult
  805. ):
  806. text = tokenizer.decode(
  807. [token for token in text_tokens if token < tokenizer.eot]
  808. )
  809. if len(text.strip()) == 0: # skip empty text output
  810. return
  811. all_segments.append(
  812. {
  813. "id": len(all_segments),
  814. "seek": seek,
  815. "start": start,
  816. "end": end,
  817. "text": text,
  818. "tokens": result.tokens,
  819. "temperature": result.temperature,
  820. "avg_logprob": result.avg_logprob,
  821. "compression_ratio": result.compression_ratio,
  822. "no_speech_prob": result.no_speech_prob,
  823. }
  824. )
  825. if verbose:
  826. print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
  827. # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
  828. num_frames = mel.shape[-1]
  829. previous_seek_value = seek
  830. with tqdm.tqdm(
  831. total=num_frames, unit="frames", disable=verbose is not False
  832. ) as pbar:
  833. while seek < num_frames:
  834. timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
  835. segment = pad_or_trim(mel[:, seek:], N_FRAMES)
  836. segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
  837. decode_options["prompt"] = all_tokens[prompt_reset_since:]
  838. result: DecodingResult = decode_with_fallback(segment)
  839. tokens = paddle.to_tensor(result.tokens)
  840. if no_speech_threshold is not None:
  841. # no voice activity check
  842. should_skip = result.no_speech_prob > no_speech_threshold
  843. if (
  844. logprob_threshold is not None
  845. and result.avg_logprob > logprob_threshold
  846. ):
  847. # don't skip if the logprob is high enough, despite the no_speech_prob
  848. should_skip = False
  849. if should_skip:
  850. seek += segment.shape[
  851. -1
  852. ] # fast-forward to the next segment boundary
  853. continue
  854. timestamp_tokens: paddle.Tensor = tokens.greater_equal(
  855. paddle.to_tensor(tokenizer.timestamp_begin)
  856. )
  857. consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
  858. if (
  859. len(consecutive) > 0
  860. ): # if the output contains two consecutive timestamp tokens
  861. consecutive = paddle.add(consecutive, paddle.to_tensor(1))
  862. last_slice = 0
  863. for current_slice in consecutive:
  864. sliced_tokens = tokens[last_slice:current_slice]
  865. start_timestamp_position = (
  866. sliced_tokens[0].item() - tokenizer.timestamp_begin
  867. )
  868. end_timestamp_position = (
  869. sliced_tokens[-1].item() - tokenizer.timestamp_begin
  870. )
  871. add_segment(
  872. start=timestamp_offset
  873. + start_timestamp_position * time_precision,
  874. end=timestamp_offset + end_timestamp_position * time_precision,
  875. text_tokens=sliced_tokens[1:-1],
  876. result=result,
  877. )
  878. last_slice = current_slice
  879. last_timestamp_position = (
  880. tokens[last_slice - 1].item() - tokenizer.timestamp_begin
  881. )
  882. seek += last_timestamp_position * input_stride
  883. all_tokens.extend(tokens[: last_slice + 1].tolist())
  884. else:
  885. duration = segment_duration
  886. timestamps = tokens[timestamp_tokens.nonzero().flatten()]
  887. if (
  888. len(timestamps) > 0
  889. and timestamps[-1].item() != tokenizer.timestamp_begin
  890. ):
  891. # no consecutive timestamps but it has a timestamp; use the last one.
  892. # single timestamp at the end means no speech after the last timestamp.
  893. last_timestamp_position = (
  894. timestamps[-1].item() - tokenizer.timestamp_begin
  895. )
  896. duration = last_timestamp_position * time_precision
  897. add_segment(
  898. start=timestamp_offset,
  899. end=timestamp_offset + duration,
  900. text_tokens=tokens,
  901. result=result,
  902. )
  903. seek += segment.shape[-1]
  904. all_tokens.extend(tokens.tolist())
  905. if not condition_on_previous_text or result.temperature > 0.5:
  906. # do not feed the prompt tokens if a high temperature was used
  907. prompt_reset_since = len(all_tokens)
  908. # update progress bar
  909. pbar.update(min(num_frames, seek) - previous_seek_value)
  910. previous_seek_value = seek
  911. return dict(
  912. text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
  913. segments=all_segments,
  914. language=language,
  915. )
  916. class SequenceRanker:
  917. def rank(
  918. self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
  919. ) -> List[int]:
  920. """
  921. Given a list of groups of samples and their cumulative log probabilities,
  922. return the indices of the samples in each group to select as the final result
  923. """
  924. raise NotImplementedError
  925. class MaximumLikelihoodRanker(SequenceRanker):
  926. """
  927. Select the sample with the highest log probabilities, penalized using either
  928. a simple length normalization or Google NMT paper's length penalty
  929. """
  930. def __init__(self, length_penalty: Optional[float]):
  931. self.length_penalty = length_penalty
  932. def rank(self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]):
  933. def scores(logprobs, lengths):
  934. result = []
  935. for logprob, length in zip(logprobs, lengths):
  936. if self.length_penalty is None or self.length_penalty == "None":
  937. penalty = length
  938. else:
  939. # from the Google NMT paper
  940. penalty = ((5 + length) / 6) ** self.length_penalty
  941. result.append(logprob / penalty)
  942. return result
  943. # get the sequence with the highest score
  944. lengths = [[len(t) for t in s] for s in tokens]
  945. return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
  946. class TokenDecoder:
  947. def reset(self):
  948. """Initialize any stateful variables for decoding a new sequence"""
  949. def update(
  950. self, tokens: paddle.Tensor, logits: paddle.Tensor, sum_logprobs: paddle.Tensor
  951. ) -> Tuple[paddle.Tensor, bool]:
  952. """Specify how to select the next token, based on the current trace and logits
  953. Parameters
  954. ----------
  955. tokens : Tensor, shape = (n_batch, current_sequence_length)
  956. all tokens in the context so far, including the prefix and sot_sequence tokens
  957. logits : Tensor, shape = (n_batch, vocab_size)
  958. per-token logits of the probability distribution at the current step
  959. sum_logprobs : Tensor, shape = (n_batch)
  960. cumulative log probabilities for each sequence
  961. Returns
  962. -------
  963. tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
  964. the tokens, appended with the selected next token
  965. completed : bool
  966. True if all sequences has reached the end of text
  967. """
  968. raise NotImplementedError
  969. def finalize(
  970. self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
  971. ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
  972. """Finalize search and return the final candidate sequences
  973. Parameters
  974. ----------
  975. tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
  976. all tokens in the context so far, including the prefix and sot_sequence
  977. sum_logprobs : Tensor, shape = (batch_size, beam_size)
  978. cumulative log probabilities for each sequence
  979. Returns
  980. -------
  981. tokens : Sequence[Sequence[Tensor]], length = batch_size
  982. sequence of Tensors containing candidate token sequences, for each audio input
  983. sum_logprobs : List[List[float]], length = batch_size
  984. sequence of cumulative log probabilities corresponding to the above
  985. """
  986. raise NotImplementedError
  987. class GreedyDecoder(TokenDecoder):
  988. def __init__(self, temperature: float, eot: int):
  989. self.temperature = temperature
  990. self.eot = eot
  991. def update(
  992. self, tokens: paddle.Tensor, logits: paddle.Tensor, sum_logprobs: paddle.Tensor
  993. ) -> Tuple[paddle.Tensor, bool]:
  994. temperature = self.temperature
  995. if temperature == 0:
  996. next_tokens = paddle.argmax(logits, axis=-1)
  997. else:
  998. next_tokens = paddle.distribution.Categorical(
  999. logits=logits / temperature
  1000. ).sample([1])
  1001. next_tokens = paddle.reshape(
  1002. next_tokens,
  1003. [
  1004. next_tokens.shape[0] * next_tokens.shape[1],
  1005. ],
  1006. )
  1007. logprobs = paddle.nn.functional.log_softmax(
  1008. logits, axis=-1, dtype=paddle.float32
  1009. )
  1010. current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
  1011. sum_logprobs += current_logprobs * paddle.to_tensor(
  1012. (tokens[:, -1] != self.eot), dtype=paddle.float32
  1013. )
  1014. next_tokens[tokens[:, -1] == self.eot] = self.eot
  1015. tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
  1016. completed = paddle.all((tokens[:, -1] == self.eot))
  1017. return tokens, completed
  1018. def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
  1019. # make sure each sequence has at least one EOT token at the end
  1020. tokens = paddle.nn.functional.pad(
  1021. tokens, (0, 1), value=self.eot, data_format="NCL"
  1022. )
  1023. return tokens, sum_logprobs.tolist()
  1024. class BeamSearchDecoder(TokenDecoder):
  1025. def __init__(
  1026. self,
  1027. beam_size: int,
  1028. eot: int,
  1029. inference: Inference,
  1030. patience: Optional[float] = None,
  1031. ):
  1032. self.beam_size = beam_size
  1033. self.eot = eot
  1034. self.inference = inference
  1035. self.patience = patience or 1.0
  1036. if patience is None or patience == "None":
  1037. self.patience = 1.0
  1038. else:
  1039. self.patience = patience
  1040. self.max_candidates: int = round(beam_size * self.patience)
  1041. self.finished_sequences = None
  1042. assert (
  1043. self.max_candidates > 0
  1044. ), f"Invalid beam size ({beam_size}) or patience ({patience})"
  1045. def reset(self):
  1046. self.finished_sequences = None
  1047. def update(
  1048. self, tokens: paddle.Tensor, logits: paddle.Tensor, sum_logprobs: paddle.Tensor
  1049. ) -> Tuple[paddle.Tensor, bool]:
  1050. if tokens.shape[0] % self.beam_size != 0:
  1051. raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
  1052. batch_size = tokens.shape[0] // self.beam_size
  1053. if self.finished_sequences is None: # for the first update
  1054. self.finished_sequences = [{} for _ in range(batch_size)]
  1055. logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
  1056. next_tokens, source_indices, finished_sequences = [], [], []
  1057. for i in range(batch_size):
  1058. scores, sources, finished = {}, {}, {}
  1059. # STEP 1: calculate the cumulative log probabilities for possible candidates
  1060. for j in range(self.beam_size):
  1061. idx = i * self.beam_size + j
  1062. prefix = tokens[idx].tolist()
  1063. logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
  1064. for logprob, token in zip(logprob, token):
  1065. new_logprob = (sum_logprobs[idx] + logprob).item()
  1066. sequence = tuple(prefix + [token.item()])
  1067. scores[sequence] = new_logprob
  1068. sources[sequence] = idx
  1069. # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
  1070. saved = 0
  1071. for sequence in sorted(scores, key=scores.get, reverse=True):
  1072. if sequence[-1] == self.eot:
  1073. finished[sequence] = scores[sequence]
  1074. else:
  1075. sum_logprobs[len(next_tokens)] = scores[sequence]
  1076. next_tokens.append(sequence)
  1077. source_indices.append(sources[sequence])
  1078. saved += 1
  1079. if saved == self.beam_size:
  1080. break
  1081. finished_sequences.append(finished)
  1082. tokens = paddle.to_tensor(next_tokens)
  1083. self.inference.rearrange_kv_cache(source_indices)
  1084. # add newly finished sequences to self.finished_sequences
  1085. assert len(self.finished_sequences) == len(finished_sequences)
  1086. for previously_finished, newly_finished in zip(
  1087. self.finished_sequences, finished_sequences
  1088. ):
  1089. for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
  1090. if len(previously_finished) >= self.max_candidates:
  1091. break # the candidate list is full
  1092. previously_finished[seq] = newly_finished[seq]
  1093. # mark as completed if all audio has enough number of samples
  1094. completed = all(
  1095. len(sequences) >= self.max_candidates
  1096. for sequences in self.finished_sequences
  1097. )
  1098. return tokens, completed
  1099. def finalize(self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
  1100. # collect all finished sequences, including patience, and add unfinished ones if not enough
  1101. sum_logprobs = sum_logprobs.cpu()
  1102. for i, sequences in enumerate(self.finished_sequences):
  1103. if (
  1104. len(sequences) < self.beam_size
  1105. ): # when not enough sequences are finished
  1106. for j in list(np.argsort(sum_logprobs[i]))[::-1]:
  1107. sequence = preceding_tokens[i, j].tolist() + [self.eot]
  1108. sequences[tuple(sequence)] = sum_logprobs[i][j].item()
  1109. if len(sequences) >= self.beam_size:
  1110. break
  1111. tokens: List[List[paddle.Tensor]] = [
  1112. [paddle.to_tensor(seq) for seq in sequences.keys()]
  1113. for sequences in self.finished_sequences
  1114. ]
  1115. sum_logprobs: List[List[float]] = [
  1116. list(sequences.values()) for sequences in self.finished_sequences
  1117. ]
  1118. return tokens, sum_logprobs
  1119. class LogitFilter:
  1120. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
  1121. """Apply any filtering or masking to logits in-place
  1122. Parameters
  1123. ----------
  1124. logits : Tensor, shape = (n_batch, vocab_size)
  1125. per-token logits of the probability distribution at the current step
  1126. tokens : Tensor, shape = (n_batch, current_sequence_length)
  1127. all tokens in the context so far, including the prefix and sot_sequence tokens
  1128. """
  1129. raise NotImplementedError
  1130. class SuppressBlank(LogitFilter):
  1131. def __init__(self, tokenizer: Tokenizer, sample_begin: int):
  1132. self.tokenizer = tokenizer
  1133. self.sample_begin = sample_begin
  1134. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1135. if tokens.shape[1] == self.sample_begin:
  1136. logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
  1137. -np.inf
  1138. )
  1139. class SuppressTokens(LogitFilter):
  1140. def __init__(self, suppress_tokens: Sequence[int]):
  1141. self.suppress_tokens = list(suppress_tokens)
  1142. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1143. logits[:, self.suppress_tokens] = -np.inf
  1144. class ApplyTimestampRules(LogitFilter):
  1145. def __init__(
  1146. self,
  1147. tokenizer: Tokenizer,
  1148. sample_begin: int,
  1149. max_initial_timestamp_index: Optional[int],
  1150. ):
  1151. self.tokenizer = tokenizer
  1152. self.sample_begin = sample_begin
  1153. self.max_initial_timestamp_index = max_initial_timestamp_index
  1154. def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
  1155. # suppress <|notimestamps|> which is handled by without_timestamps
  1156. if self.tokenizer.no_timestamps is not None:
  1157. logits[:, self.tokenizer.no_timestamps] = -np.inf
  1158. # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
  1159. for k in range(tokens.shape[0]):
  1160. seq = [t for t in tokens[k, self.sample_begin :].tolist()]
  1161. last_was_timestamp = (
  1162. len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
  1163. )
  1164. penultimate_was_timestamp = (
  1165. len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
  1166. )
  1167. if last_was_timestamp:
  1168. if penultimate_was_timestamp: # has to be non-timestamp
  1169. logits[k, self.tokenizer.timestamp_begin :] = -np.inf
  1170. else: # cannot be normal text tokens
  1171. logits[k, : self.tokenizer.eot] = -np.inf
  1172. # apply the `max_initial_timestamp` option
  1173. if (
  1174. tokens.shape[1] == self.sample_begin
  1175. and self.max_initial_timestamp_index is not None
  1176. ):
  1177. last_allowed = (
  1178. self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
  1179. )
  1180. logits[:, last_allowed + 1 :] = -np.inf
  1181. # if sum of probability over timestamps is above any other token, sample timestamp
  1182. logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
  1183. for k in range(tokens.shape[0]):
  1184. # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
  1185. # To bypass this issue in CI, we have decomposed the operation into separate steps.
  1186. # It will raise 2e-6 difference in precision.
  1187. # TODO: revert this after logsumexp been fixed.
  1188. timestamp_logprob = paddle.exp(
  1189. logprobs[k, self.tokenizer.timestamp_begin :]
  1190. )
  1191. timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
  1192. timestamp_logprob = paddle.log(timestamp_logprob)
  1193. max_text_token_logprob = paddle.max(
  1194. logprobs[k, : self.tokenizer.timestamp_begin]
  1195. )
  1196. if timestamp_logprob > max_text_token_logprob:
  1197. logits[k, : self.tokenizer.timestamp_begin] = -np.inf
  1198. class DecodingTask:
  1199. inference: Inference
  1200. sequence_ranker: SequenceRanker
  1201. decoder: TokenDecoder
  1202. logit_filters: List[LogitFilter]
  1203. def __init__(self, model: "Whisper", options: DecodingOptions, resource_path: str):
  1204. self.model = model
  1205. language = options.language or "en"
  1206. tokenizer = get_tokenizer(
  1207. model.is_multilingual,
  1208. resource_path=resource_path,
  1209. language=language,
  1210. task=options.task,
  1211. )
  1212. self.tokenizer: Tokenizer = tokenizer
  1213. self.options: DecodingOptions = self._verify_options(options)
  1214. self.resource_path: str = resource_path
  1215. self.beam_size: int = options.beam_size or options.best_of or 1
  1216. self.n_ctx: int = model.dims.n_text_ctx
  1217. self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
  1218. self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
  1219. if self.options.without_timestamps:
  1220. self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
  1221. self.initial_tokens: Tuple[int] = self._get_initial_tokens()
  1222. self.sample_begin: int = len(self.initial_tokens)
  1223. self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
  1224. # inference: implements the forward pass through the decoder, including kv caching
  1225. self.inference = WhisperInference(model, len(self.initial_tokens))
  1226. # sequence ranker: implements how to rank a group of sampled sequences
  1227. self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
  1228. # decoder: implements how to select the next tokens, given the autoregressive distribution
  1229. if options.beam_size is not None:
  1230. self.decoder = BeamSearchDecoder(
  1231. options.beam_size, tokenizer.eot, self.inference, options.patience
  1232. )
  1233. else:
  1234. self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
  1235. # logit filters: applies various rules to suppress or penalize certain tokens
  1236. self.logit_filters = []
  1237. if self.options.suppress_blank:
  1238. self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
  1239. if self.options.suppress_tokens:
  1240. self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
  1241. if not options.without_timestamps:
  1242. precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
  1243. max_initial_timestamp_index = None
  1244. if options.max_initial_timestamp:
  1245. max_initial_timestamp_index = round(
  1246. self.options.max_initial_timestamp / precision
  1247. )
  1248. self.logit_filters.append(
  1249. ApplyTimestampRules(
  1250. tokenizer, self.sample_begin, max_initial_timestamp_index
  1251. )
  1252. )
  1253. def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
  1254. if options.beam_size is not None and options.best_of is not None:
  1255. raise ValueError("beam_size and best_of can't be given together")
  1256. if options.temperature == 0:
  1257. if options.best_of is not None:
  1258. raise ValueError("best_of with greedy sampling (T=0) is not compatible")
  1259. if options.patience is not None and options.beam_size is None:
  1260. raise ValueError("patience requires beam_size to be given")
  1261. if options.length_penalty is not None and options.length_penalty != "None":
  1262. if not (0 <= options.length_penalty <= 1):
  1263. raise ValueError(
  1264. "length_penalty (alpha) should be a value between 0 and 1"
  1265. )
  1266. return options
  1267. def _get_initial_tokens(self) -> Tuple[int]:
  1268. tokens = list(self.sot_sequence)
  1269. prefix = self.options.prefix
  1270. prompt = self.options.prompt
  1271. if prefix:
  1272. prefix_tokens = (
  1273. self.tokenizer.encode(" " + prefix.strip().input_ids)
  1274. if isinstance(prefix, str)
  1275. else prefix
  1276. )
  1277. if self.sample_len is not None:
  1278. max_prefix_len = self.n_ctx // 2 - self.sample_len
  1279. prefix_tokens = prefix_tokens[-max_prefix_len:]
  1280. tokens = tokens + prefix_tokens
  1281. if prompt:
  1282. prompt_tokens = (
  1283. self.tokenizer.encode(" " + prompt.strip().input_ids)
  1284. if isinstance(prompt, str)
  1285. else prompt
  1286. )
  1287. tokens = (
  1288. [self.tokenizer.sot_prev]
  1289. + prompt_tokens[-(self.n_ctx // 2 - 1) :]
  1290. + tokens
  1291. )
  1292. return tuple(tokens)
  1293. def _get_suppress_tokens(self) -> Tuple[int]:
  1294. suppress_tokens = self.options.suppress_tokens
  1295. if isinstance(suppress_tokens, str):
  1296. suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
  1297. if -1 in suppress_tokens:
  1298. suppress_tokens = [t for t in suppress_tokens if t >= 0]
  1299. suppress_tokens.extend(self.tokenizer.non_speech_tokens)
  1300. elif suppress_tokens is None or len(suppress_tokens) == 0:
  1301. suppress_tokens = [] # interpret empty string as an empty list
  1302. else:
  1303. assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
  1304. suppress_tokens.extend(
  1305. [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
  1306. )
  1307. if self.tokenizer.no_speech is not None:
  1308. # no-speech probability is collected separately
  1309. suppress_tokens.append(self.tokenizer.no_speech)
  1310. return tuple(sorted(set(suppress_tokens)))
  1311. def _get_audio_features(self, mel: paddle.Tensor):
  1312. if mel.shape[-2:] == (
  1313. self.model.dims.n_audio_ctx,
  1314. self.model.dims.n_audio_state,
  1315. ):
  1316. # encoded audio features are given; skip audio encoding
  1317. audio_features = mel
  1318. else:
  1319. audio_features = self.model.encoder(mel)
  1320. return audio_features
  1321. def _detect_language(
  1322. self, audio_features: paddle.Tensor, tokens: paddle.Tensor, resource_path: str
  1323. ):
  1324. languages = [self.options.language] * audio_features.shape[0]
  1325. lang_probs = None
  1326. if self.options.language is None or self.options.task == "lang_id":
  1327. lang_tokens, lang_probs = self.model.detect_language(
  1328. audio_features, self.tokenizer, self.resource_path
  1329. )
  1330. languages = [max(probs, key=probs.get) for probs in lang_probs]
  1331. if self.options.language is None:
  1332. tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
  1333. return languages, lang_probs
  1334. def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
  1335. assert audio_features.shape[0] == tokens.shape[0]
  1336. n_batch = tokens.shape[0]
  1337. sum_logprobs: paddle.Tensor = paddle.zeros(
  1338. paddle.to_tensor(n_batch), dtype=paddle.float32
  1339. )
  1340. no_speech_probs = [np.nan] * n_batch
  1341. try:
  1342. for i in range(self.sample_len):
  1343. logits = self.inference.logits(tokens, audio_features)
  1344. if (
  1345. i == 0 and self.tokenizer.no_speech is not None
  1346. ): # save no_speech_probs
  1347. probs_at_sot = paddle.nn.functional.softmax(
  1348. logits[:, self.sot_index], axis=-1, dtype=paddle.float32
  1349. )
  1350. no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
  1351. # now we need to consider the logits at the last token only
  1352. logits = logits[:, -1]
  1353. # apply the logit filters, e.g. for suppressing or applying penalty to
  1354. for logit_filter in self.logit_filters:
  1355. logit_filter.apply(logits, tokens)
  1356. # expand the tokens tensor with the selected next tokens
  1357. tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
  1358. if completed or tokens.shape[-1] > self.n_ctx:
  1359. break
  1360. finally:
  1361. self.inference.cleanup_caching()
  1362. return tokens, sum_logprobs, no_speech_probs
  1363. @paddle.no_grad()
  1364. def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
  1365. self.decoder.reset()
  1366. tokenizer: Tokenizer = self.tokenizer
  1367. batch_size: int = mel.shape[0]
  1368. audio_features: paddle.Tensor = self._get_audio_features(
  1369. mel
  1370. ) # encoder forward pass
  1371. tokens: paddle.Tensor
  1372. if batch_size > 1:
  1373. for i in range(batch_size):
  1374. tokens = paddle.concat(
  1375. x=[
  1376. paddle.to_tensor([self.initial_tokens]),
  1377. paddle.to_tensor([self.initial_tokens]),
  1378. ],
  1379. axis=0,
  1380. )
  1381. elif batch_size == 1:
  1382. tokens = paddle.to_tensor([self.initial_tokens])
  1383. # detect language if requested, overwriting the language token
  1384. languages, language_probs = self._detect_language(
  1385. paddle.to_tensor(audio_features),
  1386. paddle.to_tensor(tokens),
  1387. self.resource_path,
  1388. )
  1389. if self.options.task == "lang_id":
  1390. return [
  1391. DecodingResult(
  1392. audio_features=features, language=language, language_probs=probs
  1393. )
  1394. for features, language, probs in zip(
  1395. audio_features, languages, language_probs
  1396. )
  1397. ]
  1398. # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
  1399. audio_features = paddle.repeat_interleave(
  1400. audio_features, self.beam_size, axis=0
  1401. )
  1402. tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
  1403. # call the main sampling loop
  1404. tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
  1405. # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
  1406. audio_features = audio_features[:: self.beam_size]
  1407. no_speech_probs = no_speech_probs[:: self.beam_size]
  1408. assert audio_features.shape[0] == len(no_speech_probs) == batch_size
  1409. tokens = tokens.reshape([batch_size, self.beam_size, -1])
  1410. sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
  1411. # get the final candidates for each group, and slice between the first sampled token and EOT
  1412. tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
  1413. tokens: List[List[paddle.Tensor]] = [
  1414. [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
  1415. for s in tokens
  1416. ]
  1417. # select the top-ranked sample in each group
  1418. selected = self.sequence_ranker.rank(tokens, sum_logprobs)
  1419. tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
  1420. texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
  1421. sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
  1422. avg_logprobs: List[float] = [
  1423. lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
  1424. ]
  1425. fields = (
  1426. texts,
  1427. languages,
  1428. tokens,
  1429. audio_features,
  1430. avg_logprobs,
  1431. no_speech_probs,
  1432. )
  1433. if len(set(map(len, fields))) != 1:
  1434. raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
  1435. return [
  1436. DecodingResult(
  1437. audio_features=features,
  1438. language=language,
  1439. tokens=tokens,
  1440. text=text,
  1441. avg_logprob=avg_logprob,
  1442. no_speech_prob=no_speech_prob,
  1443. temperature=self.options.temperature,
  1444. compression_ratio=compression_ratio(text),
  1445. )
  1446. for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
  1447. *fields
  1448. )
  1449. ]
  1450. @paddle.no_grad()
  1451. def decode(
  1452. model: "Whisper",
  1453. mel: paddle.Tensor,
  1454. options: DecodingOptions = DecodingOptions(),
  1455. resource_path=str,
  1456. ) -> Union[DecodingResult, List[DecodingResult]]:
  1457. """
  1458. Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
  1459. Parameters
  1460. ----------
  1461. model: Whisper
  1462. the Whisper model instance
  1463. mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
  1464. A tensor containing the Mel spectrogram(s)
  1465. options: DecodingOptions
  1466. A dataclass that contains all necessary options for decoding 30-second segments
  1467. Returns
  1468. -------
  1469. result: Union[DecodingResult, List[DecodingResult]]
  1470. The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
  1471. """
  1472. single = mel.ndim == 2
  1473. if single:
  1474. mel = mel.unsqueeze(0)
  1475. result = DecodingTask(model, options, resource_path).run(mel)
  1476. if single:
  1477. result = result[0]
  1478. return result
  1479. class Whisper(paddle.nn.Layer):
  1480. """
  1481. The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
  1482. """
  1483. def __init__(self, dims: ModelDimensions):
  1484. super().__init__()
  1485. self.dims = dims
  1486. self.encoder = AudioEncoder(
  1487. self.dims.n_mels,
  1488. self.dims.n_audio_ctx,
  1489. self.dims.n_audio_state,
  1490. self.dims.n_audio_head,
  1491. self.dims.n_audio_layer,
  1492. )
  1493. self.decoder = TextDecoder(
  1494. self.dims.n_vocab,
  1495. self.dims.n_text_ctx,
  1496. self.dims.n_text_state,
  1497. self.dims.n_text_head,
  1498. self.dims.n_text_layer,
  1499. )
  1500. def embed_audio(self, mel: paddle.Tensor):
  1501. return self.encoder.forward(mel)
  1502. def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
  1503. return self.decoder.forward(tokens, audio_features)
  1504. def forward(
  1505. self, mel: paddle.Tensor, tokens: paddle.Tensor
  1506. ) -> Dict[str, paddle.Tensor]:
  1507. return self.decoder(tokens, self.encoder(mel))
  1508. @property
  1509. def device(self):
  1510. return paddle.device.get_device()
  1511. @property
  1512. def is_multilingual(self):
  1513. return self.dims.n_vocab == 51865
  1514. def install_kv_cache_hooks(self, cache: Optional[dict] = None):
  1515. """
  1516. The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
  1517. tensors calculated for the previous positions. This method returns a dictionary that stores
  1518. all caches, and the necessary hooks for the key and value projection modules that save the
  1519. intermediate tensors to be reused during later calculations.
  1520. Returns
  1521. -------
  1522. cache : Dict[nn.Layer, paddle.Tensor]
  1523. A dictionary object mapping the key/value projection modules to its cache
  1524. hooks : List[RemovableHandle]
  1525. List of PyTorch RemovableHandle objects to stop the hooks to be called
  1526. """
  1527. cache = {**cache} if cache is not None else {}
  1528. hooks = []
  1529. def save_to_cache(module, _, output):
  1530. if (
  1531. module not in cache
  1532. or output.shape[1] > self.decoder.positional_embedding.shape[0]
  1533. ):
  1534. cache[module] = (
  1535. output # save as-is, for the first token or cross attention
  1536. )
  1537. else:
  1538. cache[module] = paddle.concat([cache[module], output], axis=1).detach()
  1539. return cache[module]
  1540. def install_hooks(layer: paddle.nn.Layer):
  1541. if isinstance(layer, MultiHeadAttention):
  1542. hooks.append(layer.key.register_forward_post_hook(save_to_cache))
  1543. hooks.append(layer.value.register_forward_post_hook(save_to_cache))
  1544. self.decoder.apply(install_hooks)
  1545. return cache, hooks
  1546. detect_language = detect_language
  1547. transcribe = transcribe
  1548. decode = decode
  1549. def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
  1550. """
  1551. Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
  1552. """
  1553. if paddle.is_tensor(array):
  1554. if array.shape[axis] > length:
  1555. array = array.index_select(axis=axis, index=paddle.arange(length))
  1556. if array.shape[axis] < length:
  1557. pad_widths = [(0, 0)] * array.ndim
  1558. pad_widths[axis] = (0, length - array.shape[axis])
  1559. array = paddle.transpose(array, (1, 0))
  1560. array = paddle.nn.functional.pad(
  1561. array,
  1562. [pad for sizes in pad_widths[::-1] for pad in sizes],
  1563. data_format="NLC",
  1564. )
  1565. array = paddle.transpose(array, (1, 0))
  1566. else:
  1567. if array.shape[axis] > length:
  1568. array = array.take(indices=range(length), axis=axis)
  1569. if array.shape[axis] < length:
  1570. pad_widths = [(0, 0)] * array.ndim
  1571. pad_widths[axis] = (0, length - array.shape[axis])
  1572. array = paddle.transpose(array, (1, 0))
  1573. array = np.pad(array, pad_widths)
  1574. array = paddle.transpose(array, (1, 0))
  1575. return array
  1576. def hann_window(n_fft: int = N_FFT):
  1577. """
  1578. hanning window
  1579. n_fft: The number of frequency components of the discrete Fourier transform.
  1580. """
  1581. return paddle.to_tensor(
  1582. [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
  1583. dtype=paddle.float32,
  1584. )
  1585. @lru_cache(maxsize=None)
  1586. def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
  1587. """
  1588. load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
  1589. Allows decoupling librosa dependency; saved using:
  1590. np.savez_compressed(
  1591. "mel_filters.npz",
  1592. mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
  1593. )
  1594. """
  1595. assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
  1596. with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
  1597. return paddle.to_tensor(f[f"mel_{n_mels}"])
  1598. def log_mel_spectrogram(
  1599. audio: Union[str, np.ndarray, paddle.Tensor],
  1600. n_mels: int = N_MELS,
  1601. resource_path: str = None,
  1602. ):
  1603. """
  1604. Compute the log-Mel spectrogram of
  1605. Parameters
  1606. ----------
  1607. audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
  1608. The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
  1609. n_mels: int
  1610. The number of Mel-frequency filters, only 80 is supported
  1611. Returns
  1612. -------
  1613. paddle.Tensor, shape = (80, n_frames)
  1614. A Tensor that contains the Mel spectrogram
  1615. """
  1616. if not paddle.is_tensor(audio):
  1617. if isinstance(audio, str):
  1618. audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
  1619. audio = audio[:, 0]
  1620. audio = paddle.to_tensor(audio)
  1621. window = hann_window(N_FFT)
  1622. stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
  1623. magnitudes = stft[:, :-1].abs() ** 2
  1624. filters = mel_filters(resource_path, n_mels)
  1625. mel_spec = filters @ magnitudes
  1626. mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
  1627. log_spec = paddle.clip(mel_spec, min=1e-10).log10()
  1628. log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
  1629. log_spec = (log_spec + 4.0) / 4.0
  1630. return log_spec