hrnet.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. from paddle import ParamAttr
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle.nn import Conv2D, BatchNorm, Linear
  23. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  24. from paddle.nn.initializer import Uniform
  25. import math
  26. __all__ = [
  27. "HRNet_W18_C",
  28. "HRNet_W30_C",
  29. "HRNet_W32_C",
  30. "HRNet_W40_C",
  31. "HRNet_W44_C",
  32. "HRNet_W48_C",
  33. "HRNet_W60_C",
  34. "HRNet_W64_C",
  35. "SE_HRNet_W18_C",
  36. "SE_HRNet_W30_C",
  37. "SE_HRNet_W32_C",
  38. "SE_HRNet_W40_C",
  39. "SE_HRNet_W44_C",
  40. "SE_HRNet_W48_C",
  41. "SE_HRNet_W60_C",
  42. "SE_HRNet_W64_C",
  43. ]
  44. class ConvBNLayer(nn.Layer):
  45. def __init__(self,
  46. num_channels,
  47. num_filters,
  48. filter_size,
  49. stride=1,
  50. groups=1,
  51. act="relu",
  52. name=None):
  53. super(ConvBNLayer, self).__init__()
  54. self._conv = Conv2D(
  55. in_channels=num_channels,
  56. out_channels=num_filters,
  57. kernel_size=filter_size,
  58. stride=stride,
  59. padding=(filter_size - 1) // 2,
  60. groups=groups,
  61. weight_attr=ParamAttr(name=name + "_weights"),
  62. bias_attr=False)
  63. bn_name = name + '_bn'
  64. self._batch_norm = BatchNorm(
  65. num_filters,
  66. act=act,
  67. param_attr=ParamAttr(name=bn_name + '_scale'),
  68. bias_attr=ParamAttr(bn_name + '_offset'),
  69. moving_mean_name=bn_name + '_mean',
  70. moving_variance_name=bn_name + '_variance')
  71. def forward(self, input):
  72. y = self._conv(input)
  73. y = self._batch_norm(y)
  74. return y
  75. class Layer1(nn.Layer):
  76. def __init__(self, num_channels, has_se=False, name=None):
  77. super(Layer1, self).__init__()
  78. self.bottleneck_block_list = []
  79. for i in range(4):
  80. bottleneck_block = self.add_sublayer(
  81. "bb_{}_{}".format(name, i + 1),
  82. BottleneckBlock(
  83. num_channels=num_channels if i == 0 else 256,
  84. num_filters=64,
  85. has_se=has_se,
  86. stride=1,
  87. downsample=True if i == 0 else False,
  88. name=name + '_' + str(i + 1)))
  89. self.bottleneck_block_list.append(bottleneck_block)
  90. def forward(self, input):
  91. conv = input
  92. for block_func in self.bottleneck_block_list:
  93. conv = block_func(conv)
  94. return conv
  95. class TransitionLayer(nn.Layer):
  96. def __init__(self, in_channels, out_channels, name=None):
  97. super(TransitionLayer, self).__init__()
  98. num_in = len(in_channels)
  99. num_out = len(out_channels)
  100. out = []
  101. self.conv_bn_func_list = []
  102. for i in range(num_out):
  103. residual = None
  104. if i < num_in:
  105. if in_channels[i] != out_channels[i]:
  106. residual = self.add_sublayer(
  107. "transition_{}_layer_{}".format(name, i + 1),
  108. ConvBNLayer(
  109. num_channels=in_channels[i],
  110. num_filters=out_channels[i],
  111. filter_size=3,
  112. name=name + '_layer_' + str(i + 1)))
  113. else:
  114. residual = self.add_sublayer(
  115. "transition_{}_layer_{}".format(name, i + 1),
  116. ConvBNLayer(
  117. num_channels=in_channels[-1],
  118. num_filters=out_channels[i],
  119. filter_size=3,
  120. stride=2,
  121. name=name + '_layer_' + str(i + 1)))
  122. self.conv_bn_func_list.append(residual)
  123. def forward(self, input):
  124. outs = []
  125. for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
  126. if conv_bn_func is None:
  127. outs.append(input[idx])
  128. else:
  129. if idx < len(input):
  130. outs.append(conv_bn_func(input[idx]))
  131. else:
  132. outs.append(conv_bn_func(input[-1]))
  133. return outs
  134. class Branches(nn.Layer):
  135. def __init__(self,
  136. block_num,
  137. in_channels,
  138. out_channels,
  139. has_se=False,
  140. name=None):
  141. super(Branches, self).__init__()
  142. self.basic_block_list = []
  143. for i in range(len(out_channels)):
  144. self.basic_block_list.append([])
  145. for j in range(block_num):
  146. in_ch = in_channels[i] if j == 0 else out_channels[i]
  147. basic_block_func = self.add_sublayer(
  148. "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
  149. BasicBlock(
  150. num_channels=in_ch,
  151. num_filters=out_channels[i],
  152. has_se=has_se,
  153. name=name + '_branch_layer_' + str(i + 1) + '_' +
  154. str(j + 1)))
  155. self.basic_block_list[i].append(basic_block_func)
  156. def forward(self, inputs):
  157. outs = []
  158. for idx, input in enumerate(inputs):
  159. conv = input
  160. basic_block_list = self.basic_block_list[idx]
  161. for basic_block_func in basic_block_list:
  162. conv = basic_block_func(conv)
  163. outs.append(conv)
  164. return outs
  165. class BottleneckBlock(nn.Layer):
  166. def __init__(self,
  167. num_channels,
  168. num_filters,
  169. has_se,
  170. stride=1,
  171. downsample=False,
  172. name=None):
  173. super(BottleneckBlock, self).__init__()
  174. self.has_se = has_se
  175. self.downsample = downsample
  176. self.conv1 = ConvBNLayer(
  177. num_channels=num_channels,
  178. num_filters=num_filters,
  179. filter_size=1,
  180. act="relu",
  181. name=name + "_conv1", )
  182. self.conv2 = ConvBNLayer(
  183. num_channels=num_filters,
  184. num_filters=num_filters,
  185. filter_size=3,
  186. stride=stride,
  187. act="relu",
  188. name=name + "_conv2")
  189. self.conv3 = ConvBNLayer(
  190. num_channels=num_filters,
  191. num_filters=num_filters * 4,
  192. filter_size=1,
  193. act=None,
  194. name=name + "_conv3")
  195. if self.downsample:
  196. self.conv_down = ConvBNLayer(
  197. num_channels=num_channels,
  198. num_filters=num_filters * 4,
  199. filter_size=1,
  200. act=None,
  201. name=name + "_downsample")
  202. if self.has_se:
  203. self.se = SELayer(
  204. num_channels=num_filters * 4,
  205. num_filters=num_filters * 4,
  206. reduction_ratio=16,
  207. name='fc' + name)
  208. def forward(self, input):
  209. residual = input
  210. conv1 = self.conv1(input)
  211. conv2 = self.conv2(conv1)
  212. conv3 = self.conv3(conv2)
  213. if self.downsample:
  214. residual = self.conv_down(input)
  215. if self.has_se:
  216. conv3 = self.se(conv3)
  217. y = paddle.add(x=residual, y=conv3)
  218. y = F.relu(y)
  219. return y
  220. class BasicBlock(nn.Layer):
  221. def __init__(self,
  222. num_channels,
  223. num_filters,
  224. stride=1,
  225. has_se=False,
  226. downsample=False,
  227. name=None):
  228. super(BasicBlock, self).__init__()
  229. self.has_se = has_se
  230. self.downsample = downsample
  231. self.conv1 = ConvBNLayer(
  232. num_channels=num_channels,
  233. num_filters=num_filters,
  234. filter_size=3,
  235. stride=stride,
  236. act="relu",
  237. name=name + "_conv1")
  238. self.conv2 = ConvBNLayer(
  239. num_channels=num_filters,
  240. num_filters=num_filters,
  241. filter_size=3,
  242. stride=1,
  243. act=None,
  244. name=name + "_conv2")
  245. if self.downsample:
  246. self.conv_down = ConvBNLayer(
  247. num_channels=num_channels,
  248. num_filters=num_filters * 4,
  249. filter_size=1,
  250. act="relu",
  251. name=name + "_downsample")
  252. if self.has_se:
  253. self.se = SELayer(
  254. num_channels=num_filters,
  255. num_filters=num_filters,
  256. reduction_ratio=16,
  257. name='fc' + name)
  258. def forward(self, input):
  259. residual = input
  260. conv1 = self.conv1(input)
  261. conv2 = self.conv2(conv1)
  262. if self.downsample:
  263. residual = self.conv_down(input)
  264. if self.has_se:
  265. conv2 = self.se(conv2)
  266. y = paddle.add(x=residual, y=conv2)
  267. y = F.relu(y)
  268. return y
  269. class SELayer(nn.Layer):
  270. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  271. super(SELayer, self).__init__()
  272. self.pool2d_gap = AdaptiveAvgPool2D(1)
  273. self._num_channels = num_channels
  274. med_ch = int(num_channels / reduction_ratio)
  275. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  276. self.squeeze = Linear(
  277. num_channels,
  278. med_ch,
  279. weight_attr=ParamAttr(
  280. initializer=Uniform(-stdv, stdv), name=name + "_sqz_weights"),
  281. bias_attr=ParamAttr(name=name + '_sqz_offset'))
  282. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  283. self.excitation = Linear(
  284. med_ch,
  285. num_filters,
  286. weight_attr=ParamAttr(
  287. initializer=Uniform(-stdv, stdv), name=name + "_exc_weights"),
  288. bias_attr=ParamAttr(name=name + '_exc_offset'))
  289. def forward(self, input):
  290. pool = self.pool2d_gap(input)
  291. pool = paddle.squeeze(pool, axis=[2, 3])
  292. squeeze = self.squeeze(pool)
  293. squeeze = F.relu(squeeze)
  294. excitation = self.excitation(squeeze)
  295. excitation = F.sigmoid(excitation)
  296. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  297. out = input * excitation
  298. return out
  299. class Stage(nn.Layer):
  300. def __init__(self,
  301. num_channels,
  302. num_modules,
  303. num_filters,
  304. has_se=False,
  305. multi_scale_output=True,
  306. name=None):
  307. super(Stage, self).__init__()
  308. self._num_modules = num_modules
  309. self.stage_func_list = []
  310. for i in range(num_modules):
  311. if i == num_modules - 1 and not multi_scale_output:
  312. stage_func = self.add_sublayer(
  313. "stage_{}_{}".format(name, i + 1),
  314. HighResolutionModule(
  315. num_channels=num_channels,
  316. num_filters=num_filters,
  317. has_se=has_se,
  318. multi_scale_output=False,
  319. name=name + '_' + str(i + 1)))
  320. else:
  321. stage_func = self.add_sublayer(
  322. "stage_{}_{}".format(name, i + 1),
  323. HighResolutionModule(
  324. num_channels=num_channels,
  325. num_filters=num_filters,
  326. has_se=has_se,
  327. name=name + '_' + str(i + 1)))
  328. self.stage_func_list.append(stage_func)
  329. def forward(self, input):
  330. out = input
  331. for idx in range(self._num_modules):
  332. out = self.stage_func_list[idx](out)
  333. return out
  334. class HighResolutionModule(nn.Layer):
  335. def __init__(self,
  336. num_channels,
  337. num_filters,
  338. has_se=False,
  339. multi_scale_output=True,
  340. name=None):
  341. super(HighResolutionModule, self).__init__()
  342. self.branches_func = Branches(
  343. block_num=4,
  344. in_channels=num_channels,
  345. out_channels=num_filters,
  346. has_se=has_se,
  347. name=name)
  348. self.fuse_func = FuseLayers(
  349. in_channels=num_filters,
  350. out_channels=num_filters,
  351. multi_scale_output=multi_scale_output,
  352. name=name)
  353. def forward(self, input):
  354. out = self.branches_func(input)
  355. out = self.fuse_func(out)
  356. return out
  357. class FuseLayers(nn.Layer):
  358. def __init__(self,
  359. in_channels,
  360. out_channels,
  361. multi_scale_output=True,
  362. name=None):
  363. super(FuseLayers, self).__init__()
  364. self._actual_ch = len(in_channels) if multi_scale_output else 1
  365. self._in_channels = in_channels
  366. self.residual_func_list = []
  367. for i in range(self._actual_ch):
  368. for j in range(len(in_channels)):
  369. residual_func = None
  370. if j > i:
  371. residual_func = self.add_sublayer(
  372. "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
  373. ConvBNLayer(
  374. num_channels=in_channels[j],
  375. num_filters=out_channels[i],
  376. filter_size=1,
  377. stride=1,
  378. act=None,
  379. name=name + '_layer_' + str(i + 1) + '_' +
  380. str(j + 1)))
  381. self.residual_func_list.append(residual_func)
  382. elif j < i:
  383. pre_num_filters = in_channels[j]
  384. for k in range(i - j):
  385. if k == i - j - 1:
  386. residual_func = self.add_sublayer(
  387. "residual_{}_layer_{}_{}_{}".format(
  388. name, i + 1, j + 1, k + 1),
  389. ConvBNLayer(
  390. num_channels=pre_num_filters,
  391. num_filters=out_channels[i],
  392. filter_size=3,
  393. stride=2,
  394. act=None,
  395. name=name + '_layer_' + str(i + 1) + '_' +
  396. str(j + 1) + '_' + str(k + 1)))
  397. pre_num_filters = out_channels[i]
  398. else:
  399. residual_func = self.add_sublayer(
  400. "residual_{}_layer_{}_{}_{}".format(
  401. name, i + 1, j + 1, k + 1),
  402. ConvBNLayer(
  403. num_channels=pre_num_filters,
  404. num_filters=out_channels[j],
  405. filter_size=3,
  406. stride=2,
  407. act="relu",
  408. name=name + '_layer_' + str(i + 1) + '_' +
  409. str(j + 1) + '_' + str(k + 1)))
  410. pre_num_filters = out_channels[j]
  411. self.residual_func_list.append(residual_func)
  412. def forward(self, input):
  413. outs = []
  414. residual_func_idx = 0
  415. for i in range(self._actual_ch):
  416. residual = input[i]
  417. for j in range(len(self._in_channels)):
  418. if j > i:
  419. y = self.residual_func_list[residual_func_idx](input[j])
  420. residual_func_idx += 1
  421. y = F.upsample(y, scale_factor=2**(j - i), mode="nearest")
  422. residual = paddle.add(x=residual, y=y)
  423. elif j < i:
  424. y = input[j]
  425. for k in range(i - j):
  426. y = self.residual_func_list[residual_func_idx](y)
  427. residual_func_idx += 1
  428. residual = paddle.add(x=residual, y=y)
  429. residual = F.relu(residual)
  430. outs.append(residual)
  431. return outs
  432. class LastClsOut(nn.Layer):
  433. def __init__(self,
  434. num_channel_list,
  435. has_se,
  436. num_filters_list=[32, 64, 128, 256],
  437. name=None):
  438. super(LastClsOut, self).__init__()
  439. self.func_list = []
  440. for idx in range(len(num_channel_list)):
  441. func = self.add_sublayer(
  442. "conv_{}_conv_{}".format(name, idx + 1),
  443. BottleneckBlock(
  444. num_channels=num_channel_list[idx],
  445. num_filters=num_filters_list[idx],
  446. has_se=has_se,
  447. downsample=True,
  448. name=name + 'conv_' + str(idx + 1)))
  449. self.func_list.append(func)
  450. def forward(self, inputs):
  451. outs = []
  452. for idx, input in enumerate(inputs):
  453. out = self.func_list[idx](input)
  454. outs.append(out)
  455. return outs
  456. class HRNet(nn.Layer):
  457. def __init__(self, width=18, has_se=False, class_dim=1000):
  458. super(HRNet, self).__init__()
  459. self.width = width
  460. self.has_se = has_se
  461. self.channels = {
  462. 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
  463. 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
  464. 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
  465. 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
  466. 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
  467. 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
  468. 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
  469. 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
  470. }
  471. self._class_dim = class_dim
  472. channels_2, channels_3, channels_4 = self.channels[width]
  473. num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
  474. self.conv_layer1_1 = ConvBNLayer(
  475. num_channels=3,
  476. num_filters=64,
  477. filter_size=3,
  478. stride=2,
  479. act='relu',
  480. name="layer1_1")
  481. self.conv_layer1_2 = ConvBNLayer(
  482. num_channels=64,
  483. num_filters=64,
  484. filter_size=3,
  485. stride=2,
  486. act='relu',
  487. name="layer1_2")
  488. self.la1 = Layer1(num_channels=64, has_se=has_se, name="layer2")
  489. self.tr1 = TransitionLayer(
  490. in_channels=[256], out_channels=channels_2, name="tr1")
  491. self.st2 = Stage(
  492. num_channels=channels_2,
  493. num_modules=num_modules_2,
  494. num_filters=channels_2,
  495. has_se=self.has_se,
  496. name="st2")
  497. self.tr2 = TransitionLayer(
  498. in_channels=channels_2, out_channels=channels_3, name="tr2")
  499. self.st3 = Stage(
  500. num_channels=channels_3,
  501. num_modules=num_modules_3,
  502. num_filters=channels_3,
  503. has_se=self.has_se,
  504. name="st3")
  505. self.tr3 = TransitionLayer(
  506. in_channels=channels_3, out_channels=channels_4, name="tr3")
  507. self.st4 = Stage(
  508. num_channels=channels_4,
  509. num_modules=num_modules_4,
  510. num_filters=channels_4,
  511. has_se=self.has_se,
  512. name="st4")
  513. # classification
  514. num_filters_list = [32, 64, 128, 256]
  515. self.last_cls = LastClsOut(
  516. num_channel_list=channels_4,
  517. has_se=self.has_se,
  518. num_filters_list=num_filters_list,
  519. name="cls_head", )
  520. last_num_filters = [256, 512, 1024]
  521. self.cls_head_conv_list = []
  522. for idx in range(3):
  523. self.cls_head_conv_list.append(
  524. self.add_sublayer(
  525. "cls_head_add{}".format(idx + 1),
  526. ConvBNLayer(
  527. num_channels=num_filters_list[idx] * 4,
  528. num_filters=last_num_filters[idx],
  529. filter_size=3,
  530. stride=2,
  531. name="cls_head_add" + str(idx + 1))))
  532. self.conv_last = ConvBNLayer(
  533. num_channels=1024,
  534. num_filters=2048,
  535. filter_size=1,
  536. stride=1,
  537. name="cls_head_last_conv")
  538. self.pool2d_avg = AdaptiveAvgPool2D(1)
  539. stdv = 1.0 / math.sqrt(2048 * 1.0)
  540. self.out = Linear(
  541. 2048,
  542. class_dim,
  543. weight_attr=ParamAttr(
  544. initializer=Uniform(-stdv, stdv), name="fc_weights"),
  545. bias_attr=ParamAttr(name="fc_offset"))
  546. def forward(self, input):
  547. conv1 = self.conv_layer1_1(input)
  548. conv2 = self.conv_layer1_2(conv1)
  549. la1 = self.la1(conv2)
  550. tr1 = self.tr1([la1])
  551. st2 = self.st2(tr1)
  552. tr2 = self.tr2(st2)
  553. st3 = self.st3(tr2)
  554. tr3 = self.tr3(st3)
  555. st4 = self.st4(tr3)
  556. last_cls = self.last_cls(st4)
  557. y = last_cls[0]
  558. for idx in range(3):
  559. y = paddle.add(last_cls[idx + 1], self.cls_head_conv_list[idx](y))
  560. y = self.conv_last(y)
  561. y = self.pool2d_avg(y)
  562. y = paddle.reshape(y, shape=[-1, y.shape[1]])
  563. y = self.out(y)
  564. return y
  565. def HRNet_W18_C(**args):
  566. model = HRNet(width=18, **args)
  567. return model
  568. def HRNet_W30_C(**args):
  569. model = HRNet(width=30, **args)
  570. return model
  571. def HRNet_W32_C(**args):
  572. model = HRNet(width=32, **args)
  573. return model
  574. def HRNet_W40_C(**args):
  575. model = HRNet(width=40, **args)
  576. return model
  577. def HRNet_W44_C(**args):
  578. model = HRNet(width=44, **args)
  579. return model
  580. def HRNet_W48_C(**args):
  581. model = HRNet(width=48, **args)
  582. return model
  583. def HRNet_W60_C(**args):
  584. model = HRNet(width=60, **args)
  585. return model
  586. def HRNet_W64_C(**args):
  587. model = HRNet(width=64, **args)
  588. return model
  589. def SE_HRNet_W18_C(**args):
  590. model = HRNet(width=18, has_se=True, **args)
  591. return model
  592. def SE_HRNet_W30_C(**args):
  593. model = HRNet(width=30, has_se=True, **args)
  594. return model
  595. def SE_HRNet_W32_C(**args):
  596. model = HRNet(width=32, has_se=True, **args)
  597. return model
  598. def SE_HRNet_W40_C(**args):
  599. model = HRNet(width=40, has_se=True, **args)
  600. return model
  601. def SE_HRNet_W44_C(**args):
  602. model = HRNet(width=44, has_se=True, **args)
  603. return model
  604. def SE_HRNet_W48_C(**args):
  605. model = HRNet(width=48, has_se=True, **args)
  606. return model
  607. def SE_HRNet_W60_C(**args):
  608. model = HRNet(width=60, has_se=True, **args)
  609. return model
  610. def SE_HRNet_W64_C(**args):
  611. model = HRNet(width=64, has_se=True, **args)
  612. return model