beit.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. """ Vision Transformer (ViT) in PyTorch
  2. A PyTorch implement of Vision Transformers as described in
  3. 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
  4. The official jax code is released and available at https://github.com/google-research/vision_transformer
  5. Status/TODO:
  6. * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
  7. * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
  8. * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
  9. * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
  10. Acknowledgments:
  11. * The paper authors for releasing code and weights, thanks!
  12. * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
  13. for some einops/einsum fun
  14. * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
  15. * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
  16. Hacked together by / Copyright 2020 Ross Wightman
  17. """
  18. import warnings
  19. import math
  20. import torch
  21. from functools import partial
  22. import torch.nn as nn
  23. import torch.nn.functional as F
  24. import torch.utils.checkpoint as checkpoint
  25. from timm.models.layers import drop_path, to_2tuple, trunc_normal_
  26. def _cfg(url='', **kwargs):
  27. return {
  28. 'url': url,
  29. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  30. 'crop_pct': .9, 'interpolation': 'bicubic',
  31. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  32. **kwargs
  33. }
  34. class DropPath(nn.Module):
  35. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  36. """
  37. def __init__(self, drop_prob=None):
  38. super(DropPath, self).__init__()
  39. self.drop_prob = drop_prob
  40. def forward(self, x):
  41. return drop_path(x, self.drop_prob, self.training)
  42. def extra_repr(self) -> str:
  43. return 'p={}'.format(self.drop_prob)
  44. class Mlp(nn.Module):
  45. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  46. super().__init__()
  47. out_features = out_features or in_features
  48. hidden_features = hidden_features or in_features
  49. self.fc1 = nn.Linear(in_features, hidden_features)
  50. self.act = act_layer()
  51. self.fc2 = nn.Linear(hidden_features, out_features)
  52. self.drop = nn.Dropout(drop)
  53. def forward(self, x):
  54. x = self.fc1(x)
  55. x = self.act(x)
  56. # x = self.drop(x)
  57. # commit this for the orignal BERT implement
  58. x = self.fc2(x)
  59. x = self.drop(x)
  60. return x
  61. class Attention(nn.Module):
  62. def __init__(
  63. self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
  64. proj_drop=0., window_size=None, attn_head_dim=None):
  65. super().__init__()
  66. self.num_heads = num_heads
  67. head_dim = dim // num_heads
  68. if attn_head_dim is not None:
  69. head_dim = attn_head_dim
  70. all_head_dim = head_dim * self.num_heads
  71. # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
  72. self.scale = qk_scale or head_dim ** -0.5
  73. self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
  74. if qkv_bias:
  75. self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
  76. self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
  77. else:
  78. self.q_bias = None
  79. self.v_bias = None
  80. if window_size:
  81. self.window_size = window_size
  82. self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  83. self.relative_position_bias_table = nn.Parameter(
  84. torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  85. # cls to token & token 2 cls & cls to cls
  86. # get pair-wise relative position index for each token inside the window
  87. coords_h = torch.arange(window_size[0])
  88. coords_w = torch.arange(window_size[1])
  89. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  90. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  91. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  92. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  93. relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
  94. relative_coords[:, :, 1] += window_size[1] - 1
  95. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  96. relative_position_index = \
  97. torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
  98. relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  99. relative_position_index[0, 0:] = self.num_relative_distance - 3
  100. relative_position_index[0:, 0] = self.num_relative_distance - 2
  101. relative_position_index[0, 0] = self.num_relative_distance - 1
  102. self.register_buffer("relative_position_index", relative_position_index)
  103. # trunc_normal_(self.relative_position_bias_table, std=.0)
  104. else:
  105. self.window_size = None
  106. self.relative_position_bias_table = None
  107. self.relative_position_index = None
  108. self.attn_drop = nn.Dropout(attn_drop)
  109. self.proj = nn.Linear(all_head_dim, dim)
  110. self.proj_drop = nn.Dropout(proj_drop)
  111. def forward(self, x, rel_pos_bias=None, training_window_size=None):
  112. B, N, C = x.shape
  113. qkv_bias = None
  114. if self.q_bias is not None:
  115. qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
  116. # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  117. qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
  118. qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  119. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  120. q = q * self.scale
  121. attn = (q @ k.transpose(-2, -1))
  122. if self.relative_position_bias_table is not None:
  123. if training_window_size == self.window_size:
  124. relative_position_bias = \
  125. self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  126. self.window_size[0] * self.window_size[1] + 1,
  127. self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
  128. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  129. attn = attn + relative_position_bias.unsqueeze(0)
  130. else:
  131. training_window_size = tuple(training_window_size.tolist())
  132. new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
  133. # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
  134. new_relative_position_bias_table = F.interpolate(
  135. self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
  136. 2 * self.window_size[0] - 1,
  137. 2 * self.window_size[1] - 1),
  138. size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
  139. align_corners=False)
  140. new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
  141. new_num_relative_distance - 3).permute(
  142. 1, 0)
  143. new_relative_position_bias_table = torch.cat(
  144. [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
  145. # get pair-wise relative position index for each token inside the window
  146. coords_h = torch.arange(training_window_size[0])
  147. coords_w = torch.arange(training_window_size[1])
  148. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  149. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  150. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  151. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  152. relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
  153. relative_coords[:, :, 1] += training_window_size[1] - 1
  154. relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
  155. relative_position_index = \
  156. torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
  157. dtype=relative_coords.dtype)
  158. relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  159. relative_position_index[0, 0:] = new_num_relative_distance - 3
  160. relative_position_index[0:, 0] = new_num_relative_distance - 2
  161. relative_position_index[0, 0] = new_num_relative_distance - 1
  162. relative_position_bias = \
  163. new_relative_position_bias_table[relative_position_index.view(-1)].view(
  164. training_window_size[0] * training_window_size[1] + 1,
  165. training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
  166. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  167. attn = attn + relative_position_bias.unsqueeze(0)
  168. if rel_pos_bias is not None:
  169. attn = attn + rel_pos_bias
  170. attn = attn.softmax(dim=-1)
  171. attn = self.attn_drop(attn)
  172. x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
  173. x = self.proj(x)
  174. x = self.proj_drop(x)
  175. return x
  176. class Block(nn.Module):
  177. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
  178. drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
  179. window_size=None, attn_head_dim=None):
  180. super().__init__()
  181. self.norm1 = norm_layer(dim)
  182. self.attn = Attention(
  183. dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
  184. attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
  185. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  186. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  187. self.norm2 = norm_layer(dim)
  188. mlp_hidden_dim = int(dim * mlp_ratio)
  189. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  190. if init_values is not None:
  191. self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
  192. self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
  193. else:
  194. self.gamma_1, self.gamma_2 = None, None
  195. def forward(self, x, rel_pos_bias=None, training_window_size=None):
  196. if self.gamma_1 is None:
  197. x = x + self.drop_path(
  198. self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
  199. x = x + self.drop_path(self.mlp(self.norm2(x)))
  200. else:
  201. x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
  202. training_window_size=training_window_size))
  203. x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
  204. return x
  205. class PatchEmbed(nn.Module):
  206. """ Image to Patch Embedding
  207. """
  208. def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
  209. super().__init__()
  210. img_size = to_2tuple(img_size)
  211. patch_size = to_2tuple(patch_size)
  212. num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
  213. self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
  214. self.num_patches_w = self.patch_shape[0]
  215. self.num_patches_h = self.patch_shape[1]
  216. # the so-called patch_shape is the patch shape during pre-training
  217. self.img_size = img_size
  218. self.patch_size = patch_size
  219. self.num_patches = num_patches
  220. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  221. def forward(self, x, position_embedding=None, **kwargs):
  222. # FIXME look at relaxing size constraints
  223. # assert H == self.img_size[0] and W == self.img_size[1], \
  224. # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  225. x = self.proj(x)
  226. Hp, Wp = x.shape[2], x.shape[3]
  227. if position_embedding is not None:
  228. # interpolate the position embedding to the corresponding size
  229. position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
  230. 1, 2)
  231. position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
  232. x = x + position_embedding
  233. x = x.flatten(2).transpose(1, 2)
  234. return x, (Hp, Wp)
  235. class HybridEmbed(nn.Module):
  236. """ CNN Feature Map Embedding
  237. Extract feature map from CNN, flatten, project to embedding dim.
  238. """
  239. def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
  240. super().__init__()
  241. assert isinstance(backbone, nn.Module)
  242. img_size = to_2tuple(img_size)
  243. self.img_size = img_size
  244. self.backbone = backbone
  245. if feature_size is None:
  246. with torch.no_grad():
  247. # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
  248. # map for all networks, the feature metadata has reliable channel and stride info, but using
  249. # stride to calc feature dim requires info about padding of each stage that isn't captured.
  250. training = backbone.training
  251. if training:
  252. backbone.eval()
  253. o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
  254. feature_size = o.shape[-2:]
  255. feature_dim = o.shape[1]
  256. backbone.train(training)
  257. else:
  258. feature_size = to_2tuple(feature_size)
  259. feature_dim = self.backbone.feature_info.channels()[-1]
  260. self.num_patches = feature_size[0] * feature_size[1]
  261. self.proj = nn.Linear(feature_dim, embed_dim)
  262. def forward(self, x):
  263. x = self.backbone(x)[-1]
  264. x = x.flatten(2).transpose(1, 2)
  265. x = self.proj(x)
  266. return x
  267. class RelativePositionBias(nn.Module):
  268. def __init__(self, window_size, num_heads):
  269. super().__init__()
  270. self.window_size = window_size
  271. self.num_heads = num_heads
  272. self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  273. self.relative_position_bias_table = nn.Parameter(
  274. torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  275. # cls to token & token 2 cls & cls to cls
  276. # get pair-wise relative position index for each token inside the window
  277. coords_h = torch.arange(window_size[0])
  278. coords_w = torch.arange(window_size[1])
  279. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  280. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  281. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  282. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  283. relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
  284. relative_coords[:, :, 1] += window_size[1] - 1
  285. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  286. relative_position_index = \
  287. torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
  288. relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  289. relative_position_index[0, 0:] = self.num_relative_distance - 3
  290. relative_position_index[0:, 0] = self.num_relative_distance - 2
  291. relative_position_index[0, 0] = self.num_relative_distance - 1
  292. self.register_buffer("relative_position_index", relative_position_index)
  293. # trunc_normal_(self.relative_position_bias_table, std=.02)
  294. def forward(self, training_window_size):
  295. if training_window_size == self.window_size:
  296. relative_position_bias = \
  297. self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  298. self.window_size[0] * self.window_size[1] + 1,
  299. self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
  300. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  301. else:
  302. training_window_size = tuple(training_window_size.tolist())
  303. new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
  304. # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
  305. new_relative_position_bias_table = F.interpolate(
  306. self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
  307. 2 * self.window_size[0] - 1,
  308. 2 * self.window_size[1] - 1),
  309. size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
  310. align_corners=False)
  311. new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
  312. new_num_relative_distance - 3).permute(
  313. 1, 0)
  314. new_relative_position_bias_table = torch.cat(
  315. [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
  316. # get pair-wise relative position index for each token inside the window
  317. coords_h = torch.arange(training_window_size[0])
  318. coords_w = torch.arange(training_window_size[1])
  319. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  320. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  321. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  322. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  323. relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
  324. relative_coords[:, :, 1] += training_window_size[1] - 1
  325. relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
  326. relative_position_index = \
  327. torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
  328. dtype=relative_coords.dtype)
  329. relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  330. relative_position_index[0, 0:] = new_num_relative_distance - 3
  331. relative_position_index[0:, 0] = new_num_relative_distance - 2
  332. relative_position_index[0, 0] = new_num_relative_distance - 1
  333. relative_position_bias = \
  334. new_relative_position_bias_table[relative_position_index.view(-1)].view(
  335. training_window_size[0] * training_window_size[1] + 1,
  336. training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
  337. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  338. return relative_position_bias
  339. class BEiT(nn.Module):
  340. """ Vision Transformer with support for patch or hybrid CNN input stage
  341. """
  342. def __init__(self,
  343. img_size=[224, 224],
  344. patch_size=16,
  345. in_chans=3,
  346. num_classes=80,
  347. embed_dim=768,
  348. depth=12,
  349. num_heads=12,
  350. mlp_ratio=4.,
  351. qkv_bias=False,
  352. qk_scale=None,
  353. drop_rate=0.,
  354. attn_drop_rate=0.,
  355. drop_path_rate=0.,
  356. hybrid_backbone=None,
  357. norm_layer=None,
  358. init_values=None,
  359. use_abs_pos_emb=False,
  360. use_rel_pos_bias=False,
  361. use_shared_rel_pos_bias=False,
  362. use_checkpoint=True,
  363. pretrained=None,
  364. out_features=None,
  365. ):
  366. super(BEiT, self).__init__()
  367. norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  368. self.num_classes = num_classes
  369. self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  370. self.use_checkpoint = use_checkpoint
  371. if hybrid_backbone is not None:
  372. self.patch_embed = HybridEmbed(
  373. hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
  374. else:
  375. self.patch_embed = PatchEmbed(
  376. img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
  377. num_patches = self.patch_embed.num_patches
  378. self.out_features = out_features
  379. self.out_indices = [int(name[5:]) for name in out_features]
  380. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  381. # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  382. if use_abs_pos_emb:
  383. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
  384. else:
  385. self.pos_embed = None
  386. self.pos_drop = nn.Dropout(p=drop_rate)
  387. self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
  388. if use_shared_rel_pos_bias:
  389. self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
  390. else:
  391. self.rel_pos_bias = None
  392. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
  393. self.use_rel_pos_bias = use_rel_pos_bias
  394. self.blocks = nn.ModuleList([
  395. Block(
  396. dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  397. drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
  398. init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
  399. for i in range(depth)])
  400. # trunc_normal_(self.mask_token, std=.02)
  401. if patch_size == 16:
  402. self.fpn1 = nn.Sequential(
  403. nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
  404. # nn.SyncBatchNorm(embed_dim),
  405. nn.BatchNorm2d(embed_dim),
  406. nn.GELU(),
  407. nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
  408. )
  409. self.fpn2 = nn.Sequential(
  410. nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
  411. )
  412. self.fpn3 = nn.Identity()
  413. self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
  414. elif patch_size == 8:
  415. self.fpn1 = nn.Sequential(
  416. nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
  417. )
  418. self.fpn2 = nn.Identity()
  419. self.fpn3 = nn.Sequential(
  420. nn.MaxPool2d(kernel_size=2, stride=2),
  421. )
  422. self.fpn4 = nn.Sequential(
  423. nn.MaxPool2d(kernel_size=4, stride=4),
  424. )
  425. if self.pos_embed is not None:
  426. trunc_normal_(self.pos_embed, std=.02)
  427. trunc_normal_(self.cls_token, std=.02)
  428. self.apply(self._init_weights)
  429. self.fix_init_weight()
  430. def fix_init_weight(self):
  431. def rescale(param, layer_id):
  432. param.div_(math.sqrt(2.0 * layer_id))
  433. for layer_id, layer in enumerate(self.blocks):
  434. rescale(layer.attn.proj.weight.data, layer_id + 1)
  435. rescale(layer.mlp.fc2.weight.data, layer_id + 1)
  436. def _init_weights(self, m):
  437. if isinstance(m, nn.Linear):
  438. trunc_normal_(m.weight, std=.02)
  439. if isinstance(m, nn.Linear) and m.bias is not None:
  440. nn.init.constant_(m.bias, 0)
  441. elif isinstance(m, nn.LayerNorm):
  442. nn.init.constant_(m.bias, 0)
  443. nn.init.constant_(m.weight, 1.0)
  444. '''
  445. def init_weights(self):
  446. """Initialize the weights in backbone.
  447. Args:
  448. pretrained (str, optional): Path to pre-trained weights.
  449. Defaults to None.
  450. """
  451. logger = get_root_logger()
  452. if self.pos_embed is not None:
  453. trunc_normal_(self.pos_embed, std=.02)
  454. trunc_normal_(self.cls_token, std=.02)
  455. self.apply(self._init_weights)
  456. self.fix_init_weight()
  457. if self.init_cfg is None:
  458. logger.warn(f'No pre-trained weights for '
  459. f'{self.__class__.__name__}, '
  460. f'training start from scratch')
  461. else:
  462. assert 'checkpoint' in self.init_cfg, f'Only support ' \
  463. f'specify `Pretrained` in ' \
  464. f'`init_cfg` in ' \
  465. f'{self.__class__.__name__} '
  466. logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
  467. load_checkpoint(self,
  468. filename=self.init_cfg['checkpoint'],
  469. strict=False,
  470. logger=logger,
  471. beit_spec_expand_rel_pos = self.use_rel_pos_bias,
  472. )
  473. '''
  474. def get_num_layers(self):
  475. return len(self.blocks)
  476. @torch.jit.ignore
  477. def no_weight_decay(self):
  478. return {'pos_embed', 'cls_token'}
  479. def forward_features(self, x):
  480. B, C, H, W = x.shape
  481. x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
  482. # Hp, Wp are HW for patches
  483. batch_size, seq_len, _ = x.size()
  484. cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
  485. if self.pos_embed is not None:
  486. cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
  487. x = torch.cat((cls_tokens, x), dim=1)
  488. x = self.pos_drop(x)
  489. features = []
  490. training_window_size = torch.tensor([Hp, Wp])
  491. rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
  492. for i, blk in enumerate(self.blocks):
  493. if self.use_checkpoint:
  494. x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
  495. else:
  496. x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
  497. if i in self.out_indices:
  498. xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
  499. features.append(xp.contiguous())
  500. ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
  501. for i in range(len(features)):
  502. features[i] = ops[i](features[i])
  503. feat_out = {}
  504. for name, value in zip(self.out_features, features):
  505. feat_out[name] = value
  506. return feat_out
  507. def forward(self, x):
  508. x = self.forward_features(x)
  509. return x
  510. def beit_base_patch16(pretrained=False, **kwargs):
  511. model = BEiT(
  512. patch_size=16,
  513. embed_dim=768,
  514. depth=12,
  515. num_heads=12,
  516. mlp_ratio=4,
  517. qkv_bias=True,
  518. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  519. init_values=None,
  520. **kwargs)
  521. model.default_cfg = _cfg()
  522. return model
  523. def beit_large_patch16(pretrained=False, **kwargs):
  524. model = BEiT(
  525. patch_size=16,
  526. embed_dim=1024,
  527. depth=24,
  528. num_heads=16,
  529. mlp_ratio=4,
  530. qkv_bias=True,
  531. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  532. init_values=None,
  533. **kwargs)
  534. model.default_cfg = _cfg()
  535. return model
  536. def dit_base_patch16(pretrained=False, **kwargs):
  537. model = BEiT(
  538. patch_size=16,
  539. embed_dim=768,
  540. depth=12,
  541. num_heads=12,
  542. mlp_ratio=4,
  543. qkv_bias=True,
  544. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  545. init_values=0.1,
  546. **kwargs)
  547. model.default_cfg = _cfg()
  548. return model
  549. def dit_large_patch16(pretrained=False, **kwargs):
  550. model = BEiT(
  551. patch_size=16,
  552. embed_dim=1024,
  553. depth=24,
  554. num_heads=16,
  555. mlp_ratio=4,
  556. qkv_bias=True,
  557. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  558. init_values=1e-5,
  559. **kwargs)
  560. model.default_cfg = _cfg()
  561. return model
  562. if __name__ == '__main__':
  563. model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
  564. model = model.to("cuda:0")
  565. input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
  566. input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
  567. input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
  568. output1 = model(input1)
  569. output2 = model(input2)
  570. output3 = model(input3)
  571. print("all done")