| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911 |
- from dataclasses import dataclass
- from functools import partial
- from typing import Optional
- import torch
- import torch.nn as nn
- from torch import Tensor
- from torch.nn import functional as F
- from torch.nn.modules.transformer import _get_activation_fn
- TOKEN_WHITE_LIST = [
- 1,
- 12,
- 13,
- 14,
- 15,
- 16,
- 17,
- 18,
- 19,
- 20,
- 21,
- 22,
- 23,
- 24,
- 25,
- 26,
- 27,
- 28,
- 29,
- 30,
- 31,
- 32,
- 33,
- 34,
- 35,
- 36,
- 37,
- 38,
- 39,
- 40,
- 41,
- 42,
- 43,
- 44,
- 45,
- 46,
- 47,
- 48,
- 49,
- 50,
- 51,
- 52,
- 53,
- 54,
- 55,
- 56,
- 57,
- 58,
- 59,
- 60,
- 61,
- 62,
- 63,
- 64,
- 65,
- 66,
- 67,
- 68,
- 69,
- 70,
- 71,
- 72,
- 73,
- 74,
- 75,
- 76,
- 77,
- 78,
- 79,
- 80,
- 81,
- 82,
- 83,
- 84,
- 85,
- 86,
- 87,
- 88,
- 89,
- 90,
- 91,
- 92,
- 93,
- 94,
- 95,
- 96,
- 97,
- 98,
- 99,
- 100,
- 101,
- 102,
- 103,
- 104,
- 105,
- 106,
- 107,
- 108,
- 109,
- 110,
- 111,
- 112,
- 113,
- 114,
- 115,
- 116,
- 117,
- 118,
- 119,
- 120,
- 121,
- 122,
- 123,
- 124,
- 125,
- 126,
- 127,
- 128,
- 129,
- 130,
- 131,
- 132,
- 133,
- 134,
- 135,
- 136,
- 137,
- 138,
- 139,
- 140,
- 141,
- 142,
- 143,
- 144,
- 145,
- 146,
- 147,
- 148,
- 149,
- 150,
- 151,
- 152,
- 153,
- 154,
- 155,
- 156,
- 157,
- 158,
- 159,
- 160,
- 161,
- 162,
- 163,
- 164,
- 165,
- 166,
- 167,
- 168,
- 169,
- 170,
- 171,
- 172,
- 173,
- 174,
- 175,
- 176,
- 177,
- 178,
- 179,
- 180,
- 181,
- 182,
- 183,
- 184,
- 185,
- 186,
- 187,
- 188,
- 189,
- 190,
- 191,
- 192,
- 193,
- 194,
- 195,
- 196,
- 197,
- 198,
- 199,
- 200,
- 201,
- 202,
- 203,
- 204,
- 205,
- 206,
- 207,
- 208,
- 209,
- 210,
- 211,
- 212,
- 213,
- 214,
- 215,
- 216,
- 217,
- 218,
- 219,
- 220,
- 221,
- 222,
- 223,
- 224,
- 225,
- 226,
- 227,
- 228,
- 229,
- 230,
- 231,
- 232,
- 233,
- 234,
- 235,
- 236,
- 237,
- 238,
- 239,
- 240,
- 241,
- 242,
- 243,
- 244,
- 245,
- 246,
- 247,
- 248,
- 249,
- 250,
- 251,
- 252,
- 253,
- 254,
- 255,
- 256,
- 257,
- 258,
- 259,
- 260,
- 261,
- 262,
- 263,
- 264,
- 265,
- 266,
- 267,
- 268,
- 269,
- 270,
- 271,
- 272,
- 273,
- 274,
- 275,
- 276,
- 277,
- 278,
- 279,
- 280,
- 281,
- 282,
- 283,
- 284,
- 285,
- 286,
- 287,
- 288,
- 289,
- 290,
- 291,
- 292,
- 293,
- 294,
- 295,
- 296,
- 297,
- 298,
- 299,
- 300,
- 301,
- 302,
- 303,
- 304,
- 305,
- 306,
- 307,
- 308,
- 309,
- 310,
- 311,
- 312,
- 313,
- 314,
- 315,
- 316,
- 317,
- 318,
- 319,
- 320,
- 321,
- 322,
- 323,
- 324,
- 325,
- 326,
- 327,
- 328,
- 329,
- 330,
- 331,
- 332,
- 333,
- 334,
- 335,
- 336,
- 337,
- 338,
- 339,
- 340,
- 341,
- 342,
- 343,
- 344,
- 345,
- 346,
- 347,
- 348,
- 349,
- 350,
- 351,
- 352,
- 353,
- 354,
- 355,
- 356,
- 357,
- 358,
- 359,
- 360,
- 361,
- 362,
- 363,
- 364,
- 365,
- 366,
- 367,
- 368,
- 369,
- 370,
- 371,
- 372,
- 373,
- 374,
- 375,
- 376,
- 377,
- 378,
- 379,
- 380,
- 381,
- 382,
- 383,
- 384,
- 385,
- 386,
- 387,
- 388,
- 389,
- 390,
- 391,
- 392,
- 393,
- 394,
- 395,
- 396,
- 397,
- 398,
- 399,
- 400,
- 401,
- 402,
- 403,
- 404,
- 405,
- 406,
- 407,
- 408,
- 409,
- 410,
- 411,
- 412,
- 413,
- 414,
- 415,
- 416,
- 417,
- 418,
- 419,
- 420,
- 421,
- 422,
- 423,
- 424,
- 425,
- 426,
- 427,
- 428,
- 429,
- 430,
- 431,
- 432,
- 433,
- 434,
- 435,
- 436,
- 437,
- 438,
- 439,
- 440,
- 441,
- 442,
- 443,
- 444,
- 445,
- 446,
- 447,
- 448,
- 449,
- 450,
- 451,
- 452,
- 453,
- 454,
- 455,
- 456,
- 457,
- 458,
- 459,
- 460,
- 461,
- 462,
- 463,
- 464,
- 465,
- 466,
- 467,
- 468,
- 469,
- 470,
- 471,
- 472,
- 473,
- 474,
- 475,
- 476,
- 477,
- 478,
- 479,
- 480,
- 481,
- 482,
- 483,
- 484,
- 485,
- 486,
- 487,
- 488,
- 489,
- 490,
- 491,
- 492,
- 493,
- 494,
- 495,
- 496,
- 497,
- 498,
- 499,
- 500,
- 501,
- 502,
- 503,
- 504,
- 505,
- 506,
- 507,
- 508,
- 509,
- ]
- class ImgLinearBackbone(nn.Module):
- def __init__(
- self,
- d_model: int,
- patch_size: int,
- in_chan: int = 3,
- ) -> None:
- super().__init__()
- self.conv_proj = nn.Conv2d(
- in_chan,
- out_channels=d_model,
- kernel_size=patch_size,
- stride=patch_size,
- )
- self.d_model = d_model
- def forward(self, x: Tensor) -> Tensor:
- x = self.conv_proj(x)
- x = x.flatten(start_dim=-2).transpose(1, 2)
- return x
- class Encoder(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.patch_size = 16
- self.d_model = 768
- self.dropout = 0
- self.activation = "gelu"
- self.norm_first = True
- self.ff_ratio = 4
- self.nhead = 12
- self.max_seq_len = 1024
- self.n_encoder_layer = 12
- encoder_layer = nn.TransformerEncoderLayer(
- self.d_model,
- nhead=self.nhead,
- dim_feedforward=self.ff_ratio * self.d_model,
- dropout=self.dropout,
- activation=self.activation,
- batch_first=True,
- norm_first=self.norm_first,
- )
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
- self.norm = norm_layer(self.d_model)
- self.backbone = ImgLinearBackbone(
- d_model=self.d_model, patch_size=self.patch_size
- )
- self.pos_embed = PositionEmbedding(
- max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
- )
- self.encoder = nn.TransformerEncoder(
- encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
- )
- def forward(self, x: Tensor) -> Tensor:
- src_feature = self.backbone(x)
- src_feature = self.pos_embed(src_feature)
- memory = self.encoder(src_feature)
- memory = self.norm(memory)
- return memory
- class PositionEmbedding(nn.Module):
- def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
- super().__init__()
- self.embedding = nn.Embedding(max_seq_len, d_model)
- self.dropout = nn.Dropout(dropout)
- def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
- # assume x is batch first
- if input_pos is None:
- _pos = torch.arange(x.shape[1], device=x.device)
- else:
- _pos = input_pos
- out = self.embedding(_pos)
- return self.dropout(out + x)
- class TokenEmbedding(nn.Module):
- def __init__(
- self,
- vocab_size: int,
- d_model: int,
- padding_idx: int,
- ) -> None:
- super().__init__()
- assert vocab_size > 0
- self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
- def forward(self, x: Tensor) -> Tensor:
- return self.embedding(x)
- def find_multiple(n: int, k: int) -> int:
- if n % k == 0:
- return n
- return n + k - (n % k)
- @dataclass
- class ModelArgs:
- n_layer: int = 4
- n_head: int = 12
- dim: int = 768
- intermediate_size: int = None
- head_dim: int = 64
- activation: str = "gelu"
- norm_first: bool = True
- def __post_init__(self):
- if self.intermediate_size is None:
- hidden_dim = 4 * self.dim
- n_hidden = int(2 * hidden_dim / 3)
- self.intermediate_size = find_multiple(n_hidden, 256)
- self.head_dim = self.dim // self.n_head
- class KVCache(nn.Module):
- def __init__(
- self,
- max_batch_size,
- max_seq_length,
- n_heads,
- head_dim,
- dtype=torch.bfloat16,
- device="cpu",
- ):
- super().__init__()
- cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
- self.register_buffer(
- "k_cache",
- torch.zeros(cache_shape, dtype=dtype, device=device),
- persistent=False,
- )
- self.register_buffer(
- "v_cache",
- torch.zeros(cache_shape, dtype=dtype, device=device),
- persistent=False,
- )
- def update(self, input_pos, k_val, v_val):
- # input_pos: [S], k_val: [B, H, S, D]
- # assert input_pos.shape[0] == k_val.shape[2]
- bs = k_val.shape[0]
- k_out = self.k_cache
- v_out = self.v_cache
- k_out[:bs, :, input_pos] = k_val
- v_out[:bs, :, input_pos] = v_val
- return k_out[:bs], v_out[:bs]
- class GPTFastDecoder(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.vocab_size = 960
- self.padding_idx = 2
- self.prefix_token_id = 11
- self.eos_id = 1
- self.max_seq_len = 1024
- self.dropout = 0
- self.d_model = 768
- self.nhead = 12
- self.activation = "gelu"
- self.norm_first = True
- self.n_decoder_layer = 4
- config = ModelArgs(
- n_layer=self.n_decoder_layer,
- n_head=self.nhead,
- dim=self.d_model,
- intermediate_size=self.d_model * 4,
- activation=self.activation,
- norm_first=self.norm_first,
- )
- self.config = config
- self.layers = nn.ModuleList(
- TransformerBlock(config) for _ in range(config.n_layer)
- )
- self.token_embed = TokenEmbedding(
- vocab_size=self.vocab_size,
- d_model=self.d_model,
- padding_idx=self.padding_idx,
- )
- self.pos_embed = PositionEmbedding(
- max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
- )
- self.generator = nn.Linear(self.d_model, self.vocab_size)
- self.token_white_list = TOKEN_WHITE_LIST
- self.mask_cache: Optional[Tensor] = None
- self.max_batch_size = -1
- self.max_seq_length = -1
- def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
- for b in self.layers:
- b.multihead_attn.k_cache = None
- b.multihead_attn.v_cache = None
- if (
- self.max_seq_length >= max_seq_length
- and self.max_batch_size >= max_batch_size
- ):
- return
- head_dim = self.config.dim // self.config.n_head
- max_seq_length = find_multiple(max_seq_length, 8)
- self.max_seq_length = max_seq_length
- self.max_batch_size = max_batch_size
- for b in self.layers:
- b.self_attn.kv_cache = KVCache(
- max_batch_size,
- max_seq_length,
- self.config.n_head,
- head_dim,
- dtype,
- device,
- )
- b.multihead_attn.k_cache = None
- b.multihead_attn.v_cache = None
- self.causal_mask = torch.tril(
- torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
- ).to(device)
- def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
- input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
- tgt = tgt[:, -1:]
- tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
- # tgt = self.decoder(tgt_feature, memory, input_pos)
- with torch.backends.cuda.sdp_kernel(
- enable_flash=False, enable_mem_efficient=False, enable_math=True
- ):
- logits = tgt_feature
- tgt_mask = self.causal_mask[None, None, input_pos]
- for i, layer in enumerate(self.layers):
- logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
- # return output
- logits = self.generator(logits)[:, -1, :]
- total = set([i for i in range(logits.shape[-1])])
- black_list = list(total.difference(set(self.token_white_list)))
- logits[..., black_list] = -1e9
- probs = F.softmax(logits, dim=-1)
- _, next_tokens = probs.topk(1)
- return next_tokens
- class TransformerBlock(nn.Module):
- def __init__(self, config: ModelArgs) -> None:
- super().__init__()
- self.self_attn = Attention(config)
- self.multihead_attn = CrossAttention(config)
- layer_norm_eps = 1e-5
- d_model = config.dim
- dim_feedforward = config.intermediate_size
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.linear2 = nn.Linear(dim_feedforward, d_model)
- self.norm_first = config.norm_first
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
- self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
- self.activation = _get_activation_fn(config.activation)
- def forward(
- self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Tensor,
- input_pos: Tensor,
- ) -> Tensor:
- if self.norm_first:
- x = tgt
- x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos)
- x = x + self.multihead_attn(self.norm2(x), memory)
- x = x + self._ff_block(self.norm3(x))
- else:
- x = tgt
- x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos))
- x = self.norm2(x + self.multihead_attn(x, memory))
- x = self.norm3(x + self._ff_block(x))
- return x
- def _ff_block(self, x: Tensor) -> Tensor:
- x = self.linear2(self.activation(self.linear1(x)))
- return x
- class Attention(nn.Module):
- def __init__(self, config: ModelArgs):
- super().__init__()
- assert config.dim % config.n_head == 0
- # key, query, value projections for all heads, but in a batch
- self.wqkv = nn.Linear(config.dim, 3 * config.dim)
- self.wo = nn.Linear(config.dim, config.dim)
- self.kv_cache: Optional[KVCache] = None
- self.n_head = config.n_head
- self.head_dim = config.head_dim
- self.dim = config.dim
- def forward(
- self,
- x: Tensor,
- mask: Tensor,
- input_pos: Optional[Tensor] = None,
- ) -> Tensor:
- bsz, seqlen, _ = x.shape
- kv_size = self.n_head * self.head_dim
- q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
- q = q.view(bsz, seqlen, self.n_head, self.head_dim)
- k = k.view(bsz, seqlen, self.n_head, self.head_dim)
- v = v.view(bsz, seqlen, self.n_head, self.head_dim)
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
- if self.kv_cache is not None:
- k, v = self.kv_cache.update(input_pos, k, v)
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
- y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
- y = self.wo(y)
- return y
- class CrossAttention(nn.Module):
- def __init__(self, config: ModelArgs):
- super().__init__()
- assert config.dim % config.n_head == 0
- self.query = nn.Linear(config.dim, config.dim)
- self.key = nn.Linear(config.dim, config.dim)
- self.value = nn.Linear(config.dim, config.dim)
- self.out = nn.Linear(config.dim, config.dim)
- self.k_cache = None
- self.v_cache = None
- self.n_head = config.n_head
- self.head_dim = config.head_dim
- def get_kv(self, xa: torch.Tensor):
- if self.k_cache is not None and self.v_cache is not None:
- return self.k_cache, self.v_cache
- k = self.key(xa)
- v = self.value(xa)
- # Reshape for correct format
- batch_size, source_seq_len, _ = k.shape
- k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim)
- v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim)
- if self.k_cache is None:
- self.k_cache = k
- if self.v_cache is None:
- self.v_cache = v
- return k, v
- def forward(
- self,
- x: Tensor,
- xa: Tensor,
- ):
- q = self.query(x)
- batch_size, target_seq_len, _ = q.shape
- q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim)
- k, v = self.get_kv(xa)
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
- wv = F.scaled_dot_product_attention(
- query=q,
- key=k,
- value=v,
- is_causal=False,
- )
- wv = wv.transpose(1, 2).reshape(
- batch_size,
- target_seq_len,
- self.n_head * self.head_dim,
- )
- return self.out(wv)
|