swin_transformer.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import TruncatedNormal, Constant, Assign
  18. from paddlex.ppdet.modeling.shape_spec import ShapeSpec
  19. from paddlex.ppdet.core.workspace import register, serializable
  20. import numpy as np
  21. # Common initializations
  22. ones_ = Constant(value=1.)
  23. zeros_ = Constant(value=0.)
  24. trunc_normal_ = TruncatedNormal(std=.02)
  25. # Common Functions
  26. def to_2tuple(x):
  27. return tuple([x] * 2)
  28. def add_parameter(layer, datas, name=None):
  29. parameter = layer.create_parameter(
  30. shape=(datas.shape), default_initializer=Assign(datas))
  31. if name:
  32. layer.add_parameter(name, parameter)
  33. return parameter
  34. # Common Layers
  35. def drop_path(x, drop_prob=0., training=False):
  36. """
  37. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  38. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  39. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  40. """
  41. if drop_prob == 0. or not training:
  42. return x
  43. keep_prob = paddle.to_tensor(1 - drop_prob)
  44. shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
  45. random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
  46. random_tensor = paddle.floor(random_tensor) # binarize
  47. output = x.divide(keep_prob) * random_tensor
  48. return output
  49. class DropPath(nn.Layer):
  50. def __init__(self, drop_prob=None):
  51. super(DropPath, self).__init__()
  52. self.drop_prob = drop_prob
  53. def forward(self, x):
  54. return drop_path(x, self.drop_prob, self.training)
  55. class Identity(nn.Layer):
  56. def __init__(self):
  57. super(Identity, self).__init__()
  58. def forward(self, input):
  59. return input
  60. class Mlp(nn.Layer):
  61. def __init__(self,
  62. in_features,
  63. hidden_features=None,
  64. out_features=None,
  65. act_layer=nn.GELU,
  66. drop=0.):
  67. super().__init__()
  68. out_features = out_features or in_features
  69. hidden_features = hidden_features or in_features
  70. self.fc1 = nn.Linear(in_features, hidden_features)
  71. self.act = act_layer()
  72. self.fc2 = nn.Linear(hidden_features, out_features)
  73. self.drop = nn.Dropout(drop)
  74. def forward(self, x):
  75. x = self.fc1(x)
  76. x = self.act(x)
  77. x = self.drop(x)
  78. x = self.fc2(x)
  79. x = self.drop(x)
  80. return x
  81. def window_partition(x, window_size):
  82. """
  83. Args:
  84. x: (B, H, W, C)
  85. window_size (int): window size
  86. Returns:
  87. windows: (num_windows*B, window_size, window_size, C)
  88. """
  89. B, H, W, C = x.shape
  90. x = x.reshape(
  91. [B, H // window_size, window_size, W // window_size, window_size, C])
  92. windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape(
  93. [-1, window_size, window_size, C])
  94. return windows
  95. def window_reverse(windows, window_size, H, W):
  96. """
  97. Args:
  98. windows: (num_windows*B, window_size, window_size, C)
  99. window_size (int): Window size
  100. H (int): Height of image
  101. W (int): Width of image
  102. Returns:
  103. x: (B, H, W, C)
  104. """
  105. B = int(windows.shape[0] / (H * W / window_size / window_size))
  106. x = windows.reshape(
  107. [B, H // window_size, W // window_size, window_size, window_size, -1])
  108. x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1])
  109. return x
  110. class WindowAttention(nn.Layer):
  111. """ Window based multi-head self attention (W-MSA) module with relative position bias.
  112. It supports both of shifted and non-shifted window.
  113. Args:
  114. dim (int): Number of input channels.
  115. window_size (tuple[int]): The height and width of the window.
  116. num_heads (int): Number of attention heads.
  117. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  118. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  119. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  120. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  121. """
  122. def __init__(self,
  123. dim,
  124. window_size,
  125. num_heads,
  126. qkv_bias=True,
  127. qk_scale=None,
  128. attn_drop=0.,
  129. proj_drop=0.):
  130. super().__init__()
  131. self.dim = dim
  132. self.window_size = window_size # Wh, Ww
  133. self.num_heads = num_heads
  134. head_dim = dim // num_heads
  135. self.scale = qk_scale or head_dim**-0.5
  136. # define a parameter table of relative position bias
  137. self.relative_position_bias_table = add_parameter(
  138. self,
  139. paddle.zeros(((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
  140. num_heads))) # 2*Wh-1 * 2*Ww-1, nH
  141. # get pair-wise relative position index for each token inside the window
  142. coords_h = paddle.arange(self.window_size[0])
  143. coords_w = paddle.arange(self.window_size[1])
  144. coords = paddle.stack(paddle.meshgrid(
  145. [coords_h, coords_w])) # 2, Wh, Ww
  146. coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
  147. coords_flatten_1 = coords_flatten.unsqueeze(axis=2)
  148. coords_flatten_2 = coords_flatten.unsqueeze(axis=1)
  149. relative_coords = coords_flatten_1 - coords_flatten_2
  150. relative_coords = relative_coords.transpose(
  151. [1, 2, 0]) # Wh*Ww, Wh*Ww, 2
  152. relative_coords[:, :, 0] += self.window_size[
  153. 0] - 1 # shift to start from 0
  154. relative_coords[:, :, 1] += self.window_size[1] - 1
  155. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  156. self.relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  157. self.register_buffer("relative_position_index",
  158. self.relative_position_index)
  159. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  160. self.attn_drop = nn.Dropout(attn_drop)
  161. self.proj = nn.Linear(dim, dim)
  162. self.proj_drop = nn.Dropout(proj_drop)
  163. trunc_normal_(self.relative_position_bias_table)
  164. self.softmax = nn.Softmax(axis=-1)
  165. def forward(self, x, mask=None):
  166. """ Forward function.
  167. Args:
  168. x: input features with shape of (num_windows*B, N, C)
  169. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  170. """
  171. B_, N, C = x.shape
  172. qkv = self.qkv(x).reshape(
  173. [B_, N, 3, self.num_heads, C // self.num_heads]).transpose(
  174. [2, 0, 3, 1, 4])
  175. q, k, v = qkv[0], qkv[1], qkv[2]
  176. q = q * self.scale
  177. attn = paddle.mm(q, k.transpose([0, 1, 3, 2]))
  178. index = self.relative_position_index.reshape([-1])
  179. relative_position_bias = paddle.index_select(
  180. self.relative_position_bias_table, index)
  181. relative_position_bias = relative_position_bias.reshape([
  182. self.window_size[0] * self.window_size[1],
  183. self.window_size[0] * self.window_size[1], -1
  184. ]) # Wh*Ww,Wh*Ww,nH
  185. relative_position_bias = relative_position_bias.transpose(
  186. [2, 0, 1]) # nH, Wh*Ww, Wh*Ww
  187. attn = attn + relative_position_bias.unsqueeze(0)
  188. if mask is not None:
  189. nW = mask.shape[0]
  190. attn = attn.reshape([B_ // nW, nW, self.num_heads, N, N
  191. ]) + mask.unsqueeze(1).unsqueeze(0)
  192. attn = attn.reshape([-1, self.num_heads, N, N])
  193. attn = self.softmax(attn)
  194. else:
  195. attn = self.softmax(attn)
  196. attn = self.attn_drop(attn)
  197. # x = (attn @ v).transpose(1, 2).reshape([B_, N, C])
  198. x = paddle.mm(attn, v).transpose([0, 2, 1, 3]).reshape([B_, N, C])
  199. x = self.proj(x)
  200. x = self.proj_drop(x)
  201. return x
  202. class SwinTransformerBlock(nn.Layer):
  203. """ Swin Transformer Block.
  204. Args:
  205. dim (int): Number of input channels.
  206. num_heads (int): Number of attention heads.
  207. window_size (int): Window size.
  208. shift_size (int): Shift size for SW-MSA.
  209. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  210. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  211. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  212. drop (float, optional): Dropout rate. Default: 0.0
  213. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  214. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  215. act_layer (nn.Layer, optional): Activation layer. Default: nn.GELU
  216. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  217. """
  218. def __init__(self,
  219. dim,
  220. num_heads,
  221. window_size=7,
  222. shift_size=0,
  223. mlp_ratio=4.,
  224. qkv_bias=True,
  225. qk_scale=None,
  226. drop=0.,
  227. attn_drop=0.,
  228. drop_path=0.,
  229. act_layer=nn.GELU,
  230. norm_layer=nn.LayerNorm):
  231. super().__init__()
  232. self.dim = dim
  233. self.num_heads = num_heads
  234. self.window_size = window_size
  235. self.shift_size = shift_size
  236. self.mlp_ratio = mlp_ratio
  237. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  238. self.norm1 = norm_layer(dim)
  239. self.attn = WindowAttention(
  240. dim,
  241. window_size=to_2tuple(self.window_size),
  242. num_heads=num_heads,
  243. qkv_bias=qkv_bias,
  244. qk_scale=qk_scale,
  245. attn_drop=attn_drop,
  246. proj_drop=drop)
  247. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  248. self.norm2 = norm_layer(dim)
  249. mlp_hidden_dim = int(dim * mlp_ratio)
  250. self.mlp = Mlp(in_features=dim,
  251. hidden_features=mlp_hidden_dim,
  252. act_layer=act_layer,
  253. drop=drop)
  254. self.H = None
  255. self.W = None
  256. def forward(self, x, mask_matrix):
  257. """ Forward function.
  258. Args:
  259. x: Input feature, tensor size (B, H*W, C).
  260. H, W: Spatial resolution of the input feature.
  261. mask_matrix: Attention mask for cyclic shift.
  262. """
  263. B, L, C = x.shape
  264. H, W = self.H, self.W
  265. assert L == H * W, "input feature has wrong size"
  266. shortcut = x
  267. x = self.norm1(x)
  268. x = x.reshape([B, H, W, C])
  269. # pad feature maps to multiples of window size
  270. pad_l = pad_t = 0
  271. pad_r = (self.window_size - W % self.window_size) % self.window_size
  272. pad_b = (self.window_size - H % self.window_size) % self.window_size
  273. x = F.pad(x, [0, pad_l, 0, pad_b, 0, pad_r, 0, pad_t])
  274. _, Hp, Wp, _ = x.shape
  275. # cyclic shift
  276. if self.shift_size > 0:
  277. shifted_x = paddle.roll(
  278. x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
  279. attn_mask = mask_matrix
  280. else:
  281. shifted_x = x
  282. attn_mask = None
  283. # partition windows
  284. x_windows = window_partition(
  285. shifted_x, self.window_size) # nW*B, window_size, window_size, C
  286. x_windows = x_windows.reshape(
  287. [-1, self.window_size * self.window_size,
  288. C]) # nW*B, window_size*window_size, C
  289. # W-MSA/SW-MSA
  290. attn_windows = self.attn(
  291. x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
  292. # merge windows
  293. attn_windows = attn_windows.reshape(
  294. [-1, self.window_size, self.window_size, C])
  295. shifted_x = window_reverse(attn_windows, self.window_size, Hp,
  296. Wp) # B H' W' C
  297. # reverse cyclic shift
  298. if self.shift_size > 0:
  299. x = paddle.roll(
  300. shifted_x,
  301. shifts=(self.shift_size, self.shift_size),
  302. axis=(1, 2))
  303. else:
  304. x = shifted_x
  305. if pad_r > 0 or pad_b > 0:
  306. x = x[:, :H, :W, :]
  307. x = x.reshape([B, H * W, C])
  308. # FFN
  309. x = shortcut + self.drop_path(x)
  310. x = x + self.drop_path(self.mlp(self.norm2(x)))
  311. return x
  312. class PatchMerging(nn.Layer):
  313. r""" Patch Merging Layer.
  314. Args:
  315. dim (int): Number of input channels.
  316. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  317. """
  318. def __init__(self, dim, norm_layer=nn.LayerNorm):
  319. super().__init__()
  320. self.dim = dim
  321. self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
  322. self.norm = norm_layer(4 * dim)
  323. def forward(self, x, H, W):
  324. """ Forward function.
  325. Args:
  326. x: Input feature, tensor size (B, H*W, C).
  327. H, W: Spatial resolution of the input feature.
  328. """
  329. B, L, C = x.shape
  330. assert L == H * W, "input feature has wrong size"
  331. x = x.reshape([B, H, W, C])
  332. # padding
  333. pad_input = (H % 2 == 1) or (W % 2 == 1)
  334. if pad_input:
  335. x = F.pad(x, [0, 0, 0, W % 2, 0, H % 2])
  336. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  337. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  338. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  339. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  340. x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  341. x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
  342. x = self.norm(x)
  343. x = self.reduction(x)
  344. return x
  345. class BasicLayer(nn.Layer):
  346. """ A basic Swin Transformer layer for one stage.
  347. Args:
  348. dim (int): Number of input channels.
  349. input_resolution (tuple[int]): Input resolution.
  350. depth (int): Number of blocks.
  351. num_heads (int): Number of attention heads.
  352. window_size (int): Local window size.
  353. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  354. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  355. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  356. drop (float, optional): Dropout rate. Default: 0.0
  357. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  358. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  359. norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
  360. downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
  361. """
  362. def __init__(self,
  363. dim,
  364. depth,
  365. num_heads,
  366. window_size=7,
  367. mlp_ratio=4.,
  368. qkv_bias=True,
  369. qk_scale=None,
  370. drop=0.,
  371. attn_drop=0.,
  372. drop_path=0.,
  373. norm_layer=nn.LayerNorm,
  374. downsample=None):
  375. super().__init__()
  376. self.window_size = window_size
  377. self.shift_size = window_size // 2
  378. self.depth = depth
  379. # build blocks
  380. self.blocks = nn.LayerList([
  381. SwinTransformerBlock(
  382. dim=dim,
  383. num_heads=num_heads,
  384. window_size=window_size,
  385. shift_size=0 if (i % 2 == 0) else window_size // 2,
  386. mlp_ratio=mlp_ratio,
  387. qkv_bias=qkv_bias,
  388. qk_scale=qk_scale,
  389. drop=drop,
  390. attn_drop=attn_drop,
  391. drop_path=drop_path[i]
  392. if isinstance(drop_path, np.ndarray) else drop_path,
  393. norm_layer=norm_layer) for i in range(depth)
  394. ])
  395. # patch merging layer
  396. if downsample is not None:
  397. self.downsample = downsample(dim=dim, norm_layer=norm_layer)
  398. else:
  399. self.downsample = None
  400. def forward(self, x, H, W):
  401. """ Forward function.
  402. Args:
  403. x: Input feature, tensor size (B, H*W, C).
  404. H, W: Spatial resolution of the input feature.
  405. """
  406. # calculate attention mask for SW-MSA
  407. Hp = int(np.ceil(H / self.window_size)) * self.window_size
  408. Wp = int(np.ceil(W / self.window_size)) * self.window_size
  409. img_mask = paddle.fluid.layers.zeros(
  410. [1, Hp, Wp, 1], dtype='float32') # 1 Hp Wp 1
  411. h_slices = (slice(0, -self.window_size),
  412. slice(-self.window_size, -self.shift_size),
  413. slice(-self.shift_size, None))
  414. w_slices = (slice(0, -self.window_size),
  415. slice(-self.window_size, -self.shift_size),
  416. slice(-self.shift_size, None))
  417. cnt = 0
  418. for h in h_slices:
  419. for w in w_slices:
  420. img_mask[:, h, w, :] = cnt
  421. cnt += 1
  422. mask_windows = window_partition(
  423. img_mask, self.window_size) # nW, window_size, window_size, 1
  424. mask_windows = mask_windows.reshape(
  425. [-1, self.window_size * self.window_size])
  426. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  427. huns = -100.0 * paddle.ones_like(attn_mask)
  428. attn_mask = huns * (attn_mask != 0).astype("float32")
  429. for blk in self.blocks:
  430. blk.H, blk.W = H, W
  431. x = blk(x, attn_mask)
  432. if self.downsample is not None:
  433. x_down = self.downsample(x, H, W)
  434. Wh, Ww = (H + 1) // 2, (W + 1) // 2
  435. return x, H, W, x_down, Wh, Ww
  436. else:
  437. return x, H, W, x, H, W
  438. class PatchEmbed(nn.Layer):
  439. """ Image to Patch Embedding
  440. Args:
  441. patch_size (int): Patch token size. Default: 4.
  442. in_chans (int): Number of input image channels. Default: 3.
  443. embed_dim (int): Number of linear projection output channels. Default: 96.
  444. norm_layer (nn.Layer, optional): Normalization layer. Default: None
  445. """
  446. def __init__(self, patch_size=4, in_chans=3, embed_dim=96,
  447. norm_layer=None):
  448. super().__init__()
  449. patch_size = to_2tuple(patch_size)
  450. self.patch_size = patch_size
  451. self.in_chans = in_chans
  452. self.embed_dim = embed_dim
  453. self.proj = nn.Conv2D(
  454. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  455. if norm_layer is not None:
  456. self.norm = norm_layer(embed_dim)
  457. else:
  458. self.norm = None
  459. def forward(self, x):
  460. B, C, H, W = x.shape
  461. # assert [H, W] == self.img_size[:2], "Input image size ({H}*{W}) doesn't match model ({}*{}).".format(H, W, self.img_size[0], self.img_size[1])
  462. if W % self.patch_size[1] != 0:
  463. x = F.pad(x, [0, self.patch_size[1] - W % self.patch_size[1]])
  464. if H % self.patch_size[0] != 0:
  465. x = F.pad(x,
  466. [0, 0, 0, self.patch_size[0] - H % self.patch_size[0]])
  467. x = self.proj(x)
  468. if self.norm is not None:
  469. _, _, Wh, Ww = x.shape
  470. x = x.flatten(2).transpose([0, 2, 1])
  471. x = self.norm(x)
  472. x = x.transpose([0, 2, 1]).reshape([-1, self.embed_dim, Wh, Ww])
  473. return x
  474. @register
  475. @serializable
  476. class SwinTransformer(nn.Layer):
  477. """ Swin Transformer
  478. A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  479. https://arxiv.org/pdf/2103.14030
  480. Args:
  481. img_size (int | tuple(int)): Input image size. Default 224
  482. patch_size (int | tuple(int)): Patch size. Default: 4
  483. in_chans (int): Number of input image channels. Default: 3
  484. num_classes (int): Number of classes for classification head. Default: 1000
  485. embed_dim (int): Patch embedding dimension. Default: 96
  486. depths (tuple(int)): Depth of each Swin Transformer layer.
  487. num_heads (tuple(int)): Number of attention heads in different layers.
  488. window_size (int): Window size. Default: 7
  489. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  490. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  491. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  492. drop_rate (float): Dropout rate. Default: 0
  493. attn_drop_rate (float): Attention dropout rate. Default: 0
  494. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  495. norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm.
  496. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
  497. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  498. """
  499. def __init__(self,
  500. pretrain_img_size=224,
  501. patch_size=4,
  502. in_chans=3,
  503. embed_dim=96,
  504. depths=[2, 2, 6, 2],
  505. num_heads=[3, 6, 12, 24],
  506. window_size=7,
  507. mlp_ratio=4.,
  508. qkv_bias=True,
  509. qk_scale=None,
  510. drop_rate=0.,
  511. attn_drop_rate=0.,
  512. drop_path_rate=0.2,
  513. norm_layer=nn.LayerNorm,
  514. ape=False,
  515. patch_norm=True,
  516. out_indices=(0, 1, 2, 3),
  517. frozen_stages=-1,
  518. pretrained=None):
  519. super(SwinTransformer, self).__init__()
  520. self.pretrain_img_size = pretrain_img_size
  521. self.num_layers = len(depths)
  522. self.embed_dim = embed_dim
  523. self.ape = ape
  524. self.patch_norm = patch_norm
  525. self.out_indices = out_indices
  526. self.frozen_stages = frozen_stages
  527. # split image into non-overlapping patches
  528. self.patch_embed = PatchEmbed(
  529. patch_size=patch_size,
  530. in_chans=in_chans,
  531. embed_dim=embed_dim,
  532. norm_layer=norm_layer if self.patch_norm else None)
  533. # absolute position embedding
  534. if self.ape:
  535. pretrain_img_size = to_2tuple(pretrain_img_size)
  536. patch_size = to_2tuple(patch_size)
  537. patches_resolution = [
  538. pretrain_img_size[0] // patch_size[0],
  539. pretrain_img_size[1] // patch_size[1]
  540. ]
  541. self.absolute_pos_embed = add_parameter(
  542. self,
  543. paddle.zeros((1, embed_dim, patches_resolution[0],
  544. patches_resolution[1])))
  545. trunc_normal_(self.absolute_pos_embed)
  546. self.pos_drop = nn.Dropout(p=drop_rate)
  547. # stochastic depth
  548. dpr = np.linspace(0, drop_path_rate,
  549. sum(depths)) # stochastic depth decay rule
  550. # build layers
  551. self.layers = nn.LayerList()
  552. for i_layer in range(self.num_layers):
  553. layer = BasicLayer(
  554. dim=int(embed_dim * 2**i_layer),
  555. depth=depths[i_layer],
  556. num_heads=num_heads[i_layer],
  557. window_size=window_size,
  558. mlp_ratio=mlp_ratio,
  559. qkv_bias=qkv_bias,
  560. qk_scale=qk_scale,
  561. drop=drop_rate,
  562. attn_drop=attn_drop_rate,
  563. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  564. norm_layer=norm_layer,
  565. downsample=PatchMerging
  566. if (i_layer < self.num_layers - 1) else None)
  567. self.layers.append(layer)
  568. num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
  569. self.num_features = num_features
  570. # add a norm layer for each output
  571. for i_layer in out_indices:
  572. layer = norm_layer(num_features[i_layer])
  573. layer_name = f'norm{i_layer}'
  574. self.add_sublayer(layer_name, layer)
  575. self.apply(self._init_weights)
  576. self._freeze_stages()
  577. if pretrained:
  578. if 'http' in pretrained: #URL
  579. path = paddle.utils.download.get_weights_path_from_url(
  580. pretrained)
  581. else: #model in local path
  582. path = pretrained
  583. self.set_state_dict(paddle.load(path))
  584. def _freeze_stages(self):
  585. if self.frozen_stages >= 0:
  586. self.patch_embed.eval()
  587. for param in self.patch_embed.parameters():
  588. param.requires_grad = False
  589. if self.frozen_stages >= 1 and self.ape:
  590. self.absolute_pos_embed.requires_grad = False
  591. if self.frozen_stages >= 2:
  592. self.pos_drop.eval()
  593. for i in range(0, self.frozen_stages - 1):
  594. m = self.layers[i]
  595. m.eval()
  596. for param in m.parameters():
  597. param.requires_grad = False
  598. def _init_weights(self, m):
  599. if isinstance(m, nn.Linear):
  600. trunc_normal_(m.weight)
  601. if isinstance(m, nn.Linear) and m.bias is not None:
  602. zeros_(m.bias)
  603. elif isinstance(m, nn.LayerNorm):
  604. zeros_(m.bias)
  605. ones_(m.weight)
  606. def forward(self, x):
  607. """Forward function."""
  608. x = self.patch_embed(x['image'])
  609. _, _, Wh, Ww = x.shape
  610. if self.ape:
  611. # interpolate the position embedding to the corresponding size
  612. absolute_pos_embed = F.interpolate(
  613. self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
  614. x = (x + absolute_pos_embed).flatten(2).transpose([0, 2, 1])
  615. else:
  616. x = x.flatten(2).transpose([0, 2, 1])
  617. x = self.pos_drop(x)
  618. outs = []
  619. for i in range(self.num_layers):
  620. layer = self.layers[i]
  621. x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
  622. if i in self.out_indices:
  623. norm_layer = getattr(self, f'norm{i}')
  624. x_out = norm_layer(x_out)
  625. out = x_out.reshape(
  626. (-1, H, W, self.num_features[i])).transpose((0, 3, 1, 2))
  627. outs.append(out)
  628. return tuple(outs)
  629. @property
  630. def out_shape(self):
  631. out_strides = [4, 8, 16, 32]
  632. return [
  633. ShapeSpec(
  634. channels=self.num_features[i], stride=out_strides[i])
  635. for i in self.out_indices
  636. ]