hrnet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. from paddle import fluid
  19. from paddle.fluid.param_attr import ParamAttr
  20. from paddle.fluid.framework import Variable
  21. from paddle.fluid.regularizer import L2Decay
  22. from numbers import Integral
  23. from paddle.fluid.initializer import MSRA
  24. import math
  25. __all__ = ['HRNet']
  26. class HRNet(object):
  27. def __init__(self,
  28. width=40,
  29. has_se=False,
  30. freeze_at=0,
  31. norm_type='bn',
  32. freeze_norm=False,
  33. norm_decay=0.,
  34. feature_maps=[2, 3, 4, 5],
  35. num_classes=None):
  36. super(HRNet, self).__init__()
  37. if isinstance(feature_maps, Integral):
  38. feature_maps = [feature_maps]
  39. assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
  40. assert len(feature_maps) > 0, "need one or more feature maps"
  41. assert norm_type in ['bn', 'sync_bn']
  42. self.width = width
  43. self.has_se = has_se
  44. self.num_modules = {
  45. '18_small_v1': [1, 1, 1, 1],
  46. '18': [1, 1, 4, 3],
  47. '30': [1, 1, 4, 3],
  48. '32': [1, 1, 4, 3],
  49. '40': [1, 1, 4, 3],
  50. '44': [1, 1, 4, 3],
  51. '48': [1, 1, 4, 3],
  52. '60': [1, 1, 4, 3],
  53. '64': [1, 1, 4, 3]
  54. }
  55. self.num_blocks = {
  56. '18_small_v1': [[1], [2, 2], [2, 2, 2], [2, 2, 2, 2]],
  57. '18': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
  58. '30': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
  59. '32': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
  60. '40': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
  61. '44': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
  62. '48': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
  63. '60': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]],
  64. '64': [[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]]
  65. }
  66. self.channels = {
  67. '18_small_v1': [[32], [16, 32], [16, 32, 64], [16, 32, 64, 128]],
  68. '18': [[64], [18, 36], [18, 36, 72], [18, 36, 72, 144]],
  69. '30': [[64], [30, 60], [30, 60, 120], [30, 60, 120, 240]],
  70. '32': [[64], [32, 64], [32, 64, 128], [32, 64, 128, 256]],
  71. '40': [[64], [40, 80], [40, 80, 160], [40, 80, 160, 320]],
  72. '44': [[64], [44, 88], [44, 88, 176], [44, 88, 176, 352]],
  73. '48': [[64], [48, 96], [48, 96, 192], [48, 96, 192, 384]],
  74. '60': [[64], [60, 120], [60, 120, 240], [60, 120, 240, 480]],
  75. '64': [[64], [64, 128], [64, 128, 256], [64, 128, 256, 512]],
  76. }
  77. self.freeze_at = freeze_at
  78. self.norm_type = norm_type
  79. self.norm_decay = norm_decay
  80. self.freeze_norm = freeze_norm
  81. self.feature_maps = feature_maps
  82. self.num_classes = num_classes
  83. self.end_points = []
  84. return
  85. def net(self, input):
  86. width = self.width
  87. channels_1, channels_2, channels_3, channels_4 = self.channels[str(
  88. width)]
  89. num_modules_1, num_modules_2, num_modules_3, num_modules_4 = self.num_modules[
  90. str(width)]
  91. num_blocks_1, num_blocks_2, num_blocks_3, num_blocks_4 = self.num_blocks[
  92. str(width)]
  93. x = self.conv_bn_layer(
  94. input=input,
  95. filter_size=3,
  96. num_filters=channels_1[0],
  97. stride=2,
  98. if_act=True,
  99. name='layer1_1')
  100. x = self.conv_bn_layer(
  101. input=x,
  102. filter_size=3,
  103. num_filters=channels_1[0],
  104. stride=2,
  105. if_act=True,
  106. name='layer1_2')
  107. la1 = self.layer1(x, num_blocks_1, channels_1, name='layer2')
  108. tr1 = self.transition_layer([la1], [256], channels_2, name='tr1')
  109. st2 = self.stage(
  110. tr1, num_modules_2, num_blocks_2, channels_2, name='st2')
  111. tr2 = self.transition_layer(st2, channels_2, channels_3, name='tr2')
  112. st3 = self.stage(
  113. tr2, num_modules_3, num_blocks_3, channels_3, name='st3')
  114. tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3')
  115. st4 = self.stage(
  116. tr3, num_modules_4, num_blocks_4, channels_4, name='st4')
  117. # classification
  118. if self.num_classes:
  119. last_cls = self.last_cls_out(x=st4, name='cls_head')
  120. y = last_cls[0]
  121. last_num_filters = [256, 512, 1024]
  122. for i in range(3):
  123. y = fluid.layers.elementwise_add(
  124. last_cls[i + 1],
  125. self.conv_bn_layer(
  126. input=y,
  127. filter_size=3,
  128. num_filters=last_num_filters[i],
  129. stride=2,
  130. name='cls_head_add' + str(i + 1)))
  131. y = self.conv_bn_layer(
  132. input=y,
  133. filter_size=1,
  134. num_filters=2048,
  135. stride=1,
  136. name='cls_head_last_conv')
  137. pool = fluid.layers.pool2d(
  138. input=y, pool_type='avg', global_pooling=True)
  139. stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
  140. out = fluid.layers.fc(
  141. input=pool,
  142. size=self.num_classes,
  143. param_attr=ParamAttr(
  144. name='fc_weights',
  145. initializer=fluid.initializer.Uniform(-stdv, stdv)),
  146. bias_attr=ParamAttr(name='fc_offset'))
  147. return out
  148. # segmentation
  149. if self.feature_maps == "stage4":
  150. return st4
  151. self.end_points = st4
  152. return st4[-1]
  153. def layer1(self, input, num_blocks, channels, name=None):
  154. conv = input
  155. for i in range(num_blocks[0]):
  156. conv = self.bottleneck_block(
  157. conv,
  158. num_filters=channels[0],
  159. downsample=True if i == 0 else False,
  160. name=name + '_' + str(i + 1))
  161. return conv
  162. def transition_layer(self, x, in_channels, out_channels, name=None):
  163. num_in = len(in_channels)
  164. num_out = len(out_channels)
  165. out = []
  166. for i in range(num_out):
  167. if i < num_in:
  168. if in_channels[i] != out_channels[i]:
  169. residual = self.conv_bn_layer(
  170. x[i],
  171. filter_size=3,
  172. num_filters=out_channels[i],
  173. name=name + '_layer_' + str(i + 1))
  174. out.append(residual)
  175. else:
  176. out.append(x[i])
  177. else:
  178. residual = self.conv_bn_layer(
  179. x[-1],
  180. filter_size=3,
  181. num_filters=out_channels[i],
  182. stride=2,
  183. name=name + '_layer_' + str(i + 1))
  184. out.append(residual)
  185. return out
  186. def branches(self, x, block_num, channels, name=None):
  187. out = []
  188. for i in range(len(channels)):
  189. residual = x[i]
  190. for j in range(block_num[i]):
  191. residual = self.basic_block(
  192. residual,
  193. channels[i],
  194. name=name + '_branch_layer_' + str(i + 1) + '_' +
  195. str(j + 1))
  196. out.append(residual)
  197. return out
  198. def fuse_layers(self, x, channels, multi_scale_output=True, name=None):
  199. out = []
  200. for i in range(len(channels) if multi_scale_output else 1):
  201. residual = x[i]
  202. if self.feature_maps == "stage4":
  203. shape = fluid.layers.shape(residual)
  204. width = shape[-1]
  205. height = shape[-2]
  206. for j in range(len(channels)):
  207. if j > i:
  208. y = self.conv_bn_layer(
  209. x[j],
  210. filter_size=1,
  211. num_filters=channels[i],
  212. if_act=False,
  213. name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
  214. if self.feature_maps == "stage4":
  215. y = fluid.layers.resize_bilinear(
  216. input=y,
  217. out_shape=[height, width],
  218. align_corners=False,
  219. align_mode=1)
  220. else:
  221. y = fluid.layers.resize_nearest(
  222. input=y, scale=2**(j - i), align_corners=False)
  223. residual = fluid.layers.elementwise_add(
  224. x=residual, y=y, act=None)
  225. elif j < i:
  226. y = x[j]
  227. for k in range(i - j):
  228. if k == i - j - 1:
  229. y = self.conv_bn_layer(
  230. y,
  231. filter_size=3,
  232. num_filters=channels[i],
  233. stride=2,
  234. if_act=False,
  235. name=name + '_layer_' + str(i + 1) + '_' +
  236. str(j + 1) + '_' + str(k + 1))
  237. else:
  238. y = self.conv_bn_layer(
  239. y,
  240. filter_size=3,
  241. num_filters=channels[j],
  242. stride=2,
  243. name=name + '_layer_' + str(i + 1) + '_' +
  244. str(j + 1) + '_' + str(k + 1))
  245. residual = fluid.layers.elementwise_add(
  246. x=residual, y=y, act=None)
  247. residual = fluid.layers.relu(residual)
  248. out.append(residual)
  249. return out
  250. def high_resolution_module(self,
  251. x,
  252. num_blocks,
  253. channels,
  254. multi_scale_output=True,
  255. name=None):
  256. residual = self.branches(x, num_blocks, channels, name=name)
  257. out = self.fuse_layers(
  258. residual,
  259. channels,
  260. multi_scale_output=multi_scale_output,
  261. name=name)
  262. return out
  263. def stage(self,
  264. x,
  265. num_modules,
  266. num_blocks,
  267. channels,
  268. multi_scale_output=True,
  269. name=None):
  270. out = x
  271. for i in range(num_modules):
  272. if i == num_modules - 1 and multi_scale_output == False:
  273. out = self.high_resolution_module(
  274. out,
  275. num_blocks,
  276. channels,
  277. multi_scale_output=False,
  278. name=name + '_' + str(i + 1))
  279. else:
  280. out = self.high_resolution_module(
  281. out, num_blocks, channels, name=name + '_' + str(i + 1))
  282. return out
  283. def last_cls_out(self, x, name=None):
  284. out = []
  285. num_filters_list = [32, 64, 128, 256]
  286. for i in range(len(x)):
  287. out.append(
  288. self.bottleneck_block(
  289. input=x[i],
  290. num_filters=num_filters_list[i],
  291. name=name + 'conv_' + str(i + 1),
  292. downsample=True))
  293. return out
  294. def basic_block(self,
  295. input,
  296. num_filters,
  297. stride=1,
  298. downsample=False,
  299. name=None):
  300. residual = input
  301. conv = self.conv_bn_layer(
  302. input=input,
  303. filter_size=3,
  304. num_filters=num_filters,
  305. stride=stride,
  306. name=name + '_conv1')
  307. conv = self.conv_bn_layer(
  308. input=conv,
  309. filter_size=3,
  310. num_filters=num_filters,
  311. if_act=False,
  312. name=name + '_conv2')
  313. if downsample:
  314. residual = self.conv_bn_layer(
  315. input=input,
  316. filter_size=1,
  317. num_filters=num_filters,
  318. if_act=False,
  319. name=name + '_downsample')
  320. if self.has_se:
  321. conv = self.squeeze_excitation(
  322. input=conv,
  323. num_channels=num_filters,
  324. reduction_ratio=16,
  325. name=name + '_fc')
  326. return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
  327. def bottleneck_block(self,
  328. input,
  329. num_filters,
  330. stride=1,
  331. downsample=False,
  332. name=None):
  333. residual = input
  334. conv = self.conv_bn_layer(
  335. input=input,
  336. filter_size=1,
  337. num_filters=num_filters,
  338. name=name + '_conv1')
  339. conv = self.conv_bn_layer(
  340. input=conv,
  341. filter_size=3,
  342. num_filters=num_filters,
  343. stride=stride,
  344. name=name + '_conv2')
  345. conv = self.conv_bn_layer(
  346. input=conv,
  347. filter_size=1,
  348. num_filters=num_filters * 4,
  349. if_act=False,
  350. name=name + '_conv3')
  351. if downsample:
  352. residual = self.conv_bn_layer(
  353. input=input,
  354. filter_size=1,
  355. num_filters=num_filters * 4,
  356. if_act=False,
  357. name=name + '_downsample')
  358. if self.has_se:
  359. conv = self.squeeze_excitation(
  360. input=conv,
  361. num_channels=num_filters * 4,
  362. reduction_ratio=16,
  363. name=name + '_fc')
  364. return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
  365. def squeeze_excitation(self,
  366. input,
  367. num_channels,
  368. reduction_ratio,
  369. name=None):
  370. pool = fluid.layers.pool2d(
  371. input=input, pool_size=0, pool_type='avg', global_pooling=True)
  372. stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
  373. squeeze = fluid.layers.fc(
  374. input=pool,
  375. size=num_channels / reduction_ratio,
  376. act='relu',
  377. param_attr=fluid.param_attr.ParamAttr(
  378. initializer=fluid.initializer.Uniform(-stdv, stdv),
  379. name=name + '_sqz_weights'),
  380. bias_attr=ParamAttr(name=name + '_sqz_offset'))
  381. stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
  382. excitation = fluid.layers.fc(
  383. input=squeeze,
  384. size=num_channels,
  385. act='sigmoid',
  386. param_attr=fluid.param_attr.ParamAttr(
  387. initializer=fluid.initializer.Uniform(-stdv, stdv),
  388. name=name + '_exc_weights'),
  389. bias_attr=ParamAttr(name=name + '_exc_offset'))
  390. scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
  391. return scale
  392. def conv_bn_layer(self,
  393. input,
  394. filter_size,
  395. num_filters,
  396. stride=1,
  397. padding=1,
  398. num_groups=1,
  399. if_act=True,
  400. name=None):
  401. conv = fluid.layers.conv2d(
  402. input=input,
  403. num_filters=num_filters,
  404. filter_size=filter_size,
  405. stride=stride,
  406. padding=(filter_size - 1) // 2,
  407. groups=num_groups,
  408. act=None,
  409. param_attr=ParamAttr(
  410. initializer=MSRA(), name=name + '_weights'),
  411. bias_attr=False)
  412. bn_name = name + '_bn'
  413. bn = self._bn(input=conv, bn_name=bn_name)
  414. if if_act:
  415. bn = fluid.layers.relu(bn)
  416. return bn
  417. def _bn(self, input, act=None, bn_name=None):
  418. norm_lr = 0. if self.freeze_norm else 1.
  419. norm_decay = self.norm_decay
  420. if self.num_classes or self.feature_maps == "stage4":
  421. regularizer = None
  422. pattr_initializer = fluid.initializer.Constant(1.0)
  423. battr_initializer = fluid.initializer.Constant(0.0)
  424. else:
  425. regularizer = L2Decay(norm_decay)
  426. pattr_initializer = None
  427. battr_initializer = None
  428. pattr = ParamAttr(
  429. name=bn_name + '_scale',
  430. learning_rate=norm_lr,
  431. regularizer=regularizer,
  432. initializer=pattr_initializer)
  433. battr = ParamAttr(
  434. name=bn_name + '_offset',
  435. learning_rate=norm_lr,
  436. regularizer=regularizer,
  437. initializer=battr_initializer)
  438. global_stats = True if self.freeze_norm else False
  439. out = fluid.layers.batch_norm(
  440. input=input,
  441. act=act,
  442. name=bn_name + '.output.1',
  443. param_attr=pattr,
  444. bias_attr=battr,
  445. moving_mean_name=bn_name + '_mean',
  446. moving_variance_name=bn_name + '_variance',
  447. use_global_stats=global_stats)
  448. scale = fluid.framework._get_var(pattr.name)
  449. bias = fluid.framework._get_var(battr.name)
  450. if self.freeze_norm:
  451. scale.stop_gradient = True
  452. bias.stop_gradient = True
  453. return out
  454. def __call__(self, input):
  455. assert isinstance(input, Variable)
  456. if isinstance(self.feature_maps, (list, tuple)):
  457. assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
  458. "feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
  459. res_endpoints = []
  460. res = input
  461. feature_maps = self.feature_maps
  462. out = self.net(input)
  463. if self.num_classes or self.feature_maps == "stage4":
  464. return out
  465. for i in feature_maps:
  466. res = self.end_points[i - 2]
  467. if i in self.feature_maps:
  468. res_endpoints.append(res)
  469. if self.freeze_at >= i:
  470. res.stop_gradient = True
  471. return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
  472. for idx, feat in enumerate(res_endpoints)])