vision_transformer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. from paddle.nn.initializer import TruncatedNormal, Constant
  18. __all__ = [
  19. "VisionTransformer", "ViT_small_patch16_224", "ViT_base_patch16_224",
  20. "ViT_base_patch16_384", "ViT_base_patch32_384", "ViT_large_patch16_224",
  21. "ViT_large_patch16_384", "ViT_large_patch32_384", "ViT_huge_patch16_224",
  22. "ViT_huge_patch32_384"
  23. ]
  24. trunc_normal_ = TruncatedNormal(std=.02)
  25. zeros_ = Constant(value=0.)
  26. ones_ = Constant(value=1.)
  27. def to_2tuple(x):
  28. return tuple([x] * 2)
  29. def drop_path(x, drop_prob=0., training=False):
  30. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  31. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  32. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  33. """
  34. if drop_prob == 0. or not training:
  35. return x
  36. keep_prob = paddle.to_tensor(1 - drop_prob)
  37. shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
  38. random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
  39. random_tensor = paddle.floor(random_tensor) # binarize
  40. output = x.divide(keep_prob) * random_tensor
  41. return output
  42. class DropPath(nn.Layer):
  43. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  44. """
  45. def __init__(self, drop_prob=None):
  46. super(DropPath, self).__init__()
  47. self.drop_prob = drop_prob
  48. def forward(self, x):
  49. return drop_path(x, self.drop_prob, self.training)
  50. class Identity(nn.Layer):
  51. def __init__(self):
  52. super(Identity, self).__init__()
  53. def forward(self, input):
  54. return input
  55. class Mlp(nn.Layer):
  56. def __init__(self,
  57. in_features,
  58. hidden_features=None,
  59. out_features=None,
  60. act_layer=nn.GELU,
  61. drop=0.):
  62. super().__init__()
  63. out_features = out_features or in_features
  64. hidden_features = hidden_features or in_features
  65. self.fc1 = nn.Linear(in_features, hidden_features)
  66. self.act = act_layer()
  67. self.fc2 = nn.Linear(hidden_features, out_features)
  68. self.drop = nn.Dropout(drop)
  69. def forward(self, x):
  70. x = self.fc1(x)
  71. x = self.act(x)
  72. x = self.drop(x)
  73. x = self.fc2(x)
  74. x = self.drop(x)
  75. return x
  76. class Attention(nn.Layer):
  77. def __init__(self,
  78. dim,
  79. num_heads=8,
  80. qkv_bias=False,
  81. qk_scale=None,
  82. attn_drop=0.,
  83. proj_drop=0.):
  84. super().__init__()
  85. self.num_heads = num_heads
  86. head_dim = dim // num_heads
  87. self.scale = qk_scale or head_dim**-0.5
  88. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  89. self.attn_drop = nn.Dropout(attn_drop)
  90. self.proj = nn.Linear(dim, dim)
  91. self.proj_drop = nn.Dropout(proj_drop)
  92. def forward(self, x):
  93. # B= paddle.shape(x)[0]
  94. N, C = x.shape[1:]
  95. qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
  96. self.num_heads)).transpose((2, 0, 3, 1, 4))
  97. q, k, v = qkv[0], qkv[1], qkv[2]
  98. attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
  99. attn = nn.functional.softmax(attn, axis=-1)
  100. attn = self.attn_drop(attn)
  101. x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
  102. x = self.proj(x)
  103. x = self.proj_drop(x)
  104. return x
  105. class Block(nn.Layer):
  106. def __init__(self,
  107. dim,
  108. num_heads,
  109. mlp_ratio=4.,
  110. qkv_bias=False,
  111. qk_scale=None,
  112. drop=0.,
  113. attn_drop=0.,
  114. drop_path=0.,
  115. act_layer=nn.GELU,
  116. norm_layer='nn.LayerNorm',
  117. epsilon=1e-5):
  118. super().__init__()
  119. self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
  120. self.attn = Attention(
  121. dim,
  122. num_heads=num_heads,
  123. qkv_bias=qkv_bias,
  124. qk_scale=qk_scale,
  125. attn_drop=attn_drop,
  126. proj_drop=drop)
  127. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  128. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  129. self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
  130. mlp_hidden_dim = int(dim * mlp_ratio)
  131. self.mlp = Mlp(in_features=dim,
  132. hidden_features=mlp_hidden_dim,
  133. act_layer=act_layer,
  134. drop=drop)
  135. def forward(self, x):
  136. x = x + self.drop_path(self.attn(self.norm1(x)))
  137. x = x + self.drop_path(self.mlp(self.norm2(x)))
  138. return x
  139. class PatchEmbed(nn.Layer):
  140. """ Image to Patch Embedding
  141. """
  142. def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
  143. super().__init__()
  144. img_size = to_2tuple(img_size)
  145. patch_size = to_2tuple(patch_size)
  146. num_patches = (img_size[1] // patch_size[1]) * \
  147. (img_size[0] // patch_size[0])
  148. self.img_size = img_size
  149. self.patch_size = patch_size
  150. self.num_patches = num_patches
  151. self.proj = nn.Conv2D(
  152. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  153. def forward(self, x):
  154. B, C, H, W = x.shape
  155. assert H == self.img_size[0] and W == self.img_size[1], \
  156. "Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  157. x = self.proj(x).flatten(2).transpose((0, 2, 1))
  158. return x
  159. class VisionTransformer(nn.Layer):
  160. """ Vision Transformer with support for patch input
  161. """
  162. def __init__(self,
  163. img_size=224,
  164. patch_size=16,
  165. in_chans=3,
  166. class_dim=1000,
  167. embed_dim=768,
  168. depth=12,
  169. num_heads=12,
  170. mlp_ratio=4,
  171. qkv_bias=False,
  172. qk_scale=None,
  173. drop_rate=0.,
  174. attn_drop_rate=0.,
  175. drop_path_rate=0.,
  176. norm_layer='nn.LayerNorm',
  177. epsilon=1e-5,
  178. **args):
  179. super().__init__()
  180. self.class_dim = class_dim
  181. self.num_features = self.embed_dim = embed_dim
  182. self.patch_embed = PatchEmbed(
  183. img_size=img_size,
  184. patch_size=patch_size,
  185. in_chans=in_chans,
  186. embed_dim=embed_dim)
  187. num_patches = self.patch_embed.num_patches
  188. self.pos_embed = self.create_parameter(
  189. shape=(1, num_patches + 1, embed_dim), default_initializer=zeros_)
  190. self.add_parameter("pos_embed", self.pos_embed)
  191. self.cls_token = self.create_parameter(
  192. shape=(1, 1, embed_dim), default_initializer=zeros_)
  193. self.add_parameter("cls_token", self.cls_token)
  194. self.pos_drop = nn.Dropout(p=drop_rate)
  195. dpr = np.linspace(0, drop_path_rate, depth)
  196. self.blocks = nn.LayerList([
  197. Block(
  198. dim=embed_dim,
  199. num_heads=num_heads,
  200. mlp_ratio=mlp_ratio,
  201. qkv_bias=qkv_bias,
  202. qk_scale=qk_scale,
  203. drop=drop_rate,
  204. attn_drop=attn_drop_rate,
  205. drop_path=dpr[i],
  206. norm_layer=norm_layer,
  207. epsilon=epsilon) for i in range(depth)
  208. ])
  209. self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
  210. # Classifier head
  211. self.head = nn.Linear(embed_dim,
  212. class_dim) if class_dim > 0 else Identity()
  213. trunc_normal_(self.pos_embed)
  214. trunc_normal_(self.cls_token)
  215. self.apply(self._init_weights)
  216. def _init_weights(self, m):
  217. if isinstance(m, nn.Linear):
  218. trunc_normal_(m.weight)
  219. if isinstance(m, nn.Linear) and m.bias is not None:
  220. zeros_(m.bias)
  221. elif isinstance(m, nn.LayerNorm):
  222. zeros_(m.bias)
  223. ones_(m.weight)
  224. def forward_features(self, x):
  225. # B = x.shape[0]
  226. B = paddle.shape(x)[0]
  227. x = self.patch_embed(x)
  228. cls_tokens = self.cls_token.expand((B, -1, -1))
  229. x = paddle.concat((cls_tokens, x), axis=1)
  230. x = x + self.pos_embed
  231. x = self.pos_drop(x)
  232. for blk in self.blocks:
  233. x = blk(x)
  234. x = self.norm(x)
  235. return x[:, 0]
  236. def forward(self, x):
  237. x = self.forward_features(x)
  238. x = self.head(x)
  239. return x
  240. def ViT_small_patch16_224(**kwargs):
  241. model = VisionTransformer(
  242. patch_size=16,
  243. embed_dim=768,
  244. depth=8,
  245. num_heads=8,
  246. mlp_ratio=3,
  247. qk_scale=768**-0.5,
  248. **kwargs)
  249. return model
  250. def ViT_base_patch16_224(**kwargs):
  251. model = VisionTransformer(
  252. patch_size=16,
  253. embed_dim=768,
  254. depth=12,
  255. num_heads=12,
  256. mlp_ratio=4,
  257. qkv_bias=True,
  258. epsilon=1e-6,
  259. **kwargs)
  260. return model
  261. def ViT_base_patch16_384(**kwargs):
  262. model = VisionTransformer(
  263. img_size=384,
  264. patch_size=16,
  265. embed_dim=768,
  266. depth=12,
  267. num_heads=12,
  268. mlp_ratio=4,
  269. qkv_bias=True,
  270. epsilon=1e-6,
  271. **kwargs)
  272. return model
  273. def ViT_base_patch32_384(**kwargs):
  274. model = VisionTransformer(
  275. img_size=384,
  276. patch_size=32,
  277. embed_dim=768,
  278. depth=12,
  279. num_heads=12,
  280. mlp_ratio=4,
  281. qkv_bias=True,
  282. epsilon=1e-6,
  283. **kwargs)
  284. return model
  285. def ViT_large_patch16_224(**kwargs):
  286. model = VisionTransformer(
  287. patch_size=16,
  288. embed_dim=1024,
  289. depth=24,
  290. num_heads=16,
  291. mlp_ratio=4,
  292. qkv_bias=True,
  293. epsilon=1e-6,
  294. **kwargs)
  295. return model
  296. def ViT_large_patch16_384(**kwargs):
  297. model = VisionTransformer(
  298. img_size=384,
  299. patch_size=16,
  300. embed_dim=1024,
  301. depth=24,
  302. num_heads=16,
  303. mlp_ratio=4,
  304. qkv_bias=True,
  305. epsilon=1e-6,
  306. **kwargs)
  307. return model
  308. def ViT_large_patch32_384(**kwargs):
  309. model = VisionTransformer(
  310. img_size=384,
  311. patch_size=32,
  312. embed_dim=1024,
  313. depth=24,
  314. num_heads=16,
  315. mlp_ratio=4,
  316. qkv_bias=True,
  317. epsilon=1e-6,
  318. **kwargs)
  319. return model
  320. def ViT_huge_patch16_224(**kwargs):
  321. model = VisionTransformer(
  322. patch_size=16,
  323. embed_dim=1280,
  324. depth=32,
  325. num_heads=16,
  326. mlp_ratio=4,
  327. **kwargs)
  328. return model
  329. def ViT_huge_patch32_384(**kwargs):
  330. model = VisionTransformer(
  331. img_size=384,
  332. patch_size=32,
  333. embed_dim=1280,
  334. depth=32,
  335. num_heads=16,
  336. mlp_ratio=4,
  337. **kwargs)
  338. return model