deit.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. """
  2. Mostly copy-paste from DINO and timm library:
  3. https://github.com/facebookresearch/dino
  4. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  5. """
  6. import warnings
  7. import math
  8. import torch
  9. import torch.nn as nn
  10. import torch.utils.checkpoint as checkpoint
  11. from timm.models.layers import trunc_normal_, drop_path, to_2tuple
  12. from functools import partial
  13. def _cfg(url='', **kwargs):
  14. return {
  15. 'url': url,
  16. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  17. 'crop_pct': .9, 'interpolation': 'bicubic',
  18. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  19. **kwargs
  20. }
  21. class DropPath(nn.Module):
  22. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  23. """
  24. def __init__(self, drop_prob=None):
  25. super(DropPath, self).__init__()
  26. self.drop_prob = drop_prob
  27. def forward(self, x):
  28. return drop_path(x, self.drop_prob, self.training)
  29. def extra_repr(self) -> str:
  30. return 'p={}'.format(self.drop_prob)
  31. class Mlp(nn.Module):
  32. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  33. super().__init__()
  34. out_features = out_features or in_features
  35. hidden_features = hidden_features or in_features
  36. self.fc1 = nn.Linear(in_features, hidden_features)
  37. self.act = act_layer()
  38. self.fc2 = nn.Linear(hidden_features, out_features)
  39. self.drop = nn.Dropout(drop)
  40. def forward(self, x):
  41. x = self.fc1(x)
  42. x = self.act(x)
  43. x = self.drop(x)
  44. x = self.fc2(x)
  45. x = self.drop(x)
  46. return x
  47. class Attention(nn.Module):
  48. def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
  49. super().__init__()
  50. self.num_heads = num_heads
  51. head_dim = dim // num_heads
  52. # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
  53. self.scale = qk_scale or head_dim ** -0.5
  54. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  55. self.attn_drop = nn.Dropout(attn_drop)
  56. self.proj = nn.Linear(dim, dim)
  57. self.proj_drop = nn.Dropout(proj_drop)
  58. def forward(self, x):
  59. B, N, C = x.shape
  60. q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
  61. C // self.num_heads).permute(2, 0, 3, 1, 4)
  62. attn = (q @ k.transpose(-2, -1)) * self.scale
  63. attn = attn.softmax(dim=-1)
  64. attn = self.attn_drop(attn)
  65. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  66. x = self.proj(x)
  67. x = self.proj_drop(x)
  68. return x
  69. class Block(nn.Module):
  70. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
  71. drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
  72. super().__init__()
  73. self.norm1 = norm_layer(dim)
  74. self.attn = Attention(
  75. dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  76. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  77. self.drop_path = DropPath(
  78. drop_path) if drop_path > 0. else nn.Identity()
  79. self.norm2 = norm_layer(dim)
  80. mlp_hidden_dim = int(dim * mlp_ratio)
  81. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
  82. act_layer=act_layer, drop=drop)
  83. def forward(self, x):
  84. x = x + self.drop_path(self.attn(self.norm1(x)))
  85. x = x + self.drop_path(self.mlp(self.norm2(x)))
  86. return x
  87. class PatchEmbed(nn.Module):
  88. """ Image to Patch Embedding
  89. """
  90. def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
  91. super().__init__()
  92. img_size = to_2tuple(img_size)
  93. patch_size = to_2tuple(patch_size)
  94. self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
  95. self.num_patches_w, self.num_patches_h = self.window_size
  96. self.num_patches = self.window_size[0] * self.window_size[1]
  97. self.img_size = img_size
  98. self.patch_size = patch_size
  99. self.proj = nn.Conv2d(in_chans, embed_dim,
  100. kernel_size=patch_size, stride=patch_size)
  101. def forward(self, x):
  102. x = self.proj(x)
  103. return x
  104. class HybridEmbed(nn.Module):
  105. """ CNN Feature Map Embedding
  106. Extract feature map from CNN, flatten, project to embedding dim.
  107. """
  108. def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
  109. super().__init__()
  110. assert isinstance(backbone, nn.Module)
  111. img_size = to_2tuple(img_size)
  112. self.img_size = img_size
  113. self.backbone = backbone
  114. if feature_size is None:
  115. with torch.no_grad():
  116. # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
  117. # map for all networks, the feature metadata has reliable channel and stride info, but using
  118. # stride to calc feature dim requires info about padding of each stage that isn't captured.
  119. training = backbone.training
  120. if training:
  121. backbone.eval()
  122. o = self.backbone(torch.zeros(
  123. 1, in_chans, img_size[0], img_size[1]))[-1]
  124. feature_size = o.shape[-2:]
  125. feature_dim = o.shape[1]
  126. backbone.train(training)
  127. else:
  128. feature_size = to_2tuple(feature_size)
  129. feature_dim = self.backbone.feature_info.channels()[-1]
  130. self.num_patches = feature_size[0] * feature_size[1]
  131. self.proj = nn.Linear(feature_dim, embed_dim)
  132. def forward(self, x):
  133. x = self.backbone(x)[-1]
  134. x = x.flatten(2).transpose(1, 2)
  135. x = self.proj(x)
  136. return x
  137. class ViT(nn.Module):
  138. """ Vision Transformer with support for patch or hybrid CNN input stage
  139. """
  140. def __init__(self,
  141. model_name='vit_base_patch16_224',
  142. img_size=384,
  143. patch_size=16,
  144. in_chans=3,
  145. embed_dim=1024,
  146. depth=24,
  147. num_heads=16,
  148. num_classes=19,
  149. mlp_ratio=4.,
  150. qkv_bias=True,
  151. qk_scale=None,
  152. drop_rate=0.1,
  153. attn_drop_rate=0.,
  154. drop_path_rate=0.,
  155. hybrid_backbone=None,
  156. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  157. norm_cfg=None,
  158. pos_embed_interp=False,
  159. random_init=False,
  160. align_corners=False,
  161. use_checkpoint=False,
  162. num_extra_tokens=1,
  163. out_features=None,
  164. **kwargs,
  165. ):
  166. super(ViT, self).__init__()
  167. self.model_name = model_name
  168. self.img_size = img_size
  169. self.patch_size = patch_size
  170. self.in_chans = in_chans
  171. self.embed_dim = embed_dim
  172. self.depth = depth
  173. self.num_heads = num_heads
  174. self.num_classes = num_classes
  175. self.mlp_ratio = mlp_ratio
  176. self.qkv_bias = qkv_bias
  177. self.qk_scale = qk_scale
  178. self.drop_rate = drop_rate
  179. self.attn_drop_rate = attn_drop_rate
  180. self.drop_path_rate = drop_path_rate
  181. self.hybrid_backbone = hybrid_backbone
  182. self.norm_layer = norm_layer
  183. self.norm_cfg = norm_cfg
  184. self.pos_embed_interp = pos_embed_interp
  185. self.random_init = random_init
  186. self.align_corners = align_corners
  187. self.use_checkpoint = use_checkpoint
  188. self.num_extra_tokens = num_extra_tokens
  189. self.out_features = out_features
  190. self.out_indices = [int(name[5:]) for name in out_features]
  191. # self.num_stages = self.depth
  192. # self.out_indices = tuple(range(self.num_stages))
  193. if self.hybrid_backbone is not None:
  194. self.patch_embed = HybridEmbed(
  195. self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
  196. else:
  197. self.patch_embed = PatchEmbed(
  198. img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
  199. self.num_patches = self.patch_embed.num_patches
  200. self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
  201. if self.num_extra_tokens == 2:
  202. self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
  203. self.pos_embed = nn.Parameter(torch.zeros(
  204. 1, self.num_patches + self.num_extra_tokens, self.embed_dim))
  205. self.pos_drop = nn.Dropout(p=self.drop_rate)
  206. # self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
  207. dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
  208. self.depth)] # stochastic depth decay rule
  209. self.blocks = nn.ModuleList([
  210. Block(
  211. dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
  212. qk_scale=self.qk_scale,
  213. drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
  214. for i in range(self.depth)])
  215. # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
  216. # self.repr = nn.Linear(embed_dim, representation_size)
  217. # self.repr_act = nn.Tanh()
  218. if patch_size == 16:
  219. self.fpn1 = nn.Sequential(
  220. nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
  221. nn.SyncBatchNorm(embed_dim),
  222. nn.GELU(),
  223. nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
  224. )
  225. self.fpn2 = nn.Sequential(
  226. nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
  227. )
  228. self.fpn3 = nn.Identity()
  229. self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
  230. elif patch_size == 8:
  231. self.fpn1 = nn.Sequential(
  232. nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
  233. )
  234. self.fpn2 = nn.Identity()
  235. self.fpn3 = nn.Sequential(
  236. nn.MaxPool2d(kernel_size=2, stride=2),
  237. )
  238. self.fpn4 = nn.Sequential(
  239. nn.MaxPool2d(kernel_size=4, stride=4),
  240. )
  241. trunc_normal_(self.pos_embed, std=.02)
  242. trunc_normal_(self.cls_token, std=.02)
  243. if self.num_extra_tokens==2:
  244. trunc_normal_(self.dist_token, std=0.2)
  245. self.apply(self._init_weights)
  246. # self.fix_init_weight()
  247. def fix_init_weight(self):
  248. def rescale(param, layer_id):
  249. param.div_(math.sqrt(2.0 * layer_id))
  250. for layer_id, layer in enumerate(self.blocks):
  251. rescale(layer.attn.proj.weight.data, layer_id + 1)
  252. rescale(layer.mlp.fc2.weight.data, layer_id + 1)
  253. def _init_weights(self, m):
  254. if isinstance(m, nn.Linear):
  255. trunc_normal_(m.weight, std=.02)
  256. if isinstance(m, nn.Linear) and m.bias is not None:
  257. nn.init.constant_(m.bias, 0)
  258. elif isinstance(m, nn.LayerNorm):
  259. nn.init.constant_(m.bias, 0)
  260. nn.init.constant_(m.weight, 1.0)
  261. '''
  262. def init_weights(self):
  263. logger = get_root_logger()
  264. trunc_normal_(self.pos_embed, std=.02)
  265. trunc_normal_(self.cls_token, std=.02)
  266. self.apply(self._init_weights)
  267. if self.init_cfg is None:
  268. logger.warn(f'No pre-trained weights for '
  269. f'{self.__class__.__name__}, '
  270. f'training start from scratch')
  271. else:
  272. assert 'checkpoint' in self.init_cfg, f'Only support ' \
  273. f'specify `Pretrained` in ' \
  274. f'`init_cfg` in ' \
  275. f'{self.__class__.__name__} '
  276. logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
  277. load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
  278. '''
  279. def get_num_layers(self):
  280. return len(self.blocks)
  281. @torch.jit.ignore
  282. def no_weight_decay(self):
  283. return {'pos_embed', 'cls_token'}
  284. def _conv_filter(self, state_dict, patch_size=16):
  285. """ convert patch embedding weight from manual patchify + linear proj to conv"""
  286. out_dict = {}
  287. for k, v in state_dict.items():
  288. if 'patch_embed.proj.weight' in k:
  289. v = v.reshape((v.shape[0], 3, patch_size, patch_size))
  290. out_dict[k] = v
  291. return out_dict
  292. def to_2D(self, x):
  293. n, hw, c = x.shape
  294. h = w = int(math.sqrt(hw))
  295. x = x.transpose(1, 2).reshape(n, c, h, w)
  296. return x
  297. def to_1D(self, x):
  298. n, c, h, w = x.shape
  299. x = x.reshape(n, c, -1).transpose(1, 2)
  300. return x
  301. def interpolate_pos_encoding(self, x, w, h):
  302. npatch = x.shape[1] - self.num_extra_tokens
  303. N = self.pos_embed.shape[1] - self.num_extra_tokens
  304. if npatch == N and w == h:
  305. return self.pos_embed
  306. class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
  307. patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
  308. dim = x.shape[-1]
  309. w0 = w // self.patch_embed.patch_size[0]
  310. h0 = h // self.patch_embed.patch_size[1]
  311. # we add a small number to avoid floating point error in the interpolation
  312. # see discussion at https://github.com/facebookresearch/dino/issues/8
  313. w0, h0 = w0 + 0.1, h0 + 0.1
  314. patch_pos_embed = nn.functional.interpolate(
  315. patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
  316. scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
  317. mode='bicubic',
  318. )
  319. assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
  320. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  321. return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
  322. def prepare_tokens(self, x, mask=None):
  323. B, nc, w, h = x.shape
  324. # patch linear embedding
  325. x = self.patch_embed(x)
  326. # mask image modeling
  327. if mask is not None:
  328. x = self.mask_model(x, mask)
  329. x = x.flatten(2).transpose(1, 2)
  330. # add the [CLS] token to the embed patch tokens
  331. all_tokens = [self.cls_token.expand(B, -1, -1)]
  332. if self.num_extra_tokens == 2:
  333. dist_tokens = self.dist_token.expand(B, -1, -1)
  334. all_tokens.append(dist_tokens)
  335. all_tokens.append(x)
  336. x = torch.cat(all_tokens, dim=1)
  337. # add positional encoding to each token
  338. x = x + self.interpolate_pos_encoding(x, w, h)
  339. return self.pos_drop(x)
  340. def forward_features(self, x):
  341. # print(f"==========shape of x is {x.shape}==========")
  342. B, _, H, W = x.shape
  343. Hp, Wp = H // self.patch_size, W // self.patch_size
  344. x = self.prepare_tokens(x)
  345. features = []
  346. for i, blk in enumerate(self.blocks):
  347. if self.use_checkpoint:
  348. x = checkpoint.checkpoint(blk, x)
  349. else:
  350. x = blk(x)
  351. if i in self.out_indices:
  352. xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
  353. features.append(xp.contiguous())
  354. ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
  355. for i in range(len(features)):
  356. features[i] = ops[i](features[i])
  357. feat_out = {}
  358. for name, value in zip(self.out_features, features):
  359. feat_out[name] = value
  360. return feat_out
  361. def forward(self, x):
  362. x = self.forward_features(x)
  363. return x
  364. def deit_base_patch16(pretrained=False, **kwargs):
  365. model = ViT(
  366. patch_size=16,
  367. drop_rate=0.,
  368. embed_dim=768,
  369. depth=12,
  370. num_heads=12,
  371. num_classes=1000,
  372. mlp_ratio=4.,
  373. qkv_bias=True,
  374. use_checkpoint=True,
  375. num_extra_tokens=2,
  376. **kwargs)
  377. model.default_cfg = _cfg()
  378. return model
  379. def mae_base_patch16(pretrained=False, **kwargs):
  380. model = ViT(
  381. patch_size=16,
  382. drop_rate=0.,
  383. embed_dim=768,
  384. depth=12,
  385. num_heads=12,
  386. num_classes=1000,
  387. mlp_ratio=4.,
  388. qkv_bias=True,
  389. use_checkpoint=True,
  390. num_extra_tokens=1,
  391. **kwargs)
  392. model.default_cfg = _cfg()
  393. return model