| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 |
- """
- Mostly copy-paste from DINO and timm library:
- https://github.com/facebookresearch/dino
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
- """
- import warnings
- import math
- import torch
- import torch.nn as nn
- import torch.utils.checkpoint as checkpoint
- from timm.models.layers import trunc_normal_, drop_path, to_2tuple
- from functools import partial
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .9, 'interpolation': 'bicubic',
- 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
- **kwargs
- }
- class DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
- def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training)
- def extra_repr(self) -> str:
- return 'p={}'.format(self.drop_prob)
- class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
- class Attention(nn.Module):
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
- super().__init__()
- self.num_heads = num_heads
- head_dim = dim // num_heads
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
- self.scale = qk_scale or head_dim ** -0.5
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, x):
- B, N, C = x.shape
- q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
- C // self.num_heads).permute(2, 0, 3, 1, 4)
- attn = (q @ k.transpose(-2, -1)) * self.scale
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- class Block(nn.Module):
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
- self.drop_path = DropPath(
- drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
- act_layer=act_layer, drop=drop)
- def forward(self, x):
- x = x + self.drop_path(self.attn(self.norm1(x)))
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
- class PatchEmbed(nn.Module):
- """ Image to Patch Embedding
- """
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
- self.num_patches_w, self.num_patches_h = self.window_size
- self.num_patches = self.window_size[0] * self.window_size[1]
- self.img_size = img_size
- self.patch_size = patch_size
- self.proj = nn.Conv2d(in_chans, embed_dim,
- kernel_size=patch_size, stride=patch_size)
- def forward(self, x):
- x = self.proj(x)
- return x
- class HybridEmbed(nn.Module):
- """ CNN Feature Map Embedding
- Extract feature map from CNN, flatten, project to embedding dim.
- """
- def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
- super().__init__()
- assert isinstance(backbone, nn.Module)
- img_size = to_2tuple(img_size)
- self.img_size = img_size
- self.backbone = backbone
- if feature_size is None:
- with torch.no_grad():
- # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
- # map for all networks, the feature metadata has reliable channel and stride info, but using
- # stride to calc feature dim requires info about padding of each stage that isn't captured.
- training = backbone.training
- if training:
- backbone.eval()
- o = self.backbone(torch.zeros(
- 1, in_chans, img_size[0], img_size[1]))[-1]
- feature_size = o.shape[-2:]
- feature_dim = o.shape[1]
- backbone.train(training)
- else:
- feature_size = to_2tuple(feature_size)
- feature_dim = self.backbone.feature_info.channels()[-1]
- self.num_patches = feature_size[0] * feature_size[1]
- self.proj = nn.Linear(feature_dim, embed_dim)
- def forward(self, x):
- x = self.backbone(x)[-1]
- x = x.flatten(2).transpose(1, 2)
- x = self.proj(x)
- return x
- class ViT(nn.Module):
- """ Vision Transformer with support for patch or hybrid CNN input stage
- """
- def __init__(self,
- model_name='vit_base_patch16_224',
- img_size=384,
- patch_size=16,
- in_chans=3,
- embed_dim=1024,
- depth=24,
- num_heads=16,
- num_classes=19,
- mlp_ratio=4.,
- qkv_bias=True,
- qk_scale=None,
- drop_rate=0.1,
- attn_drop_rate=0.,
- drop_path_rate=0.,
- hybrid_backbone=None,
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
- norm_cfg=None,
- pos_embed_interp=False,
- random_init=False,
- align_corners=False,
- use_checkpoint=False,
- num_extra_tokens=1,
- out_features=None,
- **kwargs,
- ):
- super(ViT, self).__init__()
- self.model_name = model_name
- self.img_size = img_size
- self.patch_size = patch_size
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.depth = depth
- self.num_heads = num_heads
- self.num_classes = num_classes
- self.mlp_ratio = mlp_ratio
- self.qkv_bias = qkv_bias
- self.qk_scale = qk_scale
- self.drop_rate = drop_rate
- self.attn_drop_rate = attn_drop_rate
- self.drop_path_rate = drop_path_rate
- self.hybrid_backbone = hybrid_backbone
- self.norm_layer = norm_layer
- self.norm_cfg = norm_cfg
- self.pos_embed_interp = pos_embed_interp
- self.random_init = random_init
- self.align_corners = align_corners
- self.use_checkpoint = use_checkpoint
- self.num_extra_tokens = num_extra_tokens
- self.out_features = out_features
- self.out_indices = [int(name[5:]) for name in out_features]
- # self.num_stages = self.depth
- # self.out_indices = tuple(range(self.num_stages))
- if self.hybrid_backbone is not None:
- self.patch_embed = HybridEmbed(
- self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
- else:
- self.patch_embed = PatchEmbed(
- img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
- self.num_patches = self.patch_embed.num_patches
- self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
- if self.num_extra_tokens == 2:
- self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
- self.pos_embed = nn.Parameter(torch.zeros(
- 1, self.num_patches + self.num_extra_tokens, self.embed_dim))
- self.pos_drop = nn.Dropout(p=self.drop_rate)
- # self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
- dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
- self.depth)] # stochastic depth decay rule
- self.blocks = nn.ModuleList([
- Block(
- dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
- qk_scale=self.qk_scale,
- drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
- for i in range(self.depth)])
- # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
- # self.repr = nn.Linear(embed_dim, representation_size)
- # self.repr_act = nn.Tanh()
- if patch_size == 16:
- self.fpn1 = nn.Sequential(
- nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
- nn.SyncBatchNorm(embed_dim),
- nn.GELU(),
- nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
- )
- self.fpn2 = nn.Sequential(
- nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
- )
- self.fpn3 = nn.Identity()
- self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
- elif patch_size == 8:
- self.fpn1 = nn.Sequential(
- nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
- )
- self.fpn2 = nn.Identity()
- self.fpn3 = nn.Sequential(
- nn.MaxPool2d(kernel_size=2, stride=2),
- )
- self.fpn4 = nn.Sequential(
- nn.MaxPool2d(kernel_size=4, stride=4),
- )
- trunc_normal_(self.pos_embed, std=.02)
- trunc_normal_(self.cls_token, std=.02)
- if self.num_extra_tokens==2:
- trunc_normal_(self.dist_token, std=0.2)
- self.apply(self._init_weights)
- # self.fix_init_weight()
- def fix_init_weight(self):
- def rescale(param, layer_id):
- param.div_(math.sqrt(2.0 * layer_id))
- for layer_id, layer in enumerate(self.blocks):
- rescale(layer.attn.proj.weight.data, layer_id + 1)
- rescale(layer.mlp.fc2.weight.data, layer_id + 1)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- '''
- def init_weights(self):
- logger = get_root_logger()
- trunc_normal_(self.pos_embed, std=.02)
- trunc_normal_(self.cls_token, std=.02)
- self.apply(self._init_weights)
- if self.init_cfg is None:
- logger.warn(f'No pre-trained weights for '
- f'{self.__class__.__name__}, '
- f'training start from scratch')
- else:
- assert 'checkpoint' in self.init_cfg, f'Only support ' \
- f'specify `Pretrained` in ' \
- f'`init_cfg` in ' \
- f'{self.__class__.__name__} '
- logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
- load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
- '''
- def get_num_layers(self):
- return len(self.blocks)
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'pos_embed', 'cls_token'}
- def _conv_filter(self, state_dict, patch_size=16):
- """ convert patch embedding weight from manual patchify + linear proj to conv"""
- out_dict = {}
- for k, v in state_dict.items():
- if 'patch_embed.proj.weight' in k:
- v = v.reshape((v.shape[0], 3, patch_size, patch_size))
- out_dict[k] = v
- return out_dict
- def to_2D(self, x):
- n, hw, c = x.shape
- h = w = int(math.sqrt(hw))
- x = x.transpose(1, 2).reshape(n, c, h, w)
- return x
- def to_1D(self, x):
- n, c, h, w = x.shape
- x = x.reshape(n, c, -1).transpose(1, 2)
- return x
- def interpolate_pos_encoding(self, x, w, h):
- npatch = x.shape[1] - self.num_extra_tokens
- N = self.pos_embed.shape[1] - self.num_extra_tokens
- if npatch == N and w == h:
- return self.pos_embed
- class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
- patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
- dim = x.shape[-1]
- w0 = w // self.patch_embed.patch_size[0]
- h0 = h // self.patch_embed.patch_size[1]
- # we add a small number to avoid floating point error in the interpolation
- # see discussion at https://github.com/facebookresearch/dino/issues/8
- w0, h0 = w0 + 0.1, h0 + 0.1
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
- scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
- mode='bicubic',
- )
- assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
- def prepare_tokens(self, x, mask=None):
- B, nc, w, h = x.shape
- # patch linear embedding
- x = self.patch_embed(x)
- # mask image modeling
- if mask is not None:
- x = self.mask_model(x, mask)
- x = x.flatten(2).transpose(1, 2)
- # add the [CLS] token to the embed patch tokens
- all_tokens = [self.cls_token.expand(B, -1, -1)]
- if self.num_extra_tokens == 2:
- dist_tokens = self.dist_token.expand(B, -1, -1)
- all_tokens.append(dist_tokens)
- all_tokens.append(x)
- x = torch.cat(all_tokens, dim=1)
- # add positional encoding to each token
- x = x + self.interpolate_pos_encoding(x, w, h)
- return self.pos_drop(x)
- def forward_features(self, x):
- # print(f"==========shape of x is {x.shape}==========")
- B, _, H, W = x.shape
- Hp, Wp = H // self.patch_size, W // self.patch_size
- x = self.prepare_tokens(x)
- features = []
- for i, blk in enumerate(self.blocks):
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x)
- else:
- x = blk(x)
- if i in self.out_indices:
- xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
- features.append(xp.contiguous())
- ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
- for i in range(len(features)):
- features[i] = ops[i](features[i])
- feat_out = {}
- for name, value in zip(self.out_features, features):
- feat_out[name] = value
- return feat_out
- def forward(self, x):
- x = self.forward_features(x)
- return x
- def deit_base_patch16(pretrained=False, **kwargs):
- model = ViT(
- patch_size=16,
- drop_rate=0.,
- embed_dim=768,
- depth=12,
- num_heads=12,
- num_classes=1000,
- mlp_ratio=4.,
- qkv_bias=True,
- use_checkpoint=True,
- num_extra_tokens=2,
- **kwargs)
- model.default_cfg = _cfg()
- return model
- def mae_base_patch16(pretrained=False, **kwargs):
- model = ViT(
- patch_size=16,
- drop_rate=0.,
- embed_dim=768,
- depth=12,
- num_heads=12,
- num_classes=1000,
- mlp_ratio=4.,
- qkv_bias=True,
- use_checkpoint=True,
- num_extra_tokens=1,
- **kwargs)
- model.default_cfg = _cfg()
- return model
|