resnet.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from numbers import Integral
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddlex.ppdet.core.workspace import register, serializable
  20. from paddle.regularizer import L2Decay
  21. from paddle.nn.initializer import Uniform
  22. from paddle import ParamAttr
  23. from paddle.nn.initializer import Constant
  24. from paddle.vision.ops import DeformConv2D
  25. from .name_adapter import NameAdapter
  26. from ..shape_spec import ShapeSpec
  27. __all__ = ['ResNet', 'Res5Head', 'Blocks', 'BasicBlock', 'BottleNeck']
  28. ResNet_cfg = {
  29. 18: [2, 2, 2, 2],
  30. 34: [3, 4, 6, 3],
  31. 50: [3, 4, 6, 3],
  32. 101: [3, 4, 23, 3],
  33. 152: [3, 8, 36, 3],
  34. }
  35. class ConvNormLayer(nn.Layer):
  36. def __init__(self,
  37. ch_in,
  38. ch_out,
  39. filter_size,
  40. stride,
  41. groups=1,
  42. act=None,
  43. norm_type='bn',
  44. norm_decay=0.,
  45. freeze_norm=True,
  46. lr=1.0,
  47. dcn_v2=False):
  48. super(ConvNormLayer, self).__init__()
  49. assert norm_type in ['bn', 'sync_bn']
  50. self.norm_type = norm_type
  51. self.act = act
  52. self.dcn_v2 = dcn_v2
  53. if not self.dcn_v2:
  54. self.conv = nn.Conv2D(
  55. in_channels=ch_in,
  56. out_channels=ch_out,
  57. kernel_size=filter_size,
  58. stride=stride,
  59. padding=(filter_size - 1) // 2,
  60. groups=groups,
  61. weight_attr=ParamAttr(learning_rate=lr),
  62. bias_attr=False)
  63. else:
  64. self.offset_channel = 2 * filter_size**2
  65. self.mask_channel = filter_size**2
  66. self.conv_offset = nn.Conv2D(
  67. in_channels=ch_in,
  68. out_channels=3 * filter_size**2,
  69. kernel_size=filter_size,
  70. stride=stride,
  71. padding=(filter_size - 1) // 2,
  72. weight_attr=ParamAttr(initializer=Constant(0.)),
  73. bias_attr=ParamAttr(initializer=Constant(0.)))
  74. self.conv = DeformConv2D(
  75. in_channels=ch_in,
  76. out_channels=ch_out,
  77. kernel_size=filter_size,
  78. stride=stride,
  79. padding=(filter_size - 1) // 2,
  80. dilation=1,
  81. groups=groups,
  82. weight_attr=ParamAttr(learning_rate=lr),
  83. bias_attr=False)
  84. norm_lr = 0. if freeze_norm else lr
  85. param_attr = ParamAttr(
  86. learning_rate=norm_lr,
  87. regularizer=L2Decay(norm_decay),
  88. trainable=False if freeze_norm else True)
  89. bias_attr = ParamAttr(
  90. learning_rate=norm_lr,
  91. regularizer=L2Decay(norm_decay),
  92. trainable=False if freeze_norm else True)
  93. global_stats = True if freeze_norm else False
  94. if norm_type == 'sync_bn':
  95. self.norm = nn.SyncBatchNorm(
  96. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  97. else:
  98. self.norm = nn.BatchNorm(
  99. ch_out,
  100. act=None,
  101. param_attr=param_attr,
  102. bias_attr=bias_attr,
  103. use_global_stats=global_stats)
  104. norm_params = self.norm.parameters()
  105. if freeze_norm:
  106. for param in norm_params:
  107. param.stop_gradient = True
  108. def forward(self, inputs):
  109. if not self.dcn_v2:
  110. out = self.conv(inputs)
  111. else:
  112. offset_mask = self.conv_offset(inputs)
  113. offset, mask = paddle.split(
  114. offset_mask,
  115. num_or_sections=[self.offset_channel, self.mask_channel],
  116. axis=1)
  117. mask = F.sigmoid(mask)
  118. out = self.conv(inputs, offset, mask=mask)
  119. if self.norm_type in ['bn', 'sync_bn']:
  120. out = self.norm(out)
  121. if self.act:
  122. out = getattr(F, self.act)(out)
  123. return out
  124. class SELayer(nn.Layer):
  125. def __init__(self, ch, reduction_ratio=16):
  126. super(SELayer, self).__init__()
  127. self.pool = nn.AdaptiveAvgPool2D(1)
  128. stdv = 1.0 / math.sqrt(ch)
  129. c_ = ch // reduction_ratio
  130. self.squeeze = nn.Linear(
  131. ch,
  132. c_,
  133. weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
  134. bias_attr=True)
  135. stdv = 1.0 / math.sqrt(c_)
  136. self.extract = nn.Linear(
  137. c_,
  138. ch,
  139. weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
  140. bias_attr=True)
  141. def forward(self, inputs):
  142. out = self.pool(inputs)
  143. out = paddle.squeeze(out, axis=[2, 3])
  144. out = self.squeeze(out)
  145. out = F.relu(out)
  146. out = self.extract(out)
  147. out = F.sigmoid(out)
  148. out = paddle.unsqueeze(out, axis=[2, 3])
  149. scale = out * inputs
  150. return scale
  151. class BasicBlock(nn.Layer):
  152. expansion = 1
  153. def __init__(self,
  154. ch_in,
  155. ch_out,
  156. stride,
  157. shortcut,
  158. variant='b',
  159. groups=1,
  160. base_width=64,
  161. lr=1.0,
  162. norm_type='bn',
  163. norm_decay=0.,
  164. freeze_norm=True,
  165. dcn_v2=False,
  166. std_senet=False):
  167. super(BasicBlock, self).__init__()
  168. assert groups == 1 and base_width == 64, 'BasicBlock only supports groups=1 and base_width=64'
  169. self.shortcut = shortcut
  170. if not shortcut:
  171. if variant == 'd' and stride == 2:
  172. self.short = nn.Sequential()
  173. self.short.add_sublayer(
  174. 'pool',
  175. nn.AvgPool2D(
  176. kernel_size=2, stride=2, padding=0, ceil_mode=True))
  177. self.short.add_sublayer(
  178. 'conv',
  179. ConvNormLayer(
  180. ch_in=ch_in,
  181. ch_out=ch_out,
  182. filter_size=1,
  183. stride=1,
  184. norm_type=norm_type,
  185. norm_decay=norm_decay,
  186. freeze_norm=freeze_norm,
  187. lr=lr))
  188. else:
  189. self.short = ConvNormLayer(
  190. ch_in=ch_in,
  191. ch_out=ch_out,
  192. filter_size=1,
  193. stride=stride,
  194. norm_type=norm_type,
  195. norm_decay=norm_decay,
  196. freeze_norm=freeze_norm,
  197. lr=lr)
  198. self.branch2a = ConvNormLayer(
  199. ch_in=ch_in,
  200. ch_out=ch_out,
  201. filter_size=3,
  202. stride=stride,
  203. act='relu',
  204. norm_type=norm_type,
  205. norm_decay=norm_decay,
  206. freeze_norm=freeze_norm,
  207. lr=lr)
  208. self.branch2b = ConvNormLayer(
  209. ch_in=ch_out,
  210. ch_out=ch_out,
  211. filter_size=3,
  212. stride=1,
  213. act=None,
  214. norm_type=norm_type,
  215. norm_decay=norm_decay,
  216. freeze_norm=freeze_norm,
  217. lr=lr,
  218. dcn_v2=dcn_v2)
  219. self.std_senet = std_senet
  220. if self.std_senet:
  221. self.se = SELayer(ch_out)
  222. def forward(self, inputs):
  223. out = self.branch2a(inputs)
  224. out = self.branch2b(out)
  225. if self.std_senet:
  226. out = self.se(out)
  227. if self.shortcut:
  228. short = inputs
  229. else:
  230. short = self.short(inputs)
  231. out = paddle.add(x=out, y=short)
  232. out = F.relu(out)
  233. return out
  234. class BottleNeck(nn.Layer):
  235. expansion = 4
  236. def __init__(self,
  237. ch_in,
  238. ch_out,
  239. stride,
  240. shortcut,
  241. variant='b',
  242. groups=1,
  243. base_width=4,
  244. lr=1.0,
  245. norm_type='bn',
  246. norm_decay=0.,
  247. freeze_norm=True,
  248. dcn_v2=False,
  249. std_senet=False):
  250. super(BottleNeck, self).__init__()
  251. if variant == 'a':
  252. stride1, stride2 = stride, 1
  253. else:
  254. stride1, stride2 = 1, stride
  255. # ResNeXt
  256. width = int(ch_out * (base_width / 64.)) * groups
  257. self.shortcut = shortcut
  258. if not shortcut:
  259. if variant == 'd' and stride == 2:
  260. self.short = nn.Sequential()
  261. self.short.add_sublayer(
  262. 'pool',
  263. nn.AvgPool2D(
  264. kernel_size=2, stride=2, padding=0, ceil_mode=True))
  265. self.short.add_sublayer(
  266. 'conv',
  267. ConvNormLayer(
  268. ch_in=ch_in,
  269. ch_out=ch_out * self.expansion,
  270. filter_size=1,
  271. stride=1,
  272. norm_type=norm_type,
  273. norm_decay=norm_decay,
  274. freeze_norm=freeze_norm,
  275. lr=lr))
  276. else:
  277. self.short = ConvNormLayer(
  278. ch_in=ch_in,
  279. ch_out=ch_out * self.expansion,
  280. filter_size=1,
  281. stride=stride,
  282. norm_type=norm_type,
  283. norm_decay=norm_decay,
  284. freeze_norm=freeze_norm,
  285. lr=lr)
  286. self.branch2a = ConvNormLayer(
  287. ch_in=ch_in,
  288. ch_out=width,
  289. filter_size=1,
  290. stride=stride1,
  291. groups=1,
  292. act='relu',
  293. norm_type=norm_type,
  294. norm_decay=norm_decay,
  295. freeze_norm=freeze_norm,
  296. lr=lr)
  297. self.branch2b = ConvNormLayer(
  298. ch_in=width,
  299. ch_out=width,
  300. filter_size=3,
  301. stride=stride2,
  302. groups=groups,
  303. act='relu',
  304. norm_type=norm_type,
  305. norm_decay=norm_decay,
  306. freeze_norm=freeze_norm,
  307. lr=lr,
  308. dcn_v2=dcn_v2)
  309. self.branch2c = ConvNormLayer(
  310. ch_in=width,
  311. ch_out=ch_out * self.expansion,
  312. filter_size=1,
  313. stride=1,
  314. groups=1,
  315. norm_type=norm_type,
  316. norm_decay=norm_decay,
  317. freeze_norm=freeze_norm,
  318. lr=lr)
  319. self.std_senet = std_senet
  320. if self.std_senet:
  321. self.se = SELayer(ch_out * self.expansion)
  322. def forward(self, inputs):
  323. out = self.branch2a(inputs)
  324. out = self.branch2b(out)
  325. out = self.branch2c(out)
  326. if self.std_senet:
  327. out = self.se(out)
  328. if self.shortcut:
  329. short = inputs
  330. else:
  331. short = self.short(inputs)
  332. out = paddle.add(x=out, y=short)
  333. out = F.relu(out)
  334. return out
  335. class Blocks(nn.Layer):
  336. def __init__(self,
  337. block,
  338. ch_in,
  339. ch_out,
  340. count,
  341. name_adapter,
  342. stage_num,
  343. variant='b',
  344. groups=1,
  345. base_width=64,
  346. lr=1.0,
  347. norm_type='bn',
  348. norm_decay=0.,
  349. freeze_norm=True,
  350. dcn_v2=False,
  351. std_senet=False):
  352. super(Blocks, self).__init__()
  353. self.blocks = []
  354. for i in range(count):
  355. conv_name = name_adapter.fix_layer_warp_name(stage_num, count, i)
  356. layer = self.add_sublayer(
  357. conv_name,
  358. block(
  359. ch_in=ch_in,
  360. ch_out=ch_out,
  361. stride=2 if i == 0 and stage_num != 2 else 1,
  362. shortcut=False if i == 0 else True,
  363. variant=variant,
  364. groups=groups,
  365. base_width=base_width,
  366. lr=lr,
  367. norm_type=norm_type,
  368. norm_decay=norm_decay,
  369. freeze_norm=freeze_norm,
  370. dcn_v2=dcn_v2,
  371. std_senet=std_senet))
  372. self.blocks.append(layer)
  373. if i == 0:
  374. ch_in = ch_out * block.expansion
  375. def forward(self, inputs):
  376. block_out = inputs
  377. for block in self.blocks:
  378. block_out = block(block_out)
  379. return block_out
  380. @register
  381. @serializable
  382. class ResNet(nn.Layer):
  383. __shared__ = ['norm_type']
  384. def __init__(self,
  385. depth=50,
  386. ch_in=64,
  387. variant='b',
  388. lr_mult_list=[1.0, 1.0, 1.0, 1.0],
  389. groups=1,
  390. base_width=64,
  391. norm_type='bn',
  392. norm_decay=0,
  393. freeze_norm=True,
  394. freeze_at=0,
  395. return_idx=[0, 1, 2, 3],
  396. dcn_v2_stages=[-1],
  397. num_stages=4,
  398. std_senet=False):
  399. """
  400. Residual Network, see https://arxiv.org/abs/1512.03385
  401. Args:
  402. depth (int): ResNet depth, should be 18, 34, 50, 101, 152.
  403. ch_in (int): output channel of first stage, default 64
  404. variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
  405. lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
  406. lower learning rate ratio is need for pretrained model
  407. got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
  408. groups (int): group convolution cardinality
  409. base_width (int): base width of each group convolution
  410. norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
  411. norm_decay (float): weight decay for normalization layer weights
  412. freeze_norm (bool): freeze normalization layers
  413. freeze_at (int): freeze the backbone at which stage
  414. return_idx (list): index of the stages whose feature maps are returned
  415. dcn_v2_stages (list): index of stages who select deformable conv v2
  416. num_stages (int): total num of stages
  417. std_senet (bool): whether use senet, default True
  418. """
  419. super(ResNet, self).__init__()
  420. self._model_type = 'ResNet' if groups == 1 else 'ResNeXt'
  421. assert num_stages >= 1 and num_stages <= 4
  422. self.depth = depth
  423. self.variant = variant
  424. self.groups = groups
  425. self.base_width = base_width
  426. self.norm_type = norm_type
  427. self.norm_decay = norm_decay
  428. self.freeze_norm = freeze_norm
  429. self.freeze_at = freeze_at
  430. if isinstance(return_idx, Integral):
  431. return_idx = [return_idx]
  432. assert max(return_idx) < num_stages, \
  433. 'the maximum return index must smaller than num_stages, ' \
  434. 'but received maximum return index is {} and num_stages ' \
  435. 'is {}'.format(max(return_idx), num_stages)
  436. self.return_idx = return_idx
  437. self.num_stages = num_stages
  438. assert len(lr_mult_list) == 4, \
  439. "lr_mult_list length must be 4 but got {}".format(len(lr_mult_list))
  440. if isinstance(dcn_v2_stages, Integral):
  441. dcn_v2_stages = [dcn_v2_stages]
  442. assert max(dcn_v2_stages) < num_stages
  443. if isinstance(dcn_v2_stages, Integral):
  444. dcn_v2_stages = [dcn_v2_stages]
  445. assert max(dcn_v2_stages) < num_stages
  446. self.dcn_v2_stages = dcn_v2_stages
  447. block_nums = ResNet_cfg[depth]
  448. na = NameAdapter(self)
  449. conv1_name = na.fix_c1_stage_name()
  450. if variant in ['c', 'd']:
  451. conv_def = [
  452. [3, ch_in // 2, 3, 2, "conv1_1"],
  453. [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
  454. [ch_in // 2, ch_in, 3, 1, "conv1_3"],
  455. ]
  456. else:
  457. conv_def = [[3, ch_in, 7, 2, conv1_name]]
  458. self.conv1 = nn.Sequential()
  459. for (c_in, c_out, k, s, _name) in conv_def:
  460. self.conv1.add_sublayer(
  461. _name,
  462. ConvNormLayer(
  463. ch_in=c_in,
  464. ch_out=c_out,
  465. filter_size=k,
  466. stride=s,
  467. groups=1,
  468. act='relu',
  469. norm_type=norm_type,
  470. norm_decay=norm_decay,
  471. freeze_norm=freeze_norm,
  472. lr=1.0))
  473. self.ch_in = ch_in
  474. ch_out_list = [64, 128, 256, 512]
  475. block = BottleNeck if depth >= 50 else BasicBlock
  476. self._out_channels = [block.expansion * v for v in ch_out_list]
  477. self._out_strides = [4, 8, 16, 32]
  478. self.res_layers = []
  479. for i in range(num_stages):
  480. lr_mult = lr_mult_list[i]
  481. stage_num = i + 2
  482. res_name = "res{}".format(stage_num)
  483. res_layer = self.add_sublayer(
  484. res_name,
  485. Blocks(
  486. block,
  487. self.ch_in,
  488. ch_out_list[i],
  489. count=block_nums[i],
  490. name_adapter=na,
  491. stage_num=stage_num,
  492. variant=variant,
  493. groups=groups,
  494. base_width=base_width,
  495. lr=lr_mult,
  496. norm_type=norm_type,
  497. norm_decay=norm_decay,
  498. freeze_norm=freeze_norm,
  499. dcn_v2=(i in self.dcn_v2_stages),
  500. std_senet=std_senet))
  501. self.res_layers.append(res_layer)
  502. self.ch_in = self._out_channels[i]
  503. if freeze_at >= 0:
  504. self._freeze_parameters(self.conv1)
  505. for i in range(min(freeze_at + 1, num_stages)):
  506. self._freeze_parameters(self.res_layers[i])
  507. def _freeze_parameters(self, m):
  508. for p in m.parameters():
  509. p.stop_gradient = True
  510. @property
  511. def out_shape(self):
  512. return [
  513. ShapeSpec(
  514. channels=self._out_channels[i], stride=self._out_strides[i])
  515. for i in self.return_idx
  516. ]
  517. def forward(self, inputs):
  518. x = inputs['image']
  519. conv1 = self.conv1(x)
  520. x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
  521. outs = []
  522. for idx, stage in enumerate(self.res_layers):
  523. x = stage(x)
  524. if idx in self.return_idx:
  525. outs.append(x)
  526. return outs
  527. @register
  528. class Res5Head(nn.Layer):
  529. def __init__(self, depth=50):
  530. super(Res5Head, self).__init__()
  531. feat_in, feat_out = [1024, 512]
  532. if depth < 50:
  533. feat_in = 256
  534. na = NameAdapter(self)
  535. block = BottleNeck if depth >= 50 else BasicBlock
  536. self.res5 = Blocks(
  537. block, feat_in, feat_out, count=3, name_adapter=na, stage_num=5)
  538. self.feat_out = feat_out if depth < 50 else feat_out * 4
  539. @property
  540. def out_shape(self):
  541. return [ShapeSpec(
  542. channels=self.feat_out,
  543. stride=16, )]
  544. def forward(self, roi_feat, stage=0):
  545. y = self.res5(roi_feat)
  546. return y