res2net.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from numbers import Integral
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddlex.ppdet.core.workspace import register, serializable
  19. from ..shape_spec import ShapeSpec
  20. from .resnet import ConvNormLayer
  21. __all__ = ['Res2Net', 'Res2NetC5']
  22. Res2Net_cfg = {
  23. 50: [3, 4, 6, 3],
  24. 101: [3, 4, 23, 3],
  25. 152: [3, 8, 36, 3],
  26. 200: [3, 12, 48, 3]
  27. }
  28. class BottleNeck(nn.Layer):
  29. def __init__(self,
  30. ch_in,
  31. ch_out,
  32. stride,
  33. shortcut,
  34. width,
  35. scales=4,
  36. variant='b',
  37. groups=1,
  38. lr=1.0,
  39. norm_type='bn',
  40. norm_decay=0.,
  41. freeze_norm=True,
  42. dcn_v2=False):
  43. super(BottleNeck, self).__init__()
  44. self.shortcut = shortcut
  45. self.scales = scales
  46. self.stride = stride
  47. if not shortcut:
  48. if variant == 'd' and stride == 2:
  49. self.branch1 = nn.Sequential()
  50. self.branch1.add_sublayer(
  51. 'pool',
  52. nn.AvgPool2D(
  53. kernel_size=2, stride=2, padding=0, ceil_mode=True))
  54. self.branch1.add_sublayer(
  55. 'conv',
  56. ConvNormLayer(
  57. ch_in=ch_in,
  58. ch_out=ch_out,
  59. filter_size=1,
  60. stride=1,
  61. norm_type=norm_type,
  62. norm_decay=norm_decay,
  63. freeze_norm=freeze_norm,
  64. lr=lr))
  65. else:
  66. self.branch1 = ConvNormLayer(
  67. ch_in=ch_in,
  68. ch_out=ch_out,
  69. filter_size=1,
  70. stride=stride,
  71. norm_type=norm_type,
  72. norm_decay=norm_decay,
  73. freeze_norm=freeze_norm,
  74. lr=lr)
  75. self.branch2a = ConvNormLayer(
  76. ch_in=ch_in,
  77. ch_out=width * scales,
  78. filter_size=1,
  79. stride=stride if variant == 'a' else 1,
  80. groups=1,
  81. act='relu',
  82. norm_type=norm_type,
  83. norm_decay=norm_decay,
  84. freeze_norm=freeze_norm,
  85. lr=lr)
  86. self.branch2b = nn.LayerList([
  87. ConvNormLayer(
  88. ch_in=width,
  89. ch_out=width,
  90. filter_size=3,
  91. stride=1 if variant == 'a' else stride,
  92. groups=groups,
  93. act='relu',
  94. norm_type=norm_type,
  95. norm_decay=norm_decay,
  96. freeze_norm=freeze_norm,
  97. lr=lr,
  98. dcn_v2=dcn_v2) for _ in range(self.scales - 1)
  99. ])
  100. self.branch2c = ConvNormLayer(
  101. ch_in=width * scales,
  102. ch_out=ch_out,
  103. filter_size=1,
  104. stride=1,
  105. groups=1,
  106. norm_type=norm_type,
  107. norm_decay=norm_decay,
  108. freeze_norm=freeze_norm,
  109. lr=lr)
  110. def forward(self, inputs):
  111. out = self.branch2a(inputs)
  112. feature_split = paddle.split(out, self.scales, 1)
  113. out_split = []
  114. for i in range(self.scales - 1):
  115. if i == 0 or self.stride == 2:
  116. out_split.append(self.branch2b[i](feature_split[i]))
  117. else:
  118. out_split.append(self.branch2b[i](paddle.add(feature_split[i],
  119. out_split[-1])))
  120. if self.stride == 1:
  121. out_split.append(feature_split[-1])
  122. else:
  123. out_split.append(
  124. F.avg_pool2d(feature_split[-1], 3, self.stride, 1))
  125. out = self.branch2c(paddle.concat(out_split, 1))
  126. if self.shortcut:
  127. short = inputs
  128. else:
  129. short = self.branch1(inputs)
  130. out = paddle.add(out, short)
  131. out = F.relu(out)
  132. return out
  133. class Blocks(nn.Layer):
  134. def __init__(self,
  135. ch_in,
  136. ch_out,
  137. count,
  138. stage_num,
  139. width,
  140. scales=4,
  141. variant='b',
  142. groups=1,
  143. lr=1.0,
  144. norm_type='bn',
  145. norm_decay=0.,
  146. freeze_norm=True,
  147. dcn_v2=False):
  148. super(Blocks, self).__init__()
  149. self.blocks = nn.Sequential()
  150. for i in range(count):
  151. self.blocks.add_sublayer(
  152. str(i),
  153. BottleNeck(
  154. ch_in=ch_in if i == 0 else ch_out,
  155. ch_out=ch_out,
  156. stride=2 if i == 0 and stage_num != 2 else 1,
  157. shortcut=False if i == 0 else True,
  158. width=width * (2**(stage_num - 2)),
  159. scales=scales,
  160. variant=variant,
  161. groups=groups,
  162. lr=lr,
  163. norm_type=norm_type,
  164. norm_decay=norm_decay,
  165. freeze_norm=freeze_norm,
  166. dcn_v2=dcn_v2))
  167. def forward(self, inputs):
  168. return self.blocks(inputs)
  169. @register
  170. @serializable
  171. class Res2Net(nn.Layer):
  172. """
  173. Res2Net, see https://arxiv.org/abs/1904.01169
  174. Args:
  175. depth (int): Res2Net depth, should be 50, 101, 152, 200.
  176. width (int): Res2Net width
  177. scales (int): Res2Net scale
  178. variant (str): Res2Net variant, supports 'a', 'b', 'c', 'd' currently
  179. lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
  180. lower learning rate ratio is need for pretrained model
  181. got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
  182. groups (int): The groups number of the Conv Layer.
  183. norm_type (str): normalization type, 'bn' or 'sync_bn'
  184. norm_decay (float): weight decay for normalization layer weights
  185. freeze_norm (bool): freeze normalization layers
  186. freeze_at (int): freeze the backbone at which stage
  187. return_idx (list): index of stages whose feature maps are returned,
  188. index 0 stands for res2
  189. dcn_v2_stages (list): index of stages who select deformable conv v2
  190. num_stages (int): number of stages created
  191. """
  192. __shared__ = ['norm_type']
  193. def __init__(self,
  194. depth=50,
  195. width=26,
  196. scales=4,
  197. variant='b',
  198. lr_mult_list=[1.0, 1.0, 1.0, 1.0],
  199. groups=1,
  200. norm_type='bn',
  201. norm_decay=0.,
  202. freeze_norm=True,
  203. freeze_at=0,
  204. return_idx=[0, 1, 2, 3],
  205. dcn_v2_stages=[-1],
  206. num_stages=4):
  207. super(Res2Net, self).__init__()
  208. self._model_type = 'Res2Net' if groups == 1 else 'Res2NeXt'
  209. assert depth in [50, 101, 152, 200], \
  210. "depth {} not in [50, 101, 152, 200]"
  211. assert variant in ['a', 'b', 'c', 'd'], "invalid Res2Net variant"
  212. assert num_stages >= 1 and num_stages <= 4
  213. self.depth = depth
  214. self.variant = variant
  215. self.norm_type = norm_type
  216. self.norm_decay = norm_decay
  217. self.freeze_norm = freeze_norm
  218. self.freeze_at = freeze_at
  219. if isinstance(return_idx, Integral):
  220. return_idx = [return_idx]
  221. assert max(return_idx) < num_stages, \
  222. 'the maximum return index must smaller than num_stages, ' \
  223. 'but received maximum return index is {} and num_stages ' \
  224. 'is {}'.format(max(return_idx), num_stages)
  225. self.return_idx = return_idx
  226. self.num_stages = num_stages
  227. assert len(lr_mult_list) == 4, \
  228. "lr_mult_list length must be 4 but got {}".format(len(lr_mult_list))
  229. if isinstance(dcn_v2_stages, Integral):
  230. dcn_v2_stages = [dcn_v2_stages]
  231. assert max(dcn_v2_stages) < num_stages
  232. self.dcn_v2_stages = dcn_v2_stages
  233. block_nums = Res2Net_cfg[depth]
  234. # C1 stage
  235. if self.variant in ['c', 'd']:
  236. conv_def = [
  237. [3, 32, 3, 2, "conv1_1"],
  238. [32, 32, 3, 1, "conv1_2"],
  239. [32, 64, 3, 1, "conv1_3"],
  240. ]
  241. else:
  242. conv_def = [[3, 64, 7, 2, "conv1"]]
  243. self.res1 = nn.Sequential()
  244. for (c_in, c_out, k, s, _name) in conv_def:
  245. self.res1.add_sublayer(
  246. _name,
  247. ConvNormLayer(
  248. ch_in=c_in,
  249. ch_out=c_out,
  250. filter_size=k,
  251. stride=s,
  252. groups=1,
  253. act='relu',
  254. norm_type=norm_type,
  255. norm_decay=norm_decay,
  256. freeze_norm=freeze_norm,
  257. lr=1.0))
  258. self._in_channels = [64, 256, 512, 1024]
  259. self._out_channels = [256, 512, 1024, 2048]
  260. self._out_strides = [4, 8, 16, 32]
  261. # C2-C5 stages
  262. self.res_layers = []
  263. for i in range(num_stages):
  264. lr_mult = lr_mult_list[i]
  265. stage_num = i + 2
  266. self.res_layers.append(
  267. self.add_sublayer(
  268. "res{}".format(stage_num),
  269. Blocks(
  270. self._in_channels[i],
  271. self._out_channels[i],
  272. count=block_nums[i],
  273. stage_num=stage_num,
  274. width=width,
  275. scales=scales,
  276. groups=groups,
  277. lr=lr_mult,
  278. norm_type=norm_type,
  279. norm_decay=norm_decay,
  280. freeze_norm=freeze_norm,
  281. dcn_v2=(i in self.dcn_v2_stages))))
  282. @property
  283. def out_shape(self):
  284. return [
  285. ShapeSpec(
  286. channels=self._out_channels[i], stride=self._out_strides[i])
  287. for i in self.return_idx
  288. ]
  289. def forward(self, inputs):
  290. x = inputs['image']
  291. res1 = self.res1(x)
  292. x = F.max_pool2d(res1, kernel_size=3, stride=2, padding=1)
  293. outs = []
  294. for idx, stage in enumerate(self.res_layers):
  295. x = stage(x)
  296. if idx == self.freeze_at:
  297. x.stop_gradient = True
  298. if idx in self.return_idx:
  299. outs.append(x)
  300. return outs
  301. @register
  302. class Res2NetC5(nn.Layer):
  303. def __init__(self, depth=50, width=26, scales=4, variant='b'):
  304. super(Res2NetC5, self).__init__()
  305. feat_in, feat_out = [1024, 2048]
  306. self.res5 = Blocks(
  307. feat_in,
  308. feat_out,
  309. count=3,
  310. stage_num=5,
  311. width=width,
  312. scales=scales,
  313. variant=variant)
  314. self.feat_out = feat_out
  315. @property
  316. def out_shape(self):
  317. return [ShapeSpec(
  318. channels=self.feat_out,
  319. stride=32, )]
  320. def forward(self, roi_feat, stage=0):
  321. y = self.res5(roi_feat)
  322. return y