hrnet.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn import AdaptiveAvgPool2D, Linear
  18. from paddle.regularizer import L2Decay
  19. from paddle import ParamAttr
  20. from paddle.nn.initializer import Normal, Uniform
  21. from numbers import Integral
  22. import math
  23. from paddlex.ppdet.core.workspace import register
  24. from ..shape_spec import ShapeSpec
  25. __all__ = ['HRNet']
  26. class ConvNormLayer(nn.Layer):
  27. def __init__(self,
  28. ch_in,
  29. ch_out,
  30. filter_size,
  31. stride=1,
  32. norm_type='bn',
  33. norm_groups=32,
  34. use_dcn=False,
  35. norm_decay=0.,
  36. freeze_norm=False,
  37. act=None,
  38. name=None):
  39. super(ConvNormLayer, self).__init__()
  40. assert norm_type in ['bn', 'sync_bn', 'gn']
  41. self.act = act
  42. self.conv = nn.Conv2D(
  43. in_channels=ch_in,
  44. out_channels=ch_out,
  45. kernel_size=filter_size,
  46. stride=stride,
  47. padding=(filter_size - 1) // 2,
  48. groups=1,
  49. weight_attr=ParamAttr(
  50. name=name + "_weights", initializer=Normal(
  51. mean=0., std=0.01)),
  52. bias_attr=False)
  53. norm_lr = 0. if freeze_norm else 1.
  54. norm_name = name + '_bn'
  55. param_attr = ParamAttr(
  56. name=norm_name + "_scale",
  57. learning_rate=norm_lr,
  58. regularizer=L2Decay(norm_decay))
  59. bias_attr = ParamAttr(
  60. name=norm_name + "_offset",
  61. learning_rate=norm_lr,
  62. regularizer=L2Decay(norm_decay))
  63. global_stats = True if freeze_norm else False
  64. if norm_type in ['bn', 'sync_bn']:
  65. self.norm = nn.BatchNorm(
  66. ch_out,
  67. param_attr=param_attr,
  68. bias_attr=bias_attr,
  69. use_global_stats=global_stats,
  70. moving_mean_name=norm_name + '_mean',
  71. moving_variance_name=norm_name + '_variance')
  72. elif norm_type == 'gn':
  73. self.norm = nn.GroupNorm(
  74. num_groups=norm_groups,
  75. num_channels=ch_out,
  76. weight_attr=param_attr,
  77. bias_attr=bias_attr)
  78. norm_params = self.norm.parameters()
  79. if freeze_norm:
  80. for param in norm_params:
  81. param.stop_gradient = True
  82. def forward(self, inputs):
  83. out = self.conv(inputs)
  84. out = self.norm(out)
  85. if self.act == 'relu':
  86. out = F.relu(out)
  87. return out
  88. class Layer1(nn.Layer):
  89. def __init__(self,
  90. num_channels,
  91. has_se=False,
  92. norm_decay=0.,
  93. freeze_norm=True,
  94. name=None):
  95. super(Layer1, self).__init__()
  96. self.bottleneck_block_list = []
  97. for i in range(4):
  98. bottleneck_block = self.add_sublayer(
  99. "block_{}_{}".format(name, i + 1),
  100. BottleneckBlock(
  101. num_channels=num_channels if i == 0 else 256,
  102. num_filters=64,
  103. has_se=has_se,
  104. stride=1,
  105. downsample=True if i == 0 else False,
  106. norm_decay=norm_decay,
  107. freeze_norm=freeze_norm,
  108. name=name + '_' + str(i + 1)))
  109. self.bottleneck_block_list.append(bottleneck_block)
  110. def forward(self, input):
  111. conv = input
  112. for block_func in self.bottleneck_block_list:
  113. conv = block_func(conv)
  114. return conv
  115. class TransitionLayer(nn.Layer):
  116. def __init__(self,
  117. in_channels,
  118. out_channels,
  119. norm_decay=0.,
  120. freeze_norm=True,
  121. name=None):
  122. super(TransitionLayer, self).__init__()
  123. num_in = len(in_channels)
  124. num_out = len(out_channels)
  125. out = []
  126. self.conv_bn_func_list = []
  127. for i in range(num_out):
  128. residual = None
  129. if i < num_in:
  130. if in_channels[i] != out_channels[i]:
  131. residual = self.add_sublayer(
  132. "transition_{}_layer_{}".format(name, i + 1),
  133. ConvNormLayer(
  134. ch_in=in_channels[i],
  135. ch_out=out_channels[i],
  136. filter_size=3,
  137. norm_decay=norm_decay,
  138. freeze_norm=freeze_norm,
  139. act='relu',
  140. name=name + '_layer_' + str(i + 1)))
  141. else:
  142. residual = self.add_sublayer(
  143. "transition_{}_layer_{}".format(name, i + 1),
  144. ConvNormLayer(
  145. ch_in=in_channels[-1],
  146. ch_out=out_channels[i],
  147. filter_size=3,
  148. stride=2,
  149. norm_decay=norm_decay,
  150. freeze_norm=freeze_norm,
  151. act='relu',
  152. name=name + '_layer_' + str(i + 1)))
  153. self.conv_bn_func_list.append(residual)
  154. def forward(self, input):
  155. outs = []
  156. for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
  157. if conv_bn_func is None:
  158. outs.append(input[idx])
  159. else:
  160. if idx < len(input):
  161. outs.append(conv_bn_func(input[idx]))
  162. else:
  163. outs.append(conv_bn_func(input[-1]))
  164. return outs
  165. class Branches(nn.Layer):
  166. def __init__(self,
  167. block_num,
  168. in_channels,
  169. out_channels,
  170. has_se=False,
  171. norm_decay=0.,
  172. freeze_norm=True,
  173. name=None):
  174. super(Branches, self).__init__()
  175. self.basic_block_list = []
  176. for i in range(len(out_channels)):
  177. self.basic_block_list.append([])
  178. for j in range(block_num):
  179. in_ch = in_channels[i] if j == 0 else out_channels[i]
  180. basic_block_func = self.add_sublayer(
  181. "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
  182. BasicBlock(
  183. num_channels=in_ch,
  184. num_filters=out_channels[i],
  185. has_se=has_se,
  186. norm_decay=norm_decay,
  187. freeze_norm=freeze_norm,
  188. name=name + '_branch_layer_' + str(i + 1) + '_' +
  189. str(j + 1)))
  190. self.basic_block_list[i].append(basic_block_func)
  191. def forward(self, inputs):
  192. outs = []
  193. for idx, input in enumerate(inputs):
  194. conv = input
  195. basic_block_list = self.basic_block_list[idx]
  196. for basic_block_func in basic_block_list:
  197. conv = basic_block_func(conv)
  198. outs.append(conv)
  199. return outs
  200. class BottleneckBlock(nn.Layer):
  201. def __init__(self,
  202. num_channels,
  203. num_filters,
  204. has_se,
  205. stride=1,
  206. downsample=False,
  207. norm_decay=0.,
  208. freeze_norm=True,
  209. name=None):
  210. super(BottleneckBlock, self).__init__()
  211. self.has_se = has_se
  212. self.downsample = downsample
  213. self.conv1 = ConvNormLayer(
  214. ch_in=num_channels,
  215. ch_out=num_filters,
  216. filter_size=1,
  217. norm_decay=norm_decay,
  218. freeze_norm=freeze_norm,
  219. act="relu",
  220. name=name + "_conv1")
  221. self.conv2 = ConvNormLayer(
  222. ch_in=num_filters,
  223. ch_out=num_filters,
  224. filter_size=3,
  225. stride=stride,
  226. norm_decay=norm_decay,
  227. freeze_norm=freeze_norm,
  228. act="relu",
  229. name=name + "_conv2")
  230. self.conv3 = ConvNormLayer(
  231. ch_in=num_filters,
  232. ch_out=num_filters * 4,
  233. filter_size=1,
  234. norm_decay=norm_decay,
  235. freeze_norm=freeze_norm,
  236. act=None,
  237. name=name + "_conv3")
  238. if self.downsample:
  239. self.conv_down = ConvNormLayer(
  240. ch_in=num_channels,
  241. ch_out=num_filters * 4,
  242. filter_size=1,
  243. norm_decay=norm_decay,
  244. freeze_norm=freeze_norm,
  245. act=None,
  246. name=name + "_downsample")
  247. if self.has_se:
  248. self.se = SELayer(
  249. num_channels=num_filters * 4,
  250. num_filters=num_filters * 4,
  251. reduction_ratio=16,
  252. name='fc' + name)
  253. def forward(self, input):
  254. residual = input
  255. conv1 = self.conv1(input)
  256. conv2 = self.conv2(conv1)
  257. conv3 = self.conv3(conv2)
  258. if self.downsample:
  259. residual = self.conv_down(input)
  260. if self.has_se:
  261. conv3 = self.se(conv3)
  262. y = paddle.add(x=residual, y=conv3)
  263. y = F.relu(y)
  264. return y
  265. class BasicBlock(nn.Layer):
  266. def __init__(self,
  267. num_channels,
  268. num_filters,
  269. stride=1,
  270. has_se=False,
  271. downsample=False,
  272. norm_decay=0.,
  273. freeze_norm=True,
  274. name=None):
  275. super(BasicBlock, self).__init__()
  276. self.has_se = has_se
  277. self.downsample = downsample
  278. self.conv1 = ConvNormLayer(
  279. ch_in=num_channels,
  280. ch_out=num_filters,
  281. filter_size=3,
  282. norm_decay=norm_decay,
  283. freeze_norm=freeze_norm,
  284. stride=stride,
  285. act="relu",
  286. name=name + "_conv1")
  287. self.conv2 = ConvNormLayer(
  288. ch_in=num_filters,
  289. ch_out=num_filters,
  290. filter_size=3,
  291. norm_decay=norm_decay,
  292. freeze_norm=freeze_norm,
  293. stride=1,
  294. act=None,
  295. name=name + "_conv2")
  296. if self.downsample:
  297. self.conv_down = ConvNormLayer(
  298. ch_in=num_channels,
  299. ch_out=num_filters * 4,
  300. filter_size=1,
  301. norm_decay=norm_decay,
  302. freeze_norm=freeze_norm,
  303. act=None,
  304. name=name + "_downsample")
  305. if self.has_se:
  306. self.se = SELayer(
  307. num_channels=num_filters,
  308. num_filters=num_filters,
  309. reduction_ratio=16,
  310. name='fc' + name)
  311. def forward(self, input):
  312. residual = input
  313. conv1 = self.conv1(input)
  314. conv2 = self.conv2(conv1)
  315. if self.downsample:
  316. residual = self.conv_down(input)
  317. if self.has_se:
  318. conv2 = self.se(conv2)
  319. y = paddle.add(x=residual, y=conv2)
  320. y = F.relu(y)
  321. return y
  322. class SELayer(nn.Layer):
  323. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  324. super(SELayer, self).__init__()
  325. self.pool2d_gap = AdaptiveAvgPool2D(1)
  326. self._num_channels = num_channels
  327. med_ch = int(num_channels / reduction_ratio)
  328. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  329. self.squeeze = Linear(
  330. num_channels,
  331. med_ch,
  332. weight_attr=ParamAttr(
  333. initializer=Uniform(-stdv, stdv), name=name + "_sqz_weights"),
  334. bias_attr=ParamAttr(name=name + '_sqz_offset'))
  335. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  336. self.excitation = Linear(
  337. med_ch,
  338. num_filters,
  339. weight_attr=ParamAttr(
  340. initializer=Uniform(-stdv, stdv), name=name + "_exc_weights"),
  341. bias_attr=ParamAttr(name=name + '_exc_offset'))
  342. def forward(self, input):
  343. pool = self.pool2d_gap(input)
  344. pool = paddle.squeeze(pool, axis=[2, 3])
  345. squeeze = self.squeeze(pool)
  346. squeeze = F.relu(squeeze)
  347. excitation = self.excitation(squeeze)
  348. excitation = F.sigmoid(excitation)
  349. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  350. out = input * excitation
  351. return out
  352. class Stage(nn.Layer):
  353. def __init__(self,
  354. num_channels,
  355. num_modules,
  356. num_filters,
  357. has_se=False,
  358. norm_decay=0.,
  359. freeze_norm=True,
  360. multi_scale_output=True,
  361. name=None):
  362. super(Stage, self).__init__()
  363. self._num_modules = num_modules
  364. self.stage_func_list = []
  365. for i in range(num_modules):
  366. if i == num_modules - 1 and not multi_scale_output:
  367. stage_func = self.add_sublayer(
  368. "stage_{}_{}".format(name, i + 1),
  369. HighResolutionModule(
  370. num_channels=num_channels,
  371. num_filters=num_filters,
  372. has_se=has_se,
  373. norm_decay=norm_decay,
  374. freeze_norm=freeze_norm,
  375. multi_scale_output=False,
  376. name=name + '_' + str(i + 1)))
  377. else:
  378. stage_func = self.add_sublayer(
  379. "stage_{}_{}".format(name, i + 1),
  380. HighResolutionModule(
  381. num_channels=num_channels,
  382. num_filters=num_filters,
  383. has_se=has_se,
  384. norm_decay=norm_decay,
  385. freeze_norm=freeze_norm,
  386. name=name + '_' + str(i + 1)))
  387. self.stage_func_list.append(stage_func)
  388. def forward(self, input):
  389. out = input
  390. for idx in range(self._num_modules):
  391. out = self.stage_func_list[idx](out)
  392. return out
  393. class HighResolutionModule(nn.Layer):
  394. def __init__(self,
  395. num_channels,
  396. num_filters,
  397. has_se=False,
  398. multi_scale_output=True,
  399. norm_decay=0.,
  400. freeze_norm=True,
  401. name=None):
  402. super(HighResolutionModule, self).__init__()
  403. self.branches_func = Branches(
  404. block_num=4,
  405. in_channels=num_channels,
  406. out_channels=num_filters,
  407. has_se=has_se,
  408. norm_decay=norm_decay,
  409. freeze_norm=freeze_norm,
  410. name=name)
  411. self.fuse_func = FuseLayers(
  412. in_channels=num_filters,
  413. out_channels=num_filters,
  414. multi_scale_output=multi_scale_output,
  415. norm_decay=norm_decay,
  416. freeze_norm=freeze_norm,
  417. name=name)
  418. def forward(self, input):
  419. out = self.branches_func(input)
  420. out = self.fuse_func(out)
  421. return out
  422. class FuseLayers(nn.Layer):
  423. def __init__(self,
  424. in_channels,
  425. out_channels,
  426. multi_scale_output=True,
  427. norm_decay=0.,
  428. freeze_norm=True,
  429. name=None):
  430. super(FuseLayers, self).__init__()
  431. self._actual_ch = len(in_channels) if multi_scale_output else 1
  432. self._in_channels = in_channels
  433. self.residual_func_list = []
  434. for i in range(self._actual_ch):
  435. for j in range(len(in_channels)):
  436. residual_func = None
  437. if j > i:
  438. residual_func = self.add_sublayer(
  439. "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
  440. ConvNormLayer(
  441. ch_in=in_channels[j],
  442. ch_out=out_channels[i],
  443. filter_size=1,
  444. stride=1,
  445. act=None,
  446. norm_decay=norm_decay,
  447. freeze_norm=freeze_norm,
  448. name=name + '_layer_' + str(i + 1) + '_' +
  449. str(j + 1)))
  450. self.residual_func_list.append(residual_func)
  451. elif j < i:
  452. pre_num_filters = in_channels[j]
  453. for k in range(i - j):
  454. if k == i - j - 1:
  455. residual_func = self.add_sublayer(
  456. "residual_{}_layer_{}_{}_{}".format(
  457. name, i + 1, j + 1, k + 1),
  458. ConvNormLayer(
  459. ch_in=pre_num_filters,
  460. ch_out=out_channels[i],
  461. filter_size=3,
  462. stride=2,
  463. norm_decay=norm_decay,
  464. freeze_norm=freeze_norm,
  465. act=None,
  466. name=name + '_layer_' + str(i + 1) + '_' +
  467. str(j + 1) + '_' + str(k + 1)))
  468. pre_num_filters = out_channels[i]
  469. else:
  470. residual_func = self.add_sublayer(
  471. "residual_{}_layer_{}_{}_{}".format(
  472. name, i + 1, j + 1, k + 1),
  473. ConvNormLayer(
  474. ch_in=pre_num_filters,
  475. ch_out=out_channels[j],
  476. filter_size=3,
  477. stride=2,
  478. norm_decay=norm_decay,
  479. freeze_norm=freeze_norm,
  480. act="relu",
  481. name=name + '_layer_' + str(i + 1) + '_' +
  482. str(j + 1) + '_' + str(k + 1)))
  483. pre_num_filters = out_channels[j]
  484. self.residual_func_list.append(residual_func)
  485. def forward(self, input):
  486. outs = []
  487. residual_func_idx = 0
  488. for i in range(self._actual_ch):
  489. residual = input[i]
  490. for j in range(len(self._in_channels)):
  491. if j > i:
  492. y = self.residual_func_list[residual_func_idx](input[j])
  493. residual_func_idx += 1
  494. y = F.interpolate(y, scale_factor=2**(j - i))
  495. residual = paddle.add(x=residual, y=y)
  496. elif j < i:
  497. y = input[j]
  498. for k in range(i - j):
  499. y = self.residual_func_list[residual_func_idx](y)
  500. residual_func_idx += 1
  501. residual = paddle.add(x=residual, y=y)
  502. residual = F.relu(residual)
  503. outs.append(residual)
  504. return outs
  505. @register
  506. class HRNet(nn.Layer):
  507. """
  508. HRNet, see https://arxiv.org/abs/1908.07919
  509. Args:
  510. width (int): the width of HRNet
  511. has_se (bool): whether to add SE block for each stage
  512. freeze_at (int): the stage to freeze
  513. freeze_norm (bool): whether to freeze norm in HRNet
  514. norm_decay (float): weight decay for normalization layer weights
  515. return_idx (List): the stage to return
  516. """
  517. def __init__(self,
  518. width=18,
  519. has_se=False,
  520. freeze_at=0,
  521. freeze_norm=True,
  522. norm_decay=0.,
  523. return_idx=[0, 1, 2, 3]):
  524. super(HRNet, self).__init__()
  525. self.width = width
  526. self.has_se = has_se
  527. if isinstance(return_idx, Integral):
  528. return_idx = [return_idx]
  529. assert len(return_idx) > 0, "need one or more return index"
  530. self.freeze_at = freeze_at
  531. self.return_idx = return_idx
  532. self.channels = {
  533. 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
  534. 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
  535. 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
  536. 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
  537. 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
  538. 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
  539. 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
  540. 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
  541. }
  542. channels_2, channels_3, channels_4 = self.channels[width]
  543. num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
  544. self._out_channels = channels_4
  545. self._out_strides = [4, 8, 16, 32]
  546. self.conv_layer1_1 = ConvNormLayer(
  547. ch_in=3,
  548. ch_out=64,
  549. filter_size=3,
  550. stride=2,
  551. norm_decay=norm_decay,
  552. freeze_norm=freeze_norm,
  553. act='relu',
  554. name="layer1_1")
  555. self.conv_layer1_2 = ConvNormLayer(
  556. ch_in=64,
  557. ch_out=64,
  558. filter_size=3,
  559. stride=2,
  560. norm_decay=norm_decay,
  561. freeze_norm=freeze_norm,
  562. act='relu',
  563. name="layer1_2")
  564. self.la1 = Layer1(
  565. num_channels=64,
  566. has_se=has_se,
  567. norm_decay=norm_decay,
  568. freeze_norm=freeze_norm,
  569. name="layer2")
  570. self.tr1 = TransitionLayer(
  571. in_channels=[256],
  572. out_channels=channels_2,
  573. norm_decay=norm_decay,
  574. freeze_norm=freeze_norm,
  575. name="tr1")
  576. self.st2 = Stage(
  577. num_channels=channels_2,
  578. num_modules=num_modules_2,
  579. num_filters=channels_2,
  580. has_se=self.has_se,
  581. norm_decay=norm_decay,
  582. freeze_norm=freeze_norm,
  583. name="st2")
  584. self.tr2 = TransitionLayer(
  585. in_channels=channels_2,
  586. out_channels=channels_3,
  587. norm_decay=norm_decay,
  588. freeze_norm=freeze_norm,
  589. name="tr2")
  590. self.st3 = Stage(
  591. num_channels=channels_3,
  592. num_modules=num_modules_3,
  593. num_filters=channels_3,
  594. has_se=self.has_se,
  595. norm_decay=norm_decay,
  596. freeze_norm=freeze_norm,
  597. name="st3")
  598. self.tr3 = TransitionLayer(
  599. in_channels=channels_3,
  600. out_channels=channels_4,
  601. norm_decay=norm_decay,
  602. freeze_norm=freeze_norm,
  603. name="tr3")
  604. self.st4 = Stage(
  605. num_channels=channels_4,
  606. num_modules=num_modules_4,
  607. num_filters=channels_4,
  608. has_se=self.has_se,
  609. norm_decay=norm_decay,
  610. freeze_norm=freeze_norm,
  611. multi_scale_output=len(return_idx) > 1,
  612. name="st4")
  613. def forward(self, inputs):
  614. x = inputs['image']
  615. conv1 = self.conv_layer1_1(x)
  616. conv2 = self.conv_layer1_2(conv1)
  617. la1 = self.la1(conv2)
  618. tr1 = self.tr1([la1])
  619. st2 = self.st2(tr1)
  620. tr2 = self.tr2(st2)
  621. st3 = self.st3(tr2)
  622. tr3 = self.tr3(st3)
  623. st4 = self.st4(tr3)
  624. res = []
  625. for i, layer in enumerate(st4):
  626. if i == self.freeze_at:
  627. layer.stop_gradient = True
  628. if i in self.return_idx:
  629. res.append(layer)
  630. return res
  631. @property
  632. def out_shape(self):
  633. return [
  634. ShapeSpec(
  635. channels=self._out_channels[i], stride=self._out_strides[i])
  636. for i in self.return_idx
  637. ]