unitable_modules.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911
  1. from dataclasses import dataclass
  2. from functools import partial
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from torch import Tensor
  7. from torch.nn import functional as F
  8. from torch.nn.modules.transformer import _get_activation_fn
  9. TOKEN_WHITE_LIST = [
  10. 1,
  11. 12,
  12. 13,
  13. 14,
  14. 15,
  15. 16,
  16. 17,
  17. 18,
  18. 19,
  19. 20,
  20. 21,
  21. 22,
  22. 23,
  23. 24,
  24. 25,
  25. 26,
  26. 27,
  27. 28,
  28. 29,
  29. 30,
  30. 31,
  31. 32,
  32. 33,
  33. 34,
  34. 35,
  35. 36,
  36. 37,
  37. 38,
  38. 39,
  39. 40,
  40. 41,
  41. 42,
  42. 43,
  43. 44,
  44. 45,
  45. 46,
  46. 47,
  47. 48,
  48. 49,
  49. 50,
  50. 51,
  51. 52,
  52. 53,
  53. 54,
  54. 55,
  55. 56,
  56. 57,
  57. 58,
  58. 59,
  59. 60,
  60. 61,
  61. 62,
  62. 63,
  63. 64,
  64. 65,
  65. 66,
  66. 67,
  67. 68,
  68. 69,
  69. 70,
  70. 71,
  71. 72,
  72. 73,
  73. 74,
  74. 75,
  75. 76,
  76. 77,
  77. 78,
  78. 79,
  79. 80,
  80. 81,
  81. 82,
  82. 83,
  83. 84,
  84. 85,
  85. 86,
  86. 87,
  87. 88,
  88. 89,
  89. 90,
  90. 91,
  91. 92,
  92. 93,
  93. 94,
  94. 95,
  95. 96,
  96. 97,
  97. 98,
  98. 99,
  99. 100,
  100. 101,
  101. 102,
  102. 103,
  103. 104,
  104. 105,
  105. 106,
  106. 107,
  107. 108,
  108. 109,
  109. 110,
  110. 111,
  111. 112,
  112. 113,
  113. 114,
  114. 115,
  115. 116,
  116. 117,
  117. 118,
  118. 119,
  119. 120,
  120. 121,
  121. 122,
  122. 123,
  123. 124,
  124. 125,
  125. 126,
  126. 127,
  127. 128,
  128. 129,
  129. 130,
  130. 131,
  131. 132,
  132. 133,
  133. 134,
  134. 135,
  135. 136,
  136. 137,
  137. 138,
  138. 139,
  139. 140,
  140. 141,
  141. 142,
  142. 143,
  143. 144,
  144. 145,
  145. 146,
  146. 147,
  147. 148,
  148. 149,
  149. 150,
  150. 151,
  151. 152,
  152. 153,
  153. 154,
  154. 155,
  155. 156,
  156. 157,
  157. 158,
  158. 159,
  159. 160,
  160. 161,
  161. 162,
  162. 163,
  163. 164,
  164. 165,
  165. 166,
  166. 167,
  167. 168,
  168. 169,
  169. 170,
  170. 171,
  171. 172,
  172. 173,
  173. 174,
  174. 175,
  175. 176,
  176. 177,
  177. 178,
  178. 179,
  179. 180,
  180. 181,
  181. 182,
  182. 183,
  183. 184,
  184. 185,
  185. 186,
  186. 187,
  187. 188,
  188. 189,
  189. 190,
  190. 191,
  191. 192,
  192. 193,
  193. 194,
  194. 195,
  195. 196,
  196. 197,
  197. 198,
  198. 199,
  199. 200,
  200. 201,
  201. 202,
  202. 203,
  203. 204,
  204. 205,
  205. 206,
  206. 207,
  207. 208,
  208. 209,
  209. 210,
  210. 211,
  211. 212,
  212. 213,
  213. 214,
  214. 215,
  215. 216,
  216. 217,
  217. 218,
  218. 219,
  219. 220,
  220. 221,
  221. 222,
  222. 223,
  223. 224,
  224. 225,
  225. 226,
  226. 227,
  227. 228,
  228. 229,
  229. 230,
  230. 231,
  231. 232,
  232. 233,
  233. 234,
  234. 235,
  235. 236,
  236. 237,
  237. 238,
  238. 239,
  239. 240,
  240. 241,
  241. 242,
  242. 243,
  243. 244,
  244. 245,
  245. 246,
  246. 247,
  247. 248,
  248. 249,
  249. 250,
  250. 251,
  251. 252,
  252. 253,
  253. 254,
  254. 255,
  255. 256,
  256. 257,
  257. 258,
  258. 259,
  259. 260,
  260. 261,
  261. 262,
  262. 263,
  263. 264,
  264. 265,
  265. 266,
  266. 267,
  267. 268,
  268. 269,
  269. 270,
  270. 271,
  271. 272,
  272. 273,
  273. 274,
  274. 275,
  275. 276,
  276. 277,
  277. 278,
  278. 279,
  279. 280,
  280. 281,
  281. 282,
  282. 283,
  283. 284,
  284. 285,
  285. 286,
  286. 287,
  287. 288,
  288. 289,
  289. 290,
  290. 291,
  291. 292,
  292. 293,
  293. 294,
  294. 295,
  295. 296,
  296. 297,
  297. 298,
  298. 299,
  299. 300,
  300. 301,
  301. 302,
  302. 303,
  303. 304,
  304. 305,
  305. 306,
  306. 307,
  307. 308,
  308. 309,
  309. 310,
  310. 311,
  311. 312,
  312. 313,
  313. 314,
  314. 315,
  315. 316,
  316. 317,
  317. 318,
  318. 319,
  319. 320,
  320. 321,
  321. 322,
  322. 323,
  323. 324,
  324. 325,
  325. 326,
  326. 327,
  327. 328,
  328. 329,
  329. 330,
  330. 331,
  331. 332,
  332. 333,
  333. 334,
  334. 335,
  335. 336,
  336. 337,
  337. 338,
  338. 339,
  339. 340,
  340. 341,
  341. 342,
  342. 343,
  343. 344,
  344. 345,
  345. 346,
  346. 347,
  347. 348,
  348. 349,
  349. 350,
  350. 351,
  351. 352,
  352. 353,
  353. 354,
  354. 355,
  355. 356,
  356. 357,
  357. 358,
  358. 359,
  359. 360,
  360. 361,
  361. 362,
  362. 363,
  363. 364,
  364. 365,
  365. 366,
  366. 367,
  367. 368,
  368. 369,
  369. 370,
  370. 371,
  371. 372,
  372. 373,
  373. 374,
  374. 375,
  375. 376,
  376. 377,
  377. 378,
  378. 379,
  379. 380,
  380. 381,
  381. 382,
  382. 383,
  383. 384,
  384. 385,
  385. 386,
  386. 387,
  387. 388,
  388. 389,
  389. 390,
  390. 391,
  391. 392,
  392. 393,
  393. 394,
  394. 395,
  395. 396,
  396. 397,
  397. 398,
  398. 399,
  399. 400,
  400. 401,
  401. 402,
  402. 403,
  403. 404,
  404. 405,
  405. 406,
  406. 407,
  407. 408,
  408. 409,
  409. 410,
  410. 411,
  411. 412,
  412. 413,
  413. 414,
  414. 415,
  415. 416,
  416. 417,
  417. 418,
  418. 419,
  419. 420,
  420. 421,
  421. 422,
  422. 423,
  423. 424,
  424. 425,
  425. 426,
  426. 427,
  427. 428,
  428. 429,
  429. 430,
  430. 431,
  431. 432,
  432. 433,
  433. 434,
  434. 435,
  435. 436,
  436. 437,
  437. 438,
  438. 439,
  439. 440,
  440. 441,
  441. 442,
  442. 443,
  443. 444,
  444. 445,
  445. 446,
  446. 447,
  447. 448,
  448. 449,
  449. 450,
  450. 451,
  451. 452,
  452. 453,
  453. 454,
  454. 455,
  455. 456,
  456. 457,
  457. 458,
  458. 459,
  459. 460,
  460. 461,
  461. 462,
  462. 463,
  463. 464,
  464. 465,
  465. 466,
  466. 467,
  467. 468,
  468. 469,
  469. 470,
  470. 471,
  471. 472,
  472. 473,
  473. 474,
  474. 475,
  475. 476,
  476. 477,
  477. 478,
  478. 479,
  479. 480,
  480. 481,
  481. 482,
  482. 483,
  483. 484,
  484. 485,
  485. 486,
  486. 487,
  487. 488,
  488. 489,
  489. 490,
  490. 491,
  491. 492,
  492. 493,
  493. 494,
  494. 495,
  495. 496,
  496. 497,
  497. 498,
  498. 499,
  499. 500,
  500. 501,
  501. 502,
  502. 503,
  503. 504,
  504. 505,
  505. 506,
  506. 507,
  507. 508,
  508. 509,
  509. ]
  510. class ImgLinearBackbone(nn.Module):
  511. def __init__(
  512. self,
  513. d_model: int,
  514. patch_size: int,
  515. in_chan: int = 3,
  516. ) -> None:
  517. super().__init__()
  518. self.conv_proj = nn.Conv2d(
  519. in_chan,
  520. out_channels=d_model,
  521. kernel_size=patch_size,
  522. stride=patch_size,
  523. )
  524. self.d_model = d_model
  525. def forward(self, x: Tensor) -> Tensor:
  526. x = self.conv_proj(x)
  527. x = x.flatten(start_dim=-2).transpose(1, 2)
  528. return x
  529. class Encoder(nn.Module):
  530. def __init__(self) -> None:
  531. super().__init__()
  532. self.patch_size = 16
  533. self.d_model = 768
  534. self.dropout = 0
  535. self.activation = "gelu"
  536. self.norm_first = True
  537. self.ff_ratio = 4
  538. self.nhead = 12
  539. self.max_seq_len = 1024
  540. self.n_encoder_layer = 12
  541. encoder_layer = nn.TransformerEncoderLayer(
  542. self.d_model,
  543. nhead=self.nhead,
  544. dim_feedforward=self.ff_ratio * self.d_model,
  545. dropout=self.dropout,
  546. activation=self.activation,
  547. batch_first=True,
  548. norm_first=self.norm_first,
  549. )
  550. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  551. self.norm = norm_layer(self.d_model)
  552. self.backbone = ImgLinearBackbone(
  553. d_model=self.d_model, patch_size=self.patch_size
  554. )
  555. self.pos_embed = PositionEmbedding(
  556. max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
  557. )
  558. self.encoder = nn.TransformerEncoder(
  559. encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
  560. )
  561. def forward(self, x: Tensor) -> Tensor:
  562. src_feature = self.backbone(x)
  563. src_feature = self.pos_embed(src_feature)
  564. memory = self.encoder(src_feature)
  565. memory = self.norm(memory)
  566. return memory
  567. class PositionEmbedding(nn.Module):
  568. def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
  569. super().__init__()
  570. self.embedding = nn.Embedding(max_seq_len, d_model)
  571. self.dropout = nn.Dropout(dropout)
  572. def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
  573. # assume x is batch first
  574. if input_pos is None:
  575. _pos = torch.arange(x.shape[1], device=x.device)
  576. else:
  577. _pos = input_pos
  578. out = self.embedding(_pos)
  579. return self.dropout(out + x)
  580. class TokenEmbedding(nn.Module):
  581. def __init__(
  582. self,
  583. vocab_size: int,
  584. d_model: int,
  585. padding_idx: int,
  586. ) -> None:
  587. super().__init__()
  588. assert vocab_size > 0
  589. self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
  590. def forward(self, x: Tensor) -> Tensor:
  591. return self.embedding(x)
  592. def find_multiple(n: int, k: int) -> int:
  593. if n % k == 0:
  594. return n
  595. return n + k - (n % k)
  596. @dataclass
  597. class ModelArgs:
  598. n_layer: int = 4
  599. n_head: int = 12
  600. dim: int = 768
  601. intermediate_size: int = None
  602. head_dim: int = 64
  603. activation: str = "gelu"
  604. norm_first: bool = True
  605. def __post_init__(self):
  606. if self.intermediate_size is None:
  607. hidden_dim = 4 * self.dim
  608. n_hidden = int(2 * hidden_dim / 3)
  609. self.intermediate_size = find_multiple(n_hidden, 256)
  610. self.head_dim = self.dim // self.n_head
  611. class KVCache(nn.Module):
  612. def __init__(
  613. self,
  614. max_batch_size,
  615. max_seq_length,
  616. n_heads,
  617. head_dim,
  618. dtype=torch.bfloat16,
  619. device="cpu",
  620. ):
  621. super().__init__()
  622. cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
  623. self.register_buffer(
  624. "k_cache",
  625. torch.zeros(cache_shape, dtype=dtype, device=device),
  626. persistent=False,
  627. )
  628. self.register_buffer(
  629. "v_cache",
  630. torch.zeros(cache_shape, dtype=dtype, device=device),
  631. persistent=False,
  632. )
  633. def update(self, input_pos, k_val, v_val):
  634. # input_pos: [S], k_val: [B, H, S, D]
  635. # assert input_pos.shape[0] == k_val.shape[2]
  636. bs = k_val.shape[0]
  637. k_out = self.k_cache
  638. v_out = self.v_cache
  639. k_out[:bs, :, input_pos] = k_val
  640. v_out[:bs, :, input_pos] = v_val
  641. return k_out[:bs], v_out[:bs]
  642. class GPTFastDecoder(nn.Module):
  643. def __init__(self) -> None:
  644. super().__init__()
  645. self.vocab_size = 960
  646. self.padding_idx = 2
  647. self.prefix_token_id = 11
  648. self.eos_id = 1
  649. self.max_seq_len = 1024
  650. self.dropout = 0
  651. self.d_model = 768
  652. self.nhead = 12
  653. self.activation = "gelu"
  654. self.norm_first = True
  655. self.n_decoder_layer = 4
  656. config = ModelArgs(
  657. n_layer=self.n_decoder_layer,
  658. n_head=self.nhead,
  659. dim=self.d_model,
  660. intermediate_size=self.d_model * 4,
  661. activation=self.activation,
  662. norm_first=self.norm_first,
  663. )
  664. self.config = config
  665. self.layers = nn.ModuleList(
  666. TransformerBlock(config) for _ in range(config.n_layer)
  667. )
  668. self.token_embed = TokenEmbedding(
  669. vocab_size=self.vocab_size,
  670. d_model=self.d_model,
  671. padding_idx=self.padding_idx,
  672. )
  673. self.pos_embed = PositionEmbedding(
  674. max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
  675. )
  676. self.generator = nn.Linear(self.d_model, self.vocab_size)
  677. self.token_white_list = TOKEN_WHITE_LIST
  678. self.mask_cache: Optional[Tensor] = None
  679. self.max_batch_size = -1
  680. self.max_seq_length = -1
  681. def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
  682. for b in self.layers:
  683. b.multihead_attn.k_cache = None
  684. b.multihead_attn.v_cache = None
  685. if (
  686. self.max_seq_length >= max_seq_length
  687. and self.max_batch_size >= max_batch_size
  688. ):
  689. return
  690. head_dim = self.config.dim // self.config.n_head
  691. max_seq_length = find_multiple(max_seq_length, 8)
  692. self.max_seq_length = max_seq_length
  693. self.max_batch_size = max_batch_size
  694. for b in self.layers:
  695. b.self_attn.kv_cache = KVCache(
  696. max_batch_size,
  697. max_seq_length,
  698. self.config.n_head,
  699. head_dim,
  700. dtype,
  701. device,
  702. )
  703. b.multihead_attn.k_cache = None
  704. b.multihead_attn.v_cache = None
  705. self.causal_mask = torch.tril(
  706. torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
  707. ).to(device)
  708. def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
  709. input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
  710. tgt = tgt[:, -1:]
  711. tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
  712. # tgt = self.decoder(tgt_feature, memory, input_pos)
  713. with torch.backends.cuda.sdp_kernel(
  714. enable_flash=False, enable_mem_efficient=False, enable_math=True
  715. ):
  716. logits = tgt_feature
  717. tgt_mask = self.causal_mask[None, None, input_pos]
  718. for i, layer in enumerate(self.layers):
  719. logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
  720. # return output
  721. logits = self.generator(logits)[:, -1, :]
  722. total = set([i for i in range(logits.shape[-1])])
  723. black_list = list(total.difference(set(self.token_white_list)))
  724. logits[..., black_list] = -1e9
  725. probs = F.softmax(logits, dim=-1)
  726. _, next_tokens = probs.topk(1)
  727. return next_tokens
  728. class TransformerBlock(nn.Module):
  729. def __init__(self, config: ModelArgs) -> None:
  730. super().__init__()
  731. self.self_attn = Attention(config)
  732. self.multihead_attn = CrossAttention(config)
  733. layer_norm_eps = 1e-5
  734. d_model = config.dim
  735. dim_feedforward = config.intermediate_size
  736. self.linear1 = nn.Linear(d_model, dim_feedforward)
  737. self.linear2 = nn.Linear(dim_feedforward, d_model)
  738. self.norm_first = config.norm_first
  739. self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  740. self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  741. self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  742. self.activation = _get_activation_fn(config.activation)
  743. def forward(
  744. self,
  745. tgt: Tensor,
  746. memory: Tensor,
  747. tgt_mask: Tensor,
  748. input_pos: Tensor,
  749. ) -> Tensor:
  750. if self.norm_first:
  751. x = tgt
  752. x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos)
  753. x = x + self.multihead_attn(self.norm2(x), memory)
  754. x = x + self._ff_block(self.norm3(x))
  755. else:
  756. x = tgt
  757. x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos))
  758. x = self.norm2(x + self.multihead_attn(x, memory))
  759. x = self.norm3(x + self._ff_block(x))
  760. return x
  761. def _ff_block(self, x: Tensor) -> Tensor:
  762. x = self.linear2(self.activation(self.linear1(x)))
  763. return x
  764. class Attention(nn.Module):
  765. def __init__(self, config: ModelArgs):
  766. super().__init__()
  767. assert config.dim % config.n_head == 0
  768. # key, query, value projections for all heads, but in a batch
  769. self.wqkv = nn.Linear(config.dim, 3 * config.dim)
  770. self.wo = nn.Linear(config.dim, config.dim)
  771. self.kv_cache: Optional[KVCache] = None
  772. self.n_head = config.n_head
  773. self.head_dim = config.head_dim
  774. self.dim = config.dim
  775. def forward(
  776. self,
  777. x: Tensor,
  778. mask: Tensor,
  779. input_pos: Optional[Tensor] = None,
  780. ) -> Tensor:
  781. bsz, seqlen, _ = x.shape
  782. kv_size = self.n_head * self.head_dim
  783. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  784. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  785. k = k.view(bsz, seqlen, self.n_head, self.head_dim)
  786. v = v.view(bsz, seqlen, self.n_head, self.head_dim)
  787. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  788. if self.kv_cache is not None:
  789. k, v = self.kv_cache.update(input_pos, k, v)
  790. y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
  791. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  792. y = self.wo(y)
  793. return y
  794. class CrossAttention(nn.Module):
  795. def __init__(self, config: ModelArgs):
  796. super().__init__()
  797. assert config.dim % config.n_head == 0
  798. self.query = nn.Linear(config.dim, config.dim)
  799. self.key = nn.Linear(config.dim, config.dim)
  800. self.value = nn.Linear(config.dim, config.dim)
  801. self.out = nn.Linear(config.dim, config.dim)
  802. self.k_cache = None
  803. self.v_cache = None
  804. self.n_head = config.n_head
  805. self.head_dim = config.head_dim
  806. def get_kv(self, xa: torch.Tensor):
  807. if self.k_cache is not None and self.v_cache is not None:
  808. return self.k_cache, self.v_cache
  809. k = self.key(xa)
  810. v = self.value(xa)
  811. # Reshape for correct format
  812. batch_size, source_seq_len, _ = k.shape
  813. k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim)
  814. v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim)
  815. if self.k_cache is None:
  816. self.k_cache = k
  817. if self.v_cache is None:
  818. self.v_cache = v
  819. return k, v
  820. def forward(
  821. self,
  822. x: Tensor,
  823. xa: Tensor,
  824. ):
  825. q = self.query(x)
  826. batch_size, target_seq_len, _ = q.shape
  827. q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim)
  828. k, v = self.get_kv(xa)
  829. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  830. wv = F.scaled_dot_product_attention(
  831. query=q,
  832. key=k,
  833. value=v,
  834. is_causal=False,
  835. )
  836. wv = wv.transpose(1, 2).reshape(
  837. batch_size,
  838. target_seq_len,
  839. self.n_head * self.head_dim,
  840. )
  841. return self.out(wv)