modeling_dots_vision.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.utils.checkpoint
  6. try:
  7. from flash_attn import flash_attn_varlen_func
  8. HAS_FLASH_ATTN = True
  9. except ImportError:
  10. HAS_FLASH_ATTN = False
  11. def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False, **kwargs):
  12. """
  13. Float16 optimized fallback implementation for flash_attn_varlen_func.
  14. Optimized for Apple Silicon MPS.
  15. """
  16. print("Flash Attention not available. Using float16 MPS-optimized fallback.")
  17. # q, k, v shapes: (total_seq_len, num_heads, head_dim)
  18. batch_size = len(cu_seqlens_q) - 1
  19. outputs = []
  20. for i in range(batch_size):
  21. start_q = cu_seqlens_q[i]
  22. end_q = cu_seqlens_q[i + 1]
  23. start_k = cu_seqlens_k[i]
  24. end_k = cu_seqlens_k[i + 1]
  25. q_seq = q[start_q:end_q] # (seq_len_q, num_heads, head_dim)
  26. k_seq = k[start_k:end_k] # (seq_len_k, num_heads, head_dim)
  27. v_seq = v[start_k:end_k] # (seq_len_k, num_heads, head_dim)
  28. # Transpose for standard attention: (num_heads, seq_len, head_dim)
  29. q_seq = q_seq.transpose(0, 1)
  30. k_seq = k_seq.transpose(0, 1)
  31. v_seq = v_seq.transpose(0, 1)
  32. # Standard scaled dot-product attention with float16 optimization
  33. scores = torch.matmul(q_seq, k_seq.transpose(-2, -1)) / math.sqrt(q_seq.size(-1))
  34. # Apply causal mask if needed
  35. if causal and q_seq.size(1) > 1:
  36. seq_len = q_seq.size(1)
  37. causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device, dtype=q.dtype), diagonal=1).bool()
  38. scores.masked_fill_(causal_mask, float('-inf'))
  39. # Use float32 for softmax stability, then convert back to float16
  40. attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
  41. attn_output = torch.matmul(attn_weights, v_seq)
  42. # Transpose back: (seq_len, num_heads, head_dim)
  43. attn_output = attn_output.transpose(0, 1)
  44. outputs.append(attn_output)
  45. # Concatenate all sequences
  46. return torch.cat(outputs, dim=0)
  47. from torch.nn import LayerNorm
  48. from transformers.modeling_utils import PreTrainedModel
  49. from .configuration_dots import DotsVisionConfig
  50. def rotate_half(x):
  51. """Rotates half the hidden dims of the input."""
  52. x1 = x[..., : x.shape[-1] // 2]
  53. x2 = x[..., x.shape[-1] // 2 :]
  54. return torch.cat((-x2, x1), dim=-1)
  55. def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
  56. orig_dtype = tensor.dtype
  57. # For float16, use float32 for computation stability
  58. tensor = tensor.float()
  59. cos = freqs.cos()
  60. sin = freqs.sin()
  61. cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
  62. sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
  63. output = (tensor * cos) + (rotate_half(tensor) * sin)
  64. # Convert back to original dtype (float16 for MPS efficiency)
  65. output = output.to(orig_dtype)
  66. return output
  67. class VisionRotaryEmbedding(nn.Module):
  68. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  69. super().__init__()
  70. inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
  71. self.register_buffer("inv_freq", inv_freq, persistent=False)
  72. def forward(self, seqlen: int) -> torch.Tensor:
  73. seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
  74. freqs = torch.outer(seq, self.inv_freq)
  75. return freqs
  76. class PatchMerger(nn.Module):
  77. def __init__(
  78. self,
  79. dim: int,
  80. context_dim: int,
  81. spatial_merge_size: int = 2,
  82. pre_norm="layernorm",
  83. init_merger_std=None,
  84. ) -> None:
  85. super().__init__()
  86. self.hidden_size = context_dim * (spatial_merge_size**2)
  87. self.pre_norm = pre_norm
  88. if self.pre_norm == "layernorm":
  89. self.ln_q = LayerNorm(context_dim, eps=1e-6)
  90. elif self.pre_norm == "rmsnorm":
  91. self.ln_q = RMSNorm(context_dim, eps=1e-6)
  92. else:
  93. print("no norm in patch merger")
  94. self.mlp = nn.Sequential(
  95. nn.Linear(self.hidden_size, self.hidden_size),
  96. nn.GELU(),
  97. nn.Linear(self.hidden_size, dim),
  98. )
  99. if init_merger_std is not None:
  100. nn.init.normal_(self.mlp[0].weight, mean=0.0, std=init_merger_std)
  101. nn.init.zeros_(self.mlp[0].bias)
  102. nn.init.normal_(self.mlp[2].weight, mean=0.0, std=init_merger_std)
  103. nn.init.zeros_(self.mlp[2].bias)
  104. def forward(self, x: torch.Tensor) -> torch.Tensor:
  105. if self.pre_norm:
  106. x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
  107. else:
  108. x = self.mlp(x.view(-1, self.hidden_size))
  109. return x
  110. class VisionAttention(nn.Module):
  111. def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
  112. super().__init__()
  113. self.num_heads = num_heads
  114. self.head_dim = dim // num_heads
  115. self.qkv = nn.Linear(dim, dim * 3, bias=bias)
  116. self.proj = nn.Linear(dim, dim, bias=bias)
  117. def forward(
  118. self,
  119. hidden_states: torch.Tensor,
  120. cu_seqlens: torch.Tensor,
  121. rotary_pos_emb: torch.Tensor = None,
  122. ) -> torch.Tensor:
  123. seq_length = hidden_states.shape[0]
  124. q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
  125. q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
  126. k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
  127. attention_mask = torch.full(
  128. [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
  129. )
  130. for i in range(1, len(cu_seqlens)):
  131. attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
  132. q = q.transpose(0, 1)
  133. k = k.transpose(0, 1)
  134. v = v.transpose(0, 1)
  135. attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
  136. attn_weights = attn_weights + attention_mask
  137. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
  138. attn_output = torch.matmul(attn_weights, v)
  139. attn_output = attn_output.transpose(0, 1)
  140. attn_output = attn_output.reshape(seq_length, -1)
  141. attn_output = self.proj(attn_output)
  142. return attn_output
  143. class VisionFlashAttention2(nn.Module):
  144. def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
  145. super().__init__()
  146. self.num_heads = num_heads
  147. self.qkv = nn.Linear(dim, dim * 3, bias=bias)
  148. self.proj = nn.Linear(dim, dim, bias=bias)
  149. self.config = config
  150. self.is_causal = config.is_causal
  151. def forward(
  152. self,
  153. hidden_states: torch.Tensor,
  154. cu_seqlens: torch.Tensor,
  155. rotary_pos_emb: torch.Tensor = None,
  156. ) -> torch.Tensor:
  157. seq_length = hidden_states.shape[0]
  158. q, k, v = (
  159. self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
  160. ) # 'shd'
  161. q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
  162. k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
  163. max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
  164. attn_output = flash_attn_varlen_func(
  165. q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=self.is_causal
  166. ).reshape(seq_length, -1)
  167. attn_output = self.proj(attn_output)
  168. return attn_output
  169. class VisionSdpaAttention(nn.Module):
  170. def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
  171. super().__init__()
  172. self.num_heads = num_heads
  173. self.qkv = nn.Linear(dim, dim * 3, bias=bias)
  174. self.proj = nn.Linear(dim, dim, bias=bias)
  175. self.config = config
  176. def forward(
  177. self,
  178. hidden_states: torch.Tensor,
  179. cu_seqlens: torch.Tensor,
  180. rotary_pos_emb: torch.Tensor = None,
  181. ) -> torch.Tensor:
  182. seq_length = hidden_states.shape[0]
  183. q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
  184. q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
  185. k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
  186. attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
  187. for i in range(1, len(cu_seqlens)):
  188. attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
  189. q = q.transpose(0, 1)
  190. k = k.transpose(0, 1)
  191. v = v.transpose(0, 1)
  192. attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
  193. attn_output = attn_output.transpose(0, 1)
  194. attn_output = attn_output.reshape(seq_length, -1)
  195. attn_output = self.proj(attn_output)
  196. return attn_output
  197. DOTS_VISION_ATTENTION_CLASSES = {
  198. "eager": VisionAttention,
  199. "flash_attention_2": VisionFlashAttention2,
  200. "sdpa": VisionSdpaAttention,
  201. }
  202. class RMSNorm(nn.Module):
  203. def __init__(self, dim: int, eps: float = 1e-6):
  204. super().__init__()
  205. self.weight = nn.Parameter(torch.ones(dim))
  206. self.eps = eps
  207. def forward(self, x: torch.Tensor) -> torch.Tensor:
  208. output = self._norm(x.float()).type_as(x)
  209. return output * self.weight
  210. def extra_repr(self) -> str:
  211. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  212. def _norm(self, x: torch.Tensor) -> torch.Tensor:
  213. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  214. class DotsSwiGLUFFN(nn.Module):
  215. def __init__(self, config):
  216. super().__init__()
  217. hidden_features = config.intermediate_size
  218. in_features = config.embed_dim
  219. bias = config.use_bias
  220. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
  221. self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
  222. self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
  223. def forward(self, x: torch.Tensor) -> torch.Tensor:
  224. x = F.silu(self.fc1(x)) * self.fc3(x)
  225. x = self.fc2(x)
  226. return x
  227. class DotsPatchEmbed(nn.Module):
  228. def __init__(self, config):
  229. super().__init__()
  230. self.num_channels = config.num_channels
  231. self.patch_size = config.patch_size
  232. self.temporal_patch_size = config.temporal_patch_size
  233. self.embed_dim = config.embed_dim
  234. self.config = config
  235. self.proj = nn.Conv2d(
  236. config.num_channels,
  237. config.embed_dim,
  238. kernel_size=(config.patch_size, config.patch_size),
  239. stride=(config.patch_size, config.patch_size),
  240. )
  241. self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
  242. def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
  243. x = x.view(-1, self.num_channels, self.temporal_patch_size, self.patch_size, self.patch_size)[:, :, 0]
  244. x = self.proj(x).view(-1, self.embed_dim)
  245. x = self.norm(x)
  246. return x
  247. class DotsViTPreprocessor(nn.Module):
  248. def __init__(self, config):
  249. super().__init__()
  250. self.patch_h = config.patch_size
  251. self.patch_w = config.patch_size
  252. self.embed_dim = config.embed_dim
  253. self.config = config
  254. self.patchifier = DotsPatchEmbed(config)
  255. def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
  256. tokens = self.patchifier(x, grid_thw)
  257. return tokens
  258. class DotsVisionBlock(nn.Module):
  259. def __init__(self, config, attn_implementation: str = "flash_attention_2"):
  260. super().__init__()
  261. self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation](
  262. config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias
  263. )
  264. self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
  265. self.mlp = DotsSwiGLUFFN(config)
  266. self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
  267. def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
  268. hidden_states = hidden_states + self.attn(
  269. self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
  270. )
  271. hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
  272. return hidden_states
  273. class DotsVisionTransformer(PreTrainedModel):
  274. def __init__(self, config: DotsVisionConfig) -> None:
  275. super().__init__(config)
  276. self.config = config
  277. self.spatial_merge_size = config.spatial_merge_size
  278. self.patch_embed = DotsViTPreprocessor(config)
  279. self._init_weights(self.patch_embed.patchifier.proj)
  280. head_dim = config.embed_dim // config.num_attention_heads
  281. self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
  282. _num_hidden_layers = config.num_hidden_layers
  283. self.blocks = nn.ModuleList(
  284. [DotsVisionBlock(config, config.attn_implementation) for _ in range(_num_hidden_layers)]
  285. )
  286. if self.config.post_norm:
  287. self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
  288. self.merger = PatchMerger(
  289. dim=config.hidden_size,
  290. context_dim=config.embed_dim,
  291. spatial_merge_size=config.spatial_merge_size,
  292. init_merger_std=self.config.init_merger_std,
  293. )
  294. self.gradient_checkpointing = False
  295. self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
  296. def _init_weights(self, module):
  297. std = self.config.initializer_range
  298. if isinstance(module, (nn.Linear, nn.Conv3d)):
  299. module.weight.data.normal_(mean=0.0, std=std)
  300. if module.bias is not None:
  301. module.bias.data.zero_()
  302. elif isinstance(module, nn.Embedding):
  303. module.weight.data.normal_(mean=0.0, std=std)
  304. if module.padding_idx is not None:
  305. module.weight.data[module.padding_idx].zero_()
  306. @property
  307. def dtype(self) -> torch.dtype:
  308. return self.blocks[0].mlp.fc2.weight.dtype
  309. @property
  310. def device(self) -> torch.device:
  311. return self.blocks[0].mlp.fc2.weight.device
  312. def get_pos_ids_by_grid(self, grid_thw):
  313. pos_ids = []
  314. for t, h, w in grid_thw:
  315. hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
  316. hpos_ids = hpos_ids.reshape(
  317. h // self.spatial_merge_size,
  318. self.spatial_merge_size,
  319. w // self.spatial_merge_size,
  320. self.spatial_merge_size,
  321. )
  322. hpos_ids = hpos_ids.permute(0, 2, 1, 3)
  323. hpos_ids = hpos_ids.flatten()
  324. wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
  325. wpos_ids = wpos_ids.reshape(
  326. h // self.spatial_merge_size,
  327. self.spatial_merge_size,
  328. w // self.spatial_merge_size,
  329. self.spatial_merge_size,
  330. )
  331. wpos_ids = wpos_ids.permute(0, 2, 1, 3)
  332. wpos_ids = wpos_ids.flatten()
  333. pos_ids.append(
  334. torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
  335. )
  336. return pos_ids
  337. def rot_pos_emb(self, grid_thw):
  338. pos_ids = self.get_pos_ids_by_grid(grid_thw)
  339. pos_ids = torch.cat(pos_ids, dim=0)
  340. max_grid_size = grid_thw[:, 1:].max()
  341. rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
  342. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
  343. return rotary_pos_emb
  344. def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=True) -> torch.Tensor:
  345. if bf16:
  346. hidden_states = hidden_states.to(torch.float16)
  347. hidden_states = self.patch_embed(hidden_states, grid_thw)
  348. rotary_pos_emb = self.rot_pos_emb(grid_thw)
  349. cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
  350. dim=0,
  351. dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
  352. )
  353. cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
  354. for blk in self.blocks:
  355. if self.gradient_checkpointing and self.training:
  356. hidden_states = self._gradient_checkpointing_func(
  357. blk.__call__,
  358. hidden_states,
  359. cu_seqlens,
  360. rotary_pos_emb,
  361. use_reentrant=(self.config.ckpt_use_reentrant or self.config.ve_ckpt_use_reentrant),
  362. )
  363. else:
  364. hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
  365. if self.config.post_norm:
  366. hidden_states = self.post_trunk_norm(hidden_states)
  367. hidden_states = self.merger(hidden_states)
  368. return hidden_states