distilled_vision_transformer.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from .vision_transformer import VisionTransformer, Identity, trunc_normal_, zeros_
  17. __all__ = [
  18. 'DeiT_tiny_patch16_224', 'DeiT_small_patch16_224', 'DeiT_base_patch16_224',
  19. 'DeiT_tiny_distilled_patch16_224', 'DeiT_small_distilled_patch16_224',
  20. 'DeiT_base_distilled_patch16_224', 'DeiT_base_patch16_384',
  21. 'DeiT_base_distilled_patch16_384'
  22. ]
  23. class DistilledVisionTransformer(VisionTransformer):
  24. def __init__(self,
  25. img_size=224,
  26. patch_size=16,
  27. class_dim=1000,
  28. embed_dim=768,
  29. depth=12,
  30. num_heads=12,
  31. mlp_ratio=4,
  32. qkv_bias=False,
  33. norm_layer='nn.LayerNorm',
  34. epsilon=1e-5,
  35. **kwargs):
  36. super().__init__(
  37. img_size=img_size,
  38. patch_size=patch_size,
  39. class_dim=class_dim,
  40. embed_dim=embed_dim,
  41. depth=depth,
  42. num_heads=num_heads,
  43. mlp_ratio=mlp_ratio,
  44. qkv_bias=qkv_bias,
  45. norm_layer=norm_layer,
  46. epsilon=epsilon,
  47. **kwargs)
  48. self.pos_embed = self.create_parameter(
  49. shape=(1, self.patch_embed.num_patches + 2, self.embed_dim),
  50. default_initializer=zeros_)
  51. self.add_parameter("pos_embed", self.pos_embed)
  52. self.dist_token = self.create_parameter(
  53. shape=(1, 1, self.embed_dim), default_initializer=zeros_)
  54. self.add_parameter("cls_token", self.cls_token)
  55. self.head_dist = nn.Linear(
  56. self.embed_dim,
  57. self.class_dim) if self.class_dim > 0 else Identity()
  58. trunc_normal_(self.dist_token)
  59. trunc_normal_(self.pos_embed)
  60. self.head_dist.apply(self._init_weights)
  61. def forward_features(self, x):
  62. B = paddle.shape(x)[0]
  63. x = self.patch_embed(x)
  64. cls_tokens = self.cls_token.expand((B, -1, -1))
  65. dist_token = self.dist_token.expand((B, -1, -1))
  66. x = paddle.concat((cls_tokens, dist_token, x), axis=1)
  67. x = x + self.pos_embed
  68. x = self.pos_drop(x)
  69. for blk in self.blocks:
  70. x = blk(x)
  71. x = self.norm(x)
  72. return x[:, 0], x[:, 1]
  73. def forward(self, x):
  74. x, x_dist = self.forward_features(x)
  75. x = self.head(x)
  76. x_dist = self.head_dist(x_dist)
  77. return (x + x_dist) / 2
  78. def DeiT_tiny_patch16_224(**kwargs):
  79. model = VisionTransformer(
  80. patch_size=16,
  81. embed_dim=192,
  82. depth=12,
  83. num_heads=3,
  84. mlp_ratio=4,
  85. qkv_bias=True,
  86. epsilon=1e-6,
  87. **kwargs)
  88. return model
  89. def DeiT_small_patch16_224(**kwargs):
  90. model = VisionTransformer(
  91. patch_size=16,
  92. embed_dim=384,
  93. depth=12,
  94. num_heads=6,
  95. mlp_ratio=4,
  96. qkv_bias=True,
  97. epsilon=1e-6,
  98. **kwargs)
  99. return model
  100. def DeiT_base_patch16_224(**kwargs):
  101. model = VisionTransformer(
  102. patch_size=16,
  103. embed_dim=768,
  104. depth=12,
  105. num_heads=12,
  106. mlp_ratio=4,
  107. qkv_bias=True,
  108. epsilon=1e-6,
  109. **kwargs)
  110. return model
  111. def DeiT_tiny_distilled_patch16_224(**kwargs):
  112. model = DistilledVisionTransformer(
  113. patch_size=16,
  114. embed_dim=192,
  115. depth=12,
  116. num_heads=3,
  117. mlp_ratio=4,
  118. qkv_bias=True,
  119. epsilon=1e-6,
  120. **kwargs)
  121. return model
  122. def DeiT_small_distilled_patch16_224(**kwargs):
  123. model = DistilledVisionTransformer(
  124. patch_size=16,
  125. embed_dim=384,
  126. depth=12,
  127. num_heads=6,
  128. mlp_ratio=4,
  129. qkv_bias=True,
  130. epsilon=1e-6,
  131. **kwargs)
  132. return model
  133. def DeiT_base_distilled_patch16_224(**kwargs):
  134. model = DistilledVisionTransformer(
  135. patch_size=16,
  136. embed_dim=768,
  137. depth=12,
  138. num_heads=12,
  139. mlp_ratio=4,
  140. qkv_bias=True,
  141. epsilon=1e-6,
  142. **kwargs)
  143. return model
  144. def DeiT_base_patch16_384(**kwargs):
  145. model = VisionTransformer(
  146. img_size=384,
  147. patch_size=16,
  148. embed_dim=768,
  149. depth=12,
  150. num_heads=12,
  151. mlp_ratio=4,
  152. qkv_bias=True,
  153. epsilon=1e-6,
  154. **kwargs)
  155. return model
  156. def DeiT_base_distilled_patch16_384(**kwargs):
  157. model = DistilledVisionTransformer(
  158. img_size=384,
  159. patch_size=16,
  160. embed_dim=768,
  161. depth=12,
  162. num_heads=12,
  163. mlp_ratio=4,
  164. qkv_bias=True,
  165. epsilon=1e-6,
  166. **kwargs)
  167. return model