resnest.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. import math
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle import ParamAttr
  23. from paddle.nn.initializer import KaimingNormal
  24. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  25. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  26. from paddle.regularizer import L2Decay
  27. __all__ = ["ResNeSt50_fast_1s1x64d", "ResNeSt50", "ResNeSt101"]
  28. class ConvBNLayer(nn.Layer):
  29. def __init__(self,
  30. num_channels,
  31. num_filters,
  32. filter_size,
  33. stride=1,
  34. dilation=1,
  35. groups=1,
  36. act=None,
  37. name=None):
  38. super(ConvBNLayer, self).__init__()
  39. bn_decay = 0.0
  40. self._conv = Conv2D(
  41. in_channels=num_channels,
  42. out_channels=num_filters,
  43. kernel_size=filter_size,
  44. stride=stride,
  45. padding=(filter_size - 1) // 2,
  46. dilation=dilation,
  47. groups=groups,
  48. weight_attr=ParamAttr(name=name + "_weight"),
  49. bias_attr=False)
  50. self._batch_norm = BatchNorm(
  51. num_filters,
  52. act=act,
  53. param_attr=ParamAttr(
  54. name=name + "_scale", regularizer=L2Decay(bn_decay)),
  55. bias_attr=ParamAttr(
  56. name + "_offset", regularizer=L2Decay(bn_decay)),
  57. moving_mean_name=name + "_mean",
  58. moving_variance_name=name + "_variance")
  59. def forward(self, x):
  60. x = self._conv(x)
  61. x = self._batch_norm(x)
  62. return x
  63. class rSoftmax(nn.Layer):
  64. def __init__(self, radix, cardinality):
  65. super(rSoftmax, self).__init__()
  66. self.radix = radix
  67. self.cardinality = cardinality
  68. def forward(self, x):
  69. cardinality = self.cardinality
  70. radix = self.radix
  71. batch, r, h, w = x.shape
  72. if self.radix > 1:
  73. x = paddle.reshape(
  74. x=x,
  75. shape=[
  76. batch, cardinality, radix,
  77. int(r * h * w / cardinality / radix)
  78. ])
  79. x = paddle.transpose(x=x, perm=[0, 2, 1, 3])
  80. x = nn.functional.softmax(x, axis=1)
  81. x = paddle.reshape(x=x, shape=[batch, r * h * w, 1, 1])
  82. else:
  83. x = nn.functional.sigmoid(x)
  84. return x
  85. class SplatConv(nn.Layer):
  86. def __init__(self,
  87. in_channels,
  88. channels,
  89. kernel_size,
  90. stride=1,
  91. padding=0,
  92. dilation=1,
  93. groups=1,
  94. bias=True,
  95. radix=2,
  96. reduction_factor=4,
  97. rectify_avg=False,
  98. name=None):
  99. super(SplatConv, self).__init__()
  100. self.radix = radix
  101. self.conv1 = ConvBNLayer(
  102. num_channels=in_channels,
  103. num_filters=channels * radix,
  104. filter_size=kernel_size,
  105. stride=stride,
  106. groups=groups * radix,
  107. act="relu",
  108. name=name + "_1_weights")
  109. self.avg_pool2d = AdaptiveAvgPool2D(1)
  110. inter_channels = int(max(in_channels * radix // reduction_factor, 32))
  111. # to calc gap
  112. self.conv2 = ConvBNLayer(
  113. num_channels=channels,
  114. num_filters=inter_channels,
  115. filter_size=1,
  116. stride=1,
  117. groups=groups,
  118. act="relu",
  119. name=name + "_2_weights")
  120. # to calc atten
  121. self.conv3 = Conv2D(
  122. in_channels=inter_channels,
  123. out_channels=channels * radix,
  124. kernel_size=1,
  125. stride=1,
  126. padding=0,
  127. groups=groups,
  128. weight_attr=ParamAttr(
  129. name=name + "_weights", initializer=KaimingNormal()),
  130. bias_attr=False)
  131. self.rsoftmax = rSoftmax(radix=radix, cardinality=groups)
  132. def forward(self, x):
  133. x = self.conv1(x)
  134. if self.radix > 1:
  135. splited = paddle.split(x, num_or_sections=self.radix, axis=1)
  136. gap = paddle.add_n(splited)
  137. else:
  138. gap = x
  139. gap = self.avg_pool2d(gap)
  140. gap = self.conv2(gap)
  141. atten = self.conv3(gap)
  142. atten = self.rsoftmax(atten)
  143. if self.radix > 1:
  144. attens = paddle.split(atten, num_or_sections=self.radix, axis=1)
  145. y = paddle.add_n([
  146. paddle.multiply(split, att)
  147. for (att, split) in zip(attens, splited)
  148. ])
  149. else:
  150. y = paddle.multiply(x, atten)
  151. return y
  152. class BottleneckBlock(nn.Layer):
  153. def __init__(self,
  154. inplanes,
  155. planes,
  156. stride=1,
  157. radix=1,
  158. cardinality=1,
  159. bottleneck_width=64,
  160. avd=False,
  161. avd_first=False,
  162. dilation=1,
  163. is_first=False,
  164. rectify_avg=False,
  165. last_gamma=False,
  166. avg_down=False,
  167. name=None):
  168. super(BottleneckBlock, self).__init__()
  169. self.inplanes = inplanes
  170. self.planes = planes
  171. self.stride = stride
  172. self.radix = radix
  173. self.cardinality = cardinality
  174. self.avd = avd
  175. self.avd_first = avd_first
  176. self.dilation = dilation
  177. self.is_first = is_first
  178. self.rectify_avg = rectify_avg
  179. self.last_gamma = last_gamma
  180. self.avg_down = avg_down
  181. group_width = int(planes * (bottleneck_width / 64.)) * cardinality
  182. self.conv1 = ConvBNLayer(
  183. num_channels=self.inplanes,
  184. num_filters=group_width,
  185. filter_size=1,
  186. stride=1,
  187. groups=1,
  188. act="relu",
  189. name=name + "_conv1")
  190. if avd and avd_first and (stride > 1 or is_first):
  191. self.avg_pool2d_1 = AvgPool2D(
  192. kernel_size=3, stride=stride, padding=1)
  193. if radix >= 1:
  194. self.conv2 = SplatConv(
  195. in_channels=group_width,
  196. channels=group_width,
  197. kernel_size=3,
  198. stride=1,
  199. padding=dilation,
  200. dilation=dilation,
  201. groups=cardinality,
  202. bias=False,
  203. radix=radix,
  204. rectify_avg=rectify_avg,
  205. name=name + "_splat")
  206. else:
  207. self.conv2 = ConvBNLayer(
  208. num_channels=group_width,
  209. num_filters=group_width,
  210. filter_size=3,
  211. stride=1,
  212. dilation=dilation,
  213. groups=cardinality,
  214. act="relu",
  215. name=name + "_conv2")
  216. if avd and avd_first == False and (stride > 1 or is_first):
  217. self.avg_pool2d_2 = AvgPool2D(
  218. kernel_size=3, stride=stride, padding=1)
  219. self.conv3 = ConvBNLayer(
  220. num_channels=group_width,
  221. num_filters=planes * 4,
  222. filter_size=1,
  223. stride=1,
  224. groups=1,
  225. act=None,
  226. name=name + "_conv3")
  227. if stride != 1 or self.inplanes != self.planes * 4:
  228. if avg_down:
  229. if dilation == 1:
  230. self.avg_pool2d_3 = AvgPool2D(
  231. kernel_size=stride, stride=stride, padding=0)
  232. else:
  233. self.avg_pool2d_3 = AvgPool2D(
  234. kernel_size=1, stride=1, padding=0, ceil_mode=True)
  235. self.conv4 = Conv2D(
  236. in_channels=self.inplanes,
  237. out_channels=planes * 4,
  238. kernel_size=1,
  239. stride=1,
  240. padding=0,
  241. groups=1,
  242. weight_attr=ParamAttr(
  243. name=name + "_weights", initializer=KaimingNormal()),
  244. bias_attr=False)
  245. else:
  246. self.conv4 = Conv2D(
  247. in_channels=self.inplanes,
  248. out_channels=planes * 4,
  249. kernel_size=1,
  250. stride=stride,
  251. padding=0,
  252. groups=1,
  253. weight_attr=ParamAttr(
  254. name=name + "_shortcut_weights",
  255. initializer=KaimingNormal()),
  256. bias_attr=False)
  257. bn_decay = 0.0
  258. self._batch_norm = BatchNorm(
  259. planes * 4,
  260. act=None,
  261. param_attr=ParamAttr(
  262. name=name + "_shortcut_scale",
  263. regularizer=L2Decay(bn_decay)),
  264. bias_attr=ParamAttr(
  265. name + "_shortcut_offset", regularizer=L2Decay(bn_decay)),
  266. moving_mean_name=name + "_shortcut_mean",
  267. moving_variance_name=name + "_shortcut_variance")
  268. def forward(self, x):
  269. short = x
  270. x = self.conv1(x)
  271. if self.avd and self.avd_first and (self.stride > 1 or self.is_first):
  272. x = self.avg_pool2d_1(x)
  273. x = self.conv2(x)
  274. if self.avd and self.avd_first == False and (self.stride > 1 or
  275. self.is_first):
  276. x = self.avg_pool2d_2(x)
  277. x = self.conv3(x)
  278. if self.stride != 1 or self.inplanes != self.planes * 4:
  279. if self.avg_down:
  280. short = self.avg_pool2d_3(short)
  281. short = self.conv4(short)
  282. short = self._batch_norm(short)
  283. y = paddle.add(x=short, y=x)
  284. y = F.relu(y)
  285. return y
  286. class ResNeStLayer(nn.Layer):
  287. def __init__(self,
  288. inplanes,
  289. planes,
  290. blocks,
  291. radix,
  292. cardinality,
  293. bottleneck_width,
  294. avg_down,
  295. avd,
  296. avd_first,
  297. rectify_avg,
  298. last_gamma,
  299. stride=1,
  300. dilation=1,
  301. is_first=True,
  302. name=None):
  303. super(ResNeStLayer, self).__init__()
  304. self.inplanes = inplanes
  305. self.planes = planes
  306. self.blocks = blocks
  307. self.radix = radix
  308. self.cardinality = cardinality
  309. self.bottleneck_width = bottleneck_width
  310. self.avg_down = avg_down
  311. self.avd = avd
  312. self.avd_first = avd_first
  313. self.rectify_avg = rectify_avg
  314. self.last_gamma = last_gamma
  315. self.is_first = is_first
  316. if dilation == 1 or dilation == 2:
  317. bottleneck_func = self.add_sublayer(
  318. name + "_bottleneck_0",
  319. BottleneckBlock(
  320. inplanes=self.inplanes,
  321. planes=planes,
  322. stride=stride,
  323. radix=radix,
  324. cardinality=cardinality,
  325. bottleneck_width=bottleneck_width,
  326. avg_down=self.avg_down,
  327. avd=avd,
  328. avd_first=avd_first,
  329. dilation=1,
  330. is_first=is_first,
  331. rectify_avg=rectify_avg,
  332. last_gamma=last_gamma,
  333. name=name + "_bottleneck_0"))
  334. elif dilation == 4:
  335. bottleneck_func = self.add_sublayer(
  336. name + "_bottleneck_0",
  337. BottleneckBlock(
  338. inplanes=self.inplanes,
  339. planes=planes,
  340. stride=stride,
  341. radix=radix,
  342. cardinality=cardinality,
  343. bottleneck_width=bottleneck_width,
  344. avg_down=self.avg_down,
  345. avd=avd,
  346. avd_first=avd_first,
  347. dilation=2,
  348. is_first=is_first,
  349. rectify_avg=rectify_avg,
  350. last_gamma=last_gamma,
  351. name=name + "_bottleneck_0"))
  352. else:
  353. raise RuntimeError("=>unknown dilation size")
  354. self.inplanes = planes * 4
  355. self.bottleneck_block_list = [bottleneck_func]
  356. for i in range(1, blocks):
  357. curr_name = name + "_bottleneck_" + str(i)
  358. bottleneck_func = self.add_sublayer(
  359. curr_name,
  360. BottleneckBlock(
  361. inplanes=self.inplanes,
  362. planes=planes,
  363. radix=radix,
  364. cardinality=cardinality,
  365. bottleneck_width=bottleneck_width,
  366. avg_down=self.avg_down,
  367. avd=avd,
  368. avd_first=avd_first,
  369. dilation=dilation,
  370. rectify_avg=rectify_avg,
  371. last_gamma=last_gamma,
  372. name=curr_name))
  373. self.bottleneck_block_list.append(bottleneck_func)
  374. def forward(self, x):
  375. for bottleneck_block in self.bottleneck_block_list:
  376. x = bottleneck_block(x)
  377. return x
  378. class ResNeSt(nn.Layer):
  379. def __init__(self,
  380. layers,
  381. radix=1,
  382. groups=1,
  383. bottleneck_width=64,
  384. dilated=False,
  385. dilation=1,
  386. deep_stem=False,
  387. stem_width=64,
  388. avg_down=False,
  389. rectify_avg=False,
  390. avd=False,
  391. avd_first=False,
  392. final_drop=0.0,
  393. last_gamma=False,
  394. class_dim=1000):
  395. super(ResNeSt, self).__init__()
  396. self.cardinality = groups
  397. self.bottleneck_width = bottleneck_width
  398. # ResNet-D params
  399. self.inplanes = stem_width * 2 if deep_stem else 64
  400. self.avg_down = avg_down
  401. self.last_gamma = last_gamma
  402. # ResNeSt params
  403. self.radix = radix
  404. self.avd = avd
  405. self.avd_first = avd_first
  406. self.deep_stem = deep_stem
  407. self.stem_width = stem_width
  408. self.layers = layers
  409. self.final_drop = final_drop
  410. self.dilated = dilated
  411. self.dilation = dilation
  412. self.rectify_avg = rectify_avg
  413. if self.deep_stem:
  414. self.stem = nn.Sequential(
  415. ("conv1", ConvBNLayer(
  416. num_channels=3,
  417. num_filters=stem_width,
  418. filter_size=3,
  419. stride=2,
  420. act="relu",
  421. name="conv1")), ("conv2", ConvBNLayer(
  422. num_channels=stem_width,
  423. num_filters=stem_width,
  424. filter_size=3,
  425. stride=1,
  426. act="relu",
  427. name="conv2")), ("conv3", ConvBNLayer(
  428. num_channels=stem_width,
  429. num_filters=stem_width * 2,
  430. filter_size=3,
  431. stride=1,
  432. act="relu",
  433. name="conv3")))
  434. else:
  435. self.stem = ConvBNLayer(
  436. num_channels=3,
  437. num_filters=stem_width,
  438. filter_size=7,
  439. stride=2,
  440. act="relu",
  441. name="conv1")
  442. self.max_pool2d = MaxPool2D(kernel_size=3, stride=2, padding=1)
  443. self.layer1 = ResNeStLayer(
  444. inplanes=self.stem_width * 2
  445. if self.deep_stem else self.stem_width,
  446. planes=64,
  447. blocks=self.layers[0],
  448. radix=radix,
  449. cardinality=self.cardinality,
  450. bottleneck_width=bottleneck_width,
  451. avg_down=self.avg_down,
  452. avd=avd,
  453. avd_first=avd_first,
  454. rectify_avg=rectify_avg,
  455. last_gamma=last_gamma,
  456. stride=1,
  457. dilation=1,
  458. is_first=False,
  459. name="layer1")
  460. # return
  461. self.layer2 = ResNeStLayer(
  462. inplanes=256,
  463. planes=128,
  464. blocks=self.layers[1],
  465. radix=radix,
  466. cardinality=self.cardinality,
  467. bottleneck_width=bottleneck_width,
  468. avg_down=self.avg_down,
  469. avd=avd,
  470. avd_first=avd_first,
  471. rectify_avg=rectify_avg,
  472. last_gamma=last_gamma,
  473. stride=2,
  474. name="layer2")
  475. if self.dilated or self.dilation == 4:
  476. self.layer3 = ResNeStLayer(
  477. inplanes=512,
  478. planes=256,
  479. blocks=self.layers[2],
  480. radix=radix,
  481. cardinality=self.cardinality,
  482. bottleneck_width=bottleneck_width,
  483. avg_down=self.avg_down,
  484. avd=avd,
  485. avd_first=avd_first,
  486. rectify_avg=rectify_avg,
  487. last_gamma=last_gamma,
  488. stride=1,
  489. dilation=2,
  490. name="layer3")
  491. self.layer4 = ResNeStLayer(
  492. inplanes=1024,
  493. planes=512,
  494. blocks=self.layers[3],
  495. radix=radix,
  496. cardinality=self.cardinality,
  497. bottleneck_width=bottleneck_width,
  498. avg_down=self.avg_down,
  499. avd=avd,
  500. avd_first=avd_first,
  501. rectify_avg=rectify_avg,
  502. last_gamma=last_gamma,
  503. stride=1,
  504. dilation=4,
  505. name="layer4")
  506. elif self.dilation == 2:
  507. self.layer3 = ResNeStLayer(
  508. inplanes=512,
  509. planes=256,
  510. blocks=self.layers[2],
  511. radix=radix,
  512. cardinality=self.cardinality,
  513. bottleneck_width=bottleneck_width,
  514. avg_down=self.avg_down,
  515. avd=avd,
  516. avd_first=avd_first,
  517. rectify_avg=rectify_avg,
  518. last_gamma=last_gamma,
  519. stride=2,
  520. dilation=1,
  521. name="layer3")
  522. self.layer4 = ResNeStLayer(
  523. inplanes=1024,
  524. planes=512,
  525. blocks=self.layers[3],
  526. radix=radix,
  527. cardinality=self.cardinality,
  528. bottleneck_width=bottleneck_width,
  529. avg_down=self.avg_down,
  530. avd=avd,
  531. avd_first=avd_first,
  532. rectify_avg=rectify_avg,
  533. last_gamma=last_gamma,
  534. stride=1,
  535. dilation=2,
  536. name="layer4")
  537. else:
  538. self.layer3 = ResNeStLayer(
  539. inplanes=512,
  540. planes=256,
  541. blocks=self.layers[2],
  542. radix=radix,
  543. cardinality=self.cardinality,
  544. bottleneck_width=bottleneck_width,
  545. avg_down=self.avg_down,
  546. avd=avd,
  547. avd_first=avd_first,
  548. rectify_avg=rectify_avg,
  549. last_gamma=last_gamma,
  550. stride=2,
  551. name="layer3")
  552. self.layer4 = ResNeStLayer(
  553. inplanes=1024,
  554. planes=512,
  555. blocks=self.layers[3],
  556. radix=radix,
  557. cardinality=self.cardinality,
  558. bottleneck_width=bottleneck_width,
  559. avg_down=self.avg_down,
  560. avd=avd,
  561. avd_first=avd_first,
  562. rectify_avg=rectify_avg,
  563. last_gamma=last_gamma,
  564. stride=2,
  565. name="layer4")
  566. self.pool2d_avg = AdaptiveAvgPool2D(1)
  567. self.out_channels = 2048
  568. stdv = 1.0 / math.sqrt(self.out_channels * 1.0)
  569. self.out = Linear(
  570. self.out_channels,
  571. class_dim,
  572. weight_attr=ParamAttr(
  573. initializer=nn.initializer.Uniform(-stdv, stdv),
  574. name="fc_weights"),
  575. bias_attr=ParamAttr(name="fc_offset"))
  576. def forward(self, x):
  577. x = self.stem(x)
  578. x = self.max_pool2d(x)
  579. x = self.layer1(x)
  580. x = self.layer2(x)
  581. x = self.layer3(x)
  582. x = self.layer4(x)
  583. x = self.pool2d_avg(x)
  584. x = paddle.reshape(x, shape=[-1, self.out_channels])
  585. x = self.out(x)
  586. return x
  587. def ResNeSt50_fast_1s1x64d(**args):
  588. model = ResNeSt(
  589. layers=[3, 4, 6, 3],
  590. radix=1,
  591. groups=1,
  592. bottleneck_width=64,
  593. deep_stem=True,
  594. stem_width=32,
  595. avg_down=True,
  596. avd=True,
  597. avd_first=True,
  598. final_drop=0.0,
  599. **args)
  600. return model
  601. def ResNeSt50(**args):
  602. model = ResNeSt(
  603. layers=[3, 4, 6, 3],
  604. radix=2,
  605. groups=1,
  606. bottleneck_width=64,
  607. deep_stem=True,
  608. stem_width=32,
  609. avg_down=True,
  610. avd=True,
  611. avd_first=False,
  612. final_drop=0.0,
  613. **args)
  614. return model
  615. def ResNeSt101(**args):
  616. model = ResNeSt(
  617. layers=[3, 4, 23, 3],
  618. radix=2,
  619. groups=1,
  620. bottleneck_width=64,
  621. deep_stem=True,
  622. stem_width=64,
  623. avg_down=True,
  624. avd=True,
  625. avd_first=False,
  626. final_drop=0.0,
  627. **args)
  628. return model