centernet_fpn.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import math
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddle.nn.initializer import KaimingUniform
  20. from paddlex.ppdet.core.workspace import register, serializable
  21. from paddlex.ppdet.modeling.layers import ConvNormLayer
  22. from paddlex.ppdet.modeling.backbones.hardnet import ConvLayer, HarDBlock
  23. from ..shape_spec import ShapeSpec
  24. __all__ = ['CenterNetDLAFPN', 'CenterNetHarDNetFPN']
  25. def fill_up_weights(up):
  26. weight = up.weight
  27. f = math.ceil(weight.shape[2] / 2)
  28. c = (2 * f - 1 - f % 2) / (2. * f)
  29. for i in range(weight.shape[2]):
  30. for j in range(weight.shape[3]):
  31. weight[0, 0, i, j] = \
  32. (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
  33. for c in range(1, weight.shape[0]):
  34. weight[c, 0, :, :] = weight[0, 0, :, :]
  35. class IDAUp(nn.Layer):
  36. def __init__(self, ch_ins, ch_out, up_strides, dcn_v2=True):
  37. super(IDAUp, self).__init__()
  38. for i in range(1, len(ch_ins)):
  39. ch_in = ch_ins[i]
  40. up_s = int(up_strides[i])
  41. proj = nn.Sequential(
  42. ConvNormLayer(
  43. ch_in,
  44. ch_out,
  45. filter_size=3,
  46. stride=1,
  47. use_dcn=dcn_v2,
  48. bias_on=dcn_v2,
  49. norm_decay=None,
  50. dcn_lr_scale=1.,
  51. dcn_regularizer=None),
  52. nn.ReLU())
  53. node = nn.Sequential(
  54. ConvNormLayer(
  55. ch_out,
  56. ch_out,
  57. filter_size=3,
  58. stride=1,
  59. use_dcn=dcn_v2,
  60. bias_on=dcn_v2,
  61. norm_decay=None,
  62. dcn_lr_scale=1.,
  63. dcn_regularizer=None),
  64. nn.ReLU())
  65. param_attr = paddle.ParamAttr(initializer=KaimingUniform())
  66. up = nn.Conv2DTranspose(
  67. ch_out,
  68. ch_out,
  69. kernel_size=up_s * 2,
  70. weight_attr=param_attr,
  71. stride=up_s,
  72. padding=up_s // 2,
  73. groups=ch_out,
  74. bias_attr=False)
  75. # TODO: uncomment fill_up_weights
  76. #fill_up_weights(up)
  77. setattr(self, 'proj_' + str(i), proj)
  78. setattr(self, 'up_' + str(i), up)
  79. setattr(self, 'node_' + str(i), node)
  80. def forward(self, inputs, start_level, end_level):
  81. for i in range(start_level + 1, end_level):
  82. upsample = getattr(self, 'up_' + str(i - start_level))
  83. project = getattr(self, 'proj_' + str(i - start_level))
  84. inputs[i] = project(inputs[i])
  85. inputs[i] = upsample(inputs[i])
  86. node = getattr(self, 'node_' + str(i - start_level))
  87. inputs[i] = node(paddle.add(inputs[i], inputs[i - 1]))
  88. class DLAUp(nn.Layer):
  89. def __init__(self, start_level, channels, scales, ch_in=None, dcn_v2=True):
  90. super(DLAUp, self).__init__()
  91. self.start_level = start_level
  92. if ch_in is None:
  93. ch_in = channels
  94. self.channels = channels
  95. channels = list(channels)
  96. scales = np.array(scales, dtype=int)
  97. for i in range(len(channels) - 1):
  98. j = -i - 2
  99. setattr(
  100. self,
  101. 'ida_{}'.format(i),
  102. IDAUp(
  103. ch_in[j:],
  104. channels[j],
  105. scales[j:] // scales[j],
  106. dcn_v2=dcn_v2))
  107. scales[j + 1:] = scales[j]
  108. ch_in[j + 1:] = [channels[j] for _ in channels[j + 1:]]
  109. def forward(self, inputs):
  110. out = [inputs[-1]] # start with 32
  111. for i in range(len(inputs) - self.start_level - 1):
  112. ida = getattr(self, 'ida_{}'.format(i))
  113. ida(inputs, len(inputs) - i - 2, len(inputs))
  114. out.insert(0, inputs[-1])
  115. return out
  116. @register
  117. @serializable
  118. class CenterNetDLAFPN(nn.Layer):
  119. """
  120. Args:
  121. in_channels (list): number of input feature channels from backbone.
  122. [16, 32, 64, 128, 256, 512] by default, means the channels of DLA-34
  123. down_ratio (int): the down ratio from images to heatmap, 4 by default
  124. last_level (int): the last level of input feature fed into the upsamplng block
  125. out_channel (int): the channel of the output feature, 0 by default means
  126. the channel of the input feature whose down ratio is `down_ratio`
  127. first_level (int): the first level of input feature fed into the upsamplng
  128. block, -1 by default and it will be calculated by down_ratio
  129. dcn_v2 (bool): whether use the DCNv2, true by default
  130. """
  131. def __init__(self,
  132. in_channels,
  133. down_ratio=4,
  134. last_level=5,
  135. out_channel=0,
  136. first_level=-1,
  137. dcn_v2=True):
  138. super(CenterNetDLAFPN, self).__init__()
  139. self.first_level = int(np.log2(
  140. down_ratio)) if first_level == -1 else first_level
  141. self.down_ratio = down_ratio
  142. self.last_level = last_level
  143. scales = [2**i for i in range(len(in_channels[self.first_level:]))]
  144. self.dla_up = DLAUp(
  145. self.first_level,
  146. in_channels[self.first_level:],
  147. scales,
  148. dcn_v2=dcn_v2)
  149. self.out_channel = out_channel
  150. if out_channel == 0:
  151. self.out_channel = in_channels[self.first_level]
  152. self.ida_up = IDAUp(
  153. in_channels[self.first_level:self.last_level],
  154. self.out_channel,
  155. [2**i for i in range(self.last_level - self.first_level)],
  156. dcn_v2=dcn_v2)
  157. @classmethod
  158. def from_config(cls, cfg, input_shape):
  159. return {'in_channels': [i.channels for i in input_shape]}
  160. def forward(self, body_feats):
  161. dla_up_feats = self.dla_up(body_feats)
  162. ida_up_feats = []
  163. for i in range(self.last_level - self.first_level):
  164. ida_up_feats.append(dla_up_feats[i].clone())
  165. self.ida_up(ida_up_feats, 0, len(ida_up_feats))
  166. return ida_up_feats[-1]
  167. @property
  168. def out_shape(self):
  169. return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)]
  170. class TransitionUp(nn.Layer):
  171. def __init__(self, in_channels, out_channels):
  172. super().__init__()
  173. def forward(self, x, skip, concat=True):
  174. w, h = skip.shape[2], skip.shape[3]
  175. out = F.interpolate(
  176. x, size=(w, h), mode="bilinear", align_corners=True)
  177. if concat:
  178. out = paddle.concat([out, skip], 1)
  179. return out
  180. @register
  181. @serializable
  182. class CenterNetHarDNetFPN(nn.Layer):
  183. """
  184. Args:
  185. in_channels (list): number of input feature channels from backbone.
  186. [96, 214, 458, 784] by default, means the channels of HarDNet85
  187. num_layers (int): HarDNet laters, 85 by default
  188. down_ratio (int): the down ratio from images to heatmap, 4 by default
  189. first_level (int): the first level of input feature fed into the
  190. upsamplng block
  191. last_level (int): the last level of input feature fed into the upsamplng block
  192. out_channel (int): the channel of the output feature, 0 by default means
  193. the channel of the input feature whose down ratio is `down_ratio`
  194. """
  195. def __init__(self,
  196. in_channels,
  197. num_layers=85,
  198. down_ratio=4,
  199. first_level=-1,
  200. last_level=4,
  201. out_channel=0):
  202. super(CenterNetHarDNetFPN, self).__init__()
  203. self.first_level = int(np.log2(
  204. down_ratio)) - 1 if first_level == -1 else first_level
  205. self.down_ratio = down_ratio
  206. self.last_level = last_level
  207. self.last_pool = nn.AvgPool2D(kernel_size=2, stride=2)
  208. assert num_layers in [68, 85], "HarDNet-{} not support.".format(
  209. num_layers)
  210. if num_layers == 85:
  211. self.last_proj = ConvLayer(784, 256, kernel_size=1)
  212. self.last_blk = HarDBlock(768, 80, 1.7, 8)
  213. self.skip_nodes = [1, 3, 8, 13]
  214. self.SC = [32, 32, 0]
  215. gr = [64, 48, 28]
  216. layers = [8, 8, 4]
  217. ch_list2 = [224 + self.SC[0], 160 + self.SC[1], 96 + self.SC[2]]
  218. channels = [96, 214, 458, 784]
  219. self.skip_lv = 3
  220. elif num_layers == 68:
  221. self.last_proj = ConvLayer(654, 192, kernel_size=1)
  222. self.last_blk = HarDBlock(576, 72, 1.7, 8)
  223. self.skip_nodes = [1, 3, 8, 11]
  224. self.SC = [32, 32, 0]
  225. gr = [48, 32, 20]
  226. layers = [8, 8, 4]
  227. ch_list2 = [224 + self.SC[0], 96 + self.SC[1], 64 + self.SC[2]]
  228. channels = [64, 124, 328, 654]
  229. self.skip_lv = 2
  230. self.transUpBlocks = nn.LayerList([])
  231. self.denseBlocksUp = nn.LayerList([])
  232. self.conv1x1_up = nn.LayerList([])
  233. self.avg9x9 = nn.AvgPool2D(
  234. kernel_size=(9, 9), stride=1, padding=(4, 4))
  235. prev_ch = self.last_blk.get_out_ch()
  236. for i in range(3):
  237. skip_ch = channels[3 - i]
  238. self.transUpBlocks.append(TransitionUp(prev_ch, prev_ch))
  239. if i < self.skip_lv:
  240. cur_ch = prev_ch + skip_ch
  241. else:
  242. cur_ch = prev_ch
  243. self.conv1x1_up.append(
  244. ConvLayer(
  245. cur_ch, ch_list2[i], kernel_size=1))
  246. cur_ch = ch_list2[i]
  247. cur_ch -= self.SC[i]
  248. cur_ch *= 3
  249. blk = HarDBlock(cur_ch, gr[i], 1.7, layers[i])
  250. self.denseBlocksUp.append(blk)
  251. prev_ch = blk.get_out_ch()
  252. prev_ch += self.SC[0] + self.SC[1] + self.SC[2]
  253. self.out_channel = prev_ch
  254. @classmethod
  255. def from_config(cls, cfg, input_shape):
  256. return {'in_channels': [i.channels for i in input_shape]}
  257. def forward(self, body_feats):
  258. x = body_feats[-1]
  259. x_sc = []
  260. x = self.last_proj(x)
  261. x = self.last_pool(x)
  262. x2 = self.avg9x9(x)
  263. x3 = x / (x.sum((2, 3), keepdim=True) + 0.1)
  264. x = paddle.concat([x, x2, x3], 1)
  265. x = self.last_blk(x)
  266. for i in range(3):
  267. skip_x = body_feats[3 - i]
  268. x = self.transUpBlocks[i](x, skip_x, (i < self.skip_lv))
  269. x = self.conv1x1_up[i](x)
  270. if self.SC[i] > 0:
  271. end = x.shape[1]
  272. x_sc.append(x[:, end - self.SC[i]:, :, :])
  273. x = x[:, :end - self.SC[i], :, :]
  274. x2 = self.avg9x9(x)
  275. x3 = x / (x.sum((2, 3), keepdim=True) + 0.1)
  276. x = paddle.concat([x, x2, x3], 1)
  277. x = self.denseBlocksUp[i](x)
  278. scs = [x]
  279. for i in range(3):
  280. if self.SC[i] > 0:
  281. scs.insert(
  282. 0,
  283. F.interpolate(
  284. x_sc[i],
  285. size=(x.shape[2], x.shape[3]),
  286. mode="bilinear",
  287. align_corners=True))
  288. neck_feat = paddle.concat(scs, 1)
  289. return neck_feat
  290. @property
  291. def out_shape(self):
  292. return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)]