hrnet.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. from paddle import ParamAttr
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle.nn import Conv2D, BatchNorm, Linear
  23. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  24. from paddle.nn.initializer import Uniform
  25. import math
  26. __all__ = [
  27. "HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C", "HRNet_W44_C",
  28. "HRNet_W48_C", "HRNet_W64_C"
  29. ]
  30. class ConvBNLayer(nn.Layer):
  31. def __init__(self,
  32. num_channels,
  33. num_filters,
  34. filter_size,
  35. stride=1,
  36. groups=1,
  37. act="relu",
  38. name=None):
  39. super(ConvBNLayer, self).__init__()
  40. self._conv = Conv2D(
  41. in_channels=num_channels,
  42. out_channels=num_filters,
  43. kernel_size=filter_size,
  44. stride=stride,
  45. padding=(filter_size - 1) // 2,
  46. groups=groups,
  47. weight_attr=ParamAttr(name=name + "_weights"),
  48. bias_attr=False)
  49. bn_name = name + '_bn'
  50. self._batch_norm = BatchNorm(
  51. num_filters,
  52. act=act,
  53. param_attr=ParamAttr(name=bn_name + '_scale'),
  54. bias_attr=ParamAttr(bn_name + '_offset'),
  55. moving_mean_name=bn_name + '_mean',
  56. moving_variance_name=bn_name + '_variance')
  57. def forward(self, input):
  58. y = self._conv(input)
  59. y = self._batch_norm(y)
  60. return y
  61. class Layer1(nn.Layer):
  62. def __init__(self, num_channels, has_se=False, name=None):
  63. super(Layer1, self).__init__()
  64. self.bottleneck_block_list = []
  65. for i in range(4):
  66. bottleneck_block = self.add_sublayer(
  67. "bb_{}_{}".format(name, i + 1),
  68. BottleneckBlock(
  69. num_channels=num_channels if i == 0 else 256,
  70. num_filters=64,
  71. has_se=has_se,
  72. stride=1,
  73. downsample=True if i == 0 else False,
  74. name=name + '_' + str(i + 1)))
  75. self.bottleneck_block_list.append(bottleneck_block)
  76. def forward(self, input):
  77. conv = input
  78. for block_func in self.bottleneck_block_list:
  79. conv = block_func(conv)
  80. return conv
  81. class TransitionLayer(nn.Layer):
  82. def __init__(self, in_channels, out_channels, name=None):
  83. super(TransitionLayer, self).__init__()
  84. num_in = len(in_channels)
  85. num_out = len(out_channels)
  86. out = []
  87. self.conv_bn_func_list = []
  88. for i in range(num_out):
  89. residual = None
  90. if i < num_in:
  91. if in_channels[i] != out_channels[i]:
  92. residual = self.add_sublayer(
  93. "transition_{}_layer_{}".format(name, i + 1),
  94. ConvBNLayer(
  95. num_channels=in_channels[i],
  96. num_filters=out_channels[i],
  97. filter_size=3,
  98. name=name + '_layer_' + str(i + 1)))
  99. else:
  100. residual = self.add_sublayer(
  101. "transition_{}_layer_{}".format(name, i + 1),
  102. ConvBNLayer(
  103. num_channels=in_channels[-1],
  104. num_filters=out_channels[i],
  105. filter_size=3,
  106. stride=2,
  107. name=name + '_layer_' + str(i + 1)))
  108. self.conv_bn_func_list.append(residual)
  109. def forward(self, input):
  110. outs = []
  111. for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
  112. if conv_bn_func is None:
  113. outs.append(input[idx])
  114. else:
  115. if idx < len(input):
  116. outs.append(conv_bn_func(input[idx]))
  117. else:
  118. outs.append(conv_bn_func(input[-1]))
  119. return outs
  120. class Branches(nn.Layer):
  121. def __init__(self,
  122. block_num,
  123. in_channels,
  124. out_channels,
  125. has_se=False,
  126. name=None):
  127. super(Branches, self).__init__()
  128. self.basic_block_list = []
  129. for i in range(len(out_channels)):
  130. self.basic_block_list.append([])
  131. for j in range(block_num):
  132. in_ch = in_channels[i] if j == 0 else out_channels[i]
  133. basic_block_func = self.add_sublayer(
  134. "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
  135. BasicBlock(
  136. num_channels=in_ch,
  137. num_filters=out_channels[i],
  138. has_se=has_se,
  139. name=name + '_branch_layer_' + str(i + 1) + '_' +
  140. str(j + 1)))
  141. self.basic_block_list[i].append(basic_block_func)
  142. def forward(self, inputs):
  143. outs = []
  144. for idx, input in enumerate(inputs):
  145. conv = input
  146. basic_block_list = self.basic_block_list[idx]
  147. for basic_block_func in basic_block_list:
  148. conv = basic_block_func(conv)
  149. outs.append(conv)
  150. return outs
  151. class BottleneckBlock(nn.Layer):
  152. def __init__(self,
  153. num_channels,
  154. num_filters,
  155. has_se,
  156. stride=1,
  157. downsample=False,
  158. name=None):
  159. super(BottleneckBlock, self).__init__()
  160. self.has_se = has_se
  161. self.downsample = downsample
  162. self.conv1 = ConvBNLayer(
  163. num_channels=num_channels,
  164. num_filters=num_filters,
  165. filter_size=1,
  166. act="relu",
  167. name=name + "_conv1", )
  168. self.conv2 = ConvBNLayer(
  169. num_channels=num_filters,
  170. num_filters=num_filters,
  171. filter_size=3,
  172. stride=stride,
  173. act="relu",
  174. name=name + "_conv2")
  175. self.conv3 = ConvBNLayer(
  176. num_channels=num_filters,
  177. num_filters=num_filters * 4,
  178. filter_size=1,
  179. act=None,
  180. name=name + "_conv3")
  181. if self.downsample:
  182. self.conv_down = ConvBNLayer(
  183. num_channels=num_channels,
  184. num_filters=num_filters * 4,
  185. filter_size=1,
  186. act=None,
  187. name=name + "_downsample")
  188. if self.has_se:
  189. self.se = SELayer(
  190. num_channels=num_filters * 4,
  191. num_filters=num_filters * 4,
  192. reduction_ratio=16,
  193. name='fc' + name)
  194. def forward(self, input):
  195. residual = input
  196. conv1 = self.conv1(input)
  197. conv2 = self.conv2(conv1)
  198. conv3 = self.conv3(conv2)
  199. if self.downsample:
  200. residual = self.conv_down(input)
  201. if self.has_se:
  202. conv3 = self.se(conv3)
  203. y = paddle.add(x=residual, y=conv3)
  204. y = F.relu(y)
  205. return y
  206. class BasicBlock(nn.Layer):
  207. def __init__(self,
  208. num_channels,
  209. num_filters,
  210. stride=1,
  211. has_se=False,
  212. downsample=False,
  213. name=None):
  214. super(BasicBlock, self).__init__()
  215. self.has_se = has_se
  216. self.downsample = downsample
  217. self.conv1 = ConvBNLayer(
  218. num_channels=num_channels,
  219. num_filters=num_filters,
  220. filter_size=3,
  221. stride=stride,
  222. act="relu",
  223. name=name + "_conv1")
  224. self.conv2 = ConvBNLayer(
  225. num_channels=num_filters,
  226. num_filters=num_filters,
  227. filter_size=3,
  228. stride=1,
  229. act=None,
  230. name=name + "_conv2")
  231. if self.downsample:
  232. self.conv_down = ConvBNLayer(
  233. num_channels=num_channels,
  234. num_filters=num_filters * 4,
  235. filter_size=1,
  236. act="relu",
  237. name=name + "_downsample")
  238. if self.has_se:
  239. self.se = SELayer(
  240. num_channels=num_filters,
  241. num_filters=num_filters,
  242. reduction_ratio=16,
  243. name='fc' + name)
  244. def forward(self, input):
  245. residual = input
  246. conv1 = self.conv1(input)
  247. conv2 = self.conv2(conv1)
  248. if self.downsample:
  249. residual = self.conv_down(input)
  250. if self.has_se:
  251. conv2 = self.se(conv2)
  252. y = paddle.add(x=residual, y=conv2)
  253. y = F.relu(y)
  254. return y
  255. class SELayer(nn.Layer):
  256. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  257. super(SELayer, self).__init__()
  258. self.pool2d_gap = AdaptiveAvgPool2D(1)
  259. self._num_channels = num_channels
  260. med_ch = int(num_channels / reduction_ratio)
  261. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  262. self.squeeze = Linear(
  263. num_channels,
  264. med_ch,
  265. weight_attr=ParamAttr(
  266. initializer=Uniform(-stdv, stdv), name=name + "_sqz_weights"),
  267. bias_attr=ParamAttr(name=name + '_sqz_offset'))
  268. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  269. self.excitation = Linear(
  270. med_ch,
  271. num_filters,
  272. weight_attr=ParamAttr(
  273. initializer=Uniform(-stdv, stdv), name=name + "_exc_weights"),
  274. bias_attr=ParamAttr(name=name + '_exc_offset'))
  275. def forward(self, input):
  276. pool = self.pool2d_gap(input)
  277. pool = paddle.squeeze(pool, axis=[2, 3])
  278. squeeze = self.squeeze(pool)
  279. squeeze = F.relu(squeeze)
  280. excitation = self.excitation(squeeze)
  281. excitation = F.sigmoid(excitation)
  282. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  283. out = input * excitation
  284. return out
  285. class Stage(nn.Layer):
  286. def __init__(self,
  287. num_channels,
  288. num_modules,
  289. num_filters,
  290. has_se=False,
  291. multi_scale_output=True,
  292. name=None):
  293. super(Stage, self).__init__()
  294. self._num_modules = num_modules
  295. self.stage_func_list = []
  296. for i in range(num_modules):
  297. if i == num_modules - 1 and not multi_scale_output:
  298. stage_func = self.add_sublayer(
  299. "stage_{}_{}".format(name, i + 1),
  300. HighResolutionModule(
  301. num_channels=num_channels,
  302. num_filters=num_filters,
  303. has_se=has_se,
  304. multi_scale_output=False,
  305. name=name + '_' + str(i + 1)))
  306. else:
  307. stage_func = self.add_sublayer(
  308. "stage_{}_{}".format(name, i + 1),
  309. HighResolutionModule(
  310. num_channels=num_channels,
  311. num_filters=num_filters,
  312. has_se=has_se,
  313. name=name + '_' + str(i + 1)))
  314. self.stage_func_list.append(stage_func)
  315. def forward(self, input):
  316. out = input
  317. for idx in range(self._num_modules):
  318. out = self.stage_func_list[idx](out)
  319. return out
  320. class HighResolutionModule(nn.Layer):
  321. def __init__(self,
  322. num_channels,
  323. num_filters,
  324. has_se=False,
  325. multi_scale_output=True,
  326. name=None):
  327. super(HighResolutionModule, self).__init__()
  328. self.branches_func = Branches(
  329. block_num=4,
  330. in_channels=num_channels,
  331. out_channels=num_filters,
  332. has_se=has_se,
  333. name=name)
  334. self.fuse_func = FuseLayers(
  335. in_channels=num_filters,
  336. out_channels=num_filters,
  337. multi_scale_output=multi_scale_output,
  338. name=name)
  339. def forward(self, input):
  340. out = self.branches_func(input)
  341. out = self.fuse_func(out)
  342. return out
  343. class FuseLayers(nn.Layer):
  344. def __init__(self,
  345. in_channels,
  346. out_channels,
  347. multi_scale_output=True,
  348. name=None):
  349. super(FuseLayers, self).__init__()
  350. self._actual_ch = len(in_channels) if multi_scale_output else 1
  351. self._in_channels = in_channels
  352. self.residual_func_list = []
  353. for i in range(self._actual_ch):
  354. for j in range(len(in_channels)):
  355. residual_func = None
  356. if j > i:
  357. residual_func = self.add_sublayer(
  358. "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
  359. ConvBNLayer(
  360. num_channels=in_channels[j],
  361. num_filters=out_channels[i],
  362. filter_size=1,
  363. stride=1,
  364. act=None,
  365. name=name + '_layer_' + str(i + 1) + '_' +
  366. str(j + 1)))
  367. self.residual_func_list.append(residual_func)
  368. elif j < i:
  369. pre_num_filters = in_channels[j]
  370. for k in range(i - j):
  371. if k == i - j - 1:
  372. residual_func = self.add_sublayer(
  373. "residual_{}_layer_{}_{}_{}".format(
  374. name, i + 1, j + 1, k + 1),
  375. ConvBNLayer(
  376. num_channels=pre_num_filters,
  377. num_filters=out_channels[i],
  378. filter_size=3,
  379. stride=2,
  380. act=None,
  381. name=name + '_layer_' + str(i + 1) + '_' +
  382. str(j + 1) + '_' + str(k + 1)))
  383. pre_num_filters = out_channels[i]
  384. else:
  385. residual_func = self.add_sublayer(
  386. "residual_{}_layer_{}_{}_{}".format(
  387. name, i + 1, j + 1, k + 1),
  388. ConvBNLayer(
  389. num_channels=pre_num_filters,
  390. num_filters=out_channels[j],
  391. filter_size=3,
  392. stride=2,
  393. act="relu",
  394. name=name + '_layer_' + str(i + 1) + '_' +
  395. str(j + 1) + '_' + str(k + 1)))
  396. pre_num_filters = out_channels[j]
  397. self.residual_func_list.append(residual_func)
  398. def forward(self, input):
  399. outs = []
  400. residual_func_idx = 0
  401. for i in range(self._actual_ch):
  402. residual = input[i]
  403. for j in range(len(self._in_channels)):
  404. if j > i:
  405. y = self.residual_func_list[residual_func_idx](input[j])
  406. residual_func_idx += 1
  407. y = F.upsample(y, scale_factor=2**(j - i), mode="nearest")
  408. residual = paddle.add(x=residual, y=y)
  409. elif j < i:
  410. y = input[j]
  411. for k in range(i - j):
  412. y = self.residual_func_list[residual_func_idx](y)
  413. residual_func_idx += 1
  414. residual = paddle.add(x=residual, y=y)
  415. residual = F.relu(residual)
  416. outs.append(residual)
  417. return outs
  418. class LastClsOut(nn.Layer):
  419. def __init__(self,
  420. num_channel_list,
  421. has_se,
  422. num_filters_list=[32, 64, 128, 256],
  423. name=None):
  424. super(LastClsOut, self).__init__()
  425. self.func_list = []
  426. for idx in range(len(num_channel_list)):
  427. func = self.add_sublayer(
  428. "conv_{}_conv_{}".format(name, idx + 1),
  429. BottleneckBlock(
  430. num_channels=num_channel_list[idx],
  431. num_filters=num_filters_list[idx],
  432. has_se=has_se,
  433. downsample=True,
  434. name=name + 'conv_' + str(idx + 1)))
  435. self.func_list.append(func)
  436. def forward(self, inputs):
  437. outs = []
  438. for idx, input in enumerate(inputs):
  439. out = self.func_list[idx](input)
  440. outs.append(out)
  441. return outs
  442. class HRNet(nn.Layer):
  443. def __init__(self, width=18, has_se=False, class_dim=1000):
  444. super(HRNet, self).__init__()
  445. self.width = width
  446. self.has_se = has_se
  447. self.channels = {
  448. 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
  449. 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
  450. 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
  451. 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
  452. 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
  453. 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
  454. 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
  455. 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
  456. }
  457. self._class_dim = class_dim
  458. channels_2, channels_3, channels_4 = self.channels[width]
  459. num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
  460. self.conv_layer1_1 = ConvBNLayer(
  461. num_channels=3,
  462. num_filters=64,
  463. filter_size=3,
  464. stride=2,
  465. act='relu',
  466. name="layer1_1")
  467. self.conv_layer1_2 = ConvBNLayer(
  468. num_channels=64,
  469. num_filters=64,
  470. filter_size=3,
  471. stride=2,
  472. act='relu',
  473. name="layer1_2")
  474. self.la1 = Layer1(num_channels=64, has_se=has_se, name="layer2")
  475. self.tr1 = TransitionLayer(
  476. in_channels=[256], out_channels=channels_2, name="tr1")
  477. self.st2 = Stage(
  478. num_channels=channels_2,
  479. num_modules=num_modules_2,
  480. num_filters=channels_2,
  481. has_se=self.has_se,
  482. name="st2")
  483. self.tr2 = TransitionLayer(
  484. in_channels=channels_2, out_channels=channels_3, name="tr2")
  485. self.st3 = Stage(
  486. num_channels=channels_3,
  487. num_modules=num_modules_3,
  488. num_filters=channels_3,
  489. has_se=self.has_se,
  490. name="st3")
  491. self.tr3 = TransitionLayer(
  492. in_channels=channels_3, out_channels=channels_4, name="tr3")
  493. self.st4 = Stage(
  494. num_channels=channels_4,
  495. num_modules=num_modules_4,
  496. num_filters=channels_4,
  497. has_se=self.has_se,
  498. name="st4")
  499. # classification
  500. num_filters_list = [32, 64, 128, 256]
  501. self.last_cls = LastClsOut(
  502. num_channel_list=channels_4,
  503. has_se=self.has_se,
  504. num_filters_list=num_filters_list,
  505. name="cls_head", )
  506. last_num_filters = [256, 512, 1024]
  507. self.cls_head_conv_list = []
  508. for idx in range(3):
  509. self.cls_head_conv_list.append(
  510. self.add_sublayer(
  511. "cls_head_add{}".format(idx + 1),
  512. ConvBNLayer(
  513. num_channels=num_filters_list[idx] * 4,
  514. num_filters=last_num_filters[idx],
  515. filter_size=3,
  516. stride=2,
  517. name="cls_head_add" + str(idx + 1))))
  518. self.conv_last = ConvBNLayer(
  519. num_channels=1024,
  520. num_filters=2048,
  521. filter_size=1,
  522. stride=1,
  523. name="cls_head_last_conv")
  524. self.pool2d_avg = AdaptiveAvgPool2D(1)
  525. stdv = 1.0 / math.sqrt(2048 * 1.0)
  526. self.out = Linear(
  527. 2048,
  528. class_dim,
  529. weight_attr=ParamAttr(
  530. initializer=Uniform(-stdv, stdv), name="fc_weights"),
  531. bias_attr=ParamAttr(name="fc_offset"))
  532. def forward(self, input):
  533. conv1 = self.conv_layer1_1(input)
  534. conv2 = self.conv_layer1_2(conv1)
  535. la1 = self.la1(conv2)
  536. tr1 = self.tr1([la1])
  537. st2 = self.st2(tr1)
  538. tr2 = self.tr2(st2)
  539. st3 = self.st3(tr2)
  540. tr3 = self.tr3(st3)
  541. st4 = self.st4(tr3)
  542. last_cls = self.last_cls(st4)
  543. y = last_cls[0]
  544. for idx in range(3):
  545. y = paddle.add(last_cls[idx + 1], self.cls_head_conv_list[idx](y))
  546. y = self.conv_last(y)
  547. y = self.pool2d_avg(y)
  548. y = paddle.reshape(y, shape=[-1, y.shape[1]])
  549. y = self.out(y)
  550. return y
  551. def HRNet_W18_C(**args):
  552. model = HRNet(width=18, **args)
  553. return model
  554. def HRNet_W30_C(**args):
  555. model = HRNet(width=30, **args)
  556. return model
  557. def HRNet_W32_C(**args):
  558. model = HRNet(width=32, **args)
  559. return model
  560. def HRNet_W40_C(**args):
  561. model = HRNet(width=40, **args)
  562. return model
  563. def HRNet_W44_C(**args):
  564. model = HRNet(width=44, **args)
  565. return model
  566. def HRNet_W48_C(**args):
  567. model = HRNet(width=48, **args)
  568. return model
  569. def HRNet_W64_C(**args):
  570. model = HRNet(width=64, **args)
  571. return model