bifpn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle import ParamAttr
  18. from paddle.nn.initializer import Constant
  19. from paddlex.ppdet.core.workspace import register, serializable
  20. from paddlex.ppdet.modeling.layers import ConvNormLayer
  21. from ..shape_spec import ShapeSpec
  22. __all__ = ['BiFPN']
  23. class SeparableConvLayer(nn.Layer):
  24. def __init__(self,
  25. in_channels,
  26. out_channels=None,
  27. kernel_size=3,
  28. norm_type='bn',
  29. norm_groups=32,
  30. act='swish'):
  31. super(SeparableConvLayer, self).__init__()
  32. assert norm_type in ['bn', 'sync_bn', 'gn', None]
  33. assert act in ['swish', 'relu', None]
  34. self.in_channels = in_channels
  35. if out_channels is None:
  36. self.out_channels = self.in_channels
  37. self.norm_type = norm_type
  38. self.norm_groups = norm_groups
  39. self.depthwise_conv = nn.Conv2D(
  40. in_channels,
  41. in_channels,
  42. kernel_size,
  43. padding=kernel_size // 2,
  44. groups=in_channels,
  45. bias_attr=False)
  46. self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1)
  47. # norm type
  48. if self.norm_type == 'bn':
  49. self.norm = nn.BatchNorm2D(self.out_channels)
  50. elif self.norm_type == 'sync_bn':
  51. self.norm = nn.SyncBatchNorm(self.out_channels)
  52. elif self.norm_type == 'gn':
  53. self.norm = nn.GroupNorm(
  54. num_groups=self.norm_groups, num_channels=self.out_channels)
  55. # activation
  56. if act == 'swish':
  57. self.act = nn.Swish()
  58. elif act == 'relu':
  59. self.act = nn.ReLU()
  60. def forward(self, x):
  61. if self.act is not None:
  62. x = self.act(x)
  63. out = self.depthwise_conv(x)
  64. out = self.pointwise_conv(out)
  65. if self.norm_type is not None:
  66. out = self.norm(out)
  67. return out
  68. class BiFPNCell(nn.Layer):
  69. def __init__(self,
  70. channels=256,
  71. num_levels=5,
  72. eps=1e-5,
  73. use_weighted_fusion=True,
  74. kernel_size=3,
  75. norm_type='bn',
  76. norm_groups=32,
  77. act='swish'):
  78. super(BiFPNCell, self).__init__()
  79. self.channels = channels
  80. self.num_levels = num_levels
  81. self.eps = eps
  82. self.use_weighted_fusion = use_weighted_fusion
  83. # up
  84. self.conv_up = nn.LayerList([
  85. SeparableConvLayer(
  86. self.channels,
  87. kernel_size=kernel_size,
  88. norm_type=norm_type,
  89. norm_groups=norm_groups,
  90. act=act) for _ in range(self.num_levels - 1)
  91. ])
  92. # down
  93. self.conv_down = nn.LayerList([
  94. SeparableConvLayer(
  95. self.channels,
  96. kernel_size=kernel_size,
  97. norm_type=norm_type,
  98. norm_groups=norm_groups,
  99. act=act) for _ in range(self.num_levels - 1)
  100. ])
  101. if self.use_weighted_fusion:
  102. self.up_weights = self.create_parameter(
  103. shape=[self.num_levels - 1, 2],
  104. attr=ParamAttr(initializer=Constant(1.)))
  105. self.down_weights = self.create_parameter(
  106. shape=[self.num_levels - 1, 3],
  107. attr=ParamAttr(initializer=Constant(1.)))
  108. def _feature_fusion_cell(self,
  109. conv_layer,
  110. lateral_feat,
  111. sampling_feat,
  112. route_feat=None,
  113. weights=None):
  114. if self.use_weighted_fusion:
  115. weights = F.relu(weights)
  116. weights = weights / (weights.sum() + self.eps)
  117. if route_feat is not None:
  118. out_feat = weights[0] * lateral_feat + \
  119. weights[1] * sampling_feat + \
  120. weights[2] * route_feat
  121. else:
  122. out_feat = weights[0] * lateral_feat + \
  123. weights[1] * sampling_feat
  124. else:
  125. if route_feat is not None:
  126. out_feat = lateral_feat + sampling_feat + route_feat
  127. else:
  128. out_feat = lateral_feat + sampling_feat
  129. out_feat = conv_layer(out_feat)
  130. return out_feat
  131. def forward(self, feats):
  132. # feats: [P3 - P7]
  133. lateral_feats = []
  134. # up
  135. up_feature = feats[-1]
  136. for i, feature in enumerate(feats[::-1]):
  137. if i == 0:
  138. lateral_feats.append(feature)
  139. else:
  140. shape = paddle.shape(feature)
  141. up_feature = F.interpolate(
  142. up_feature, size=[shape[2], shape[3]])
  143. lateral_feature = self._feature_fusion_cell(
  144. self.conv_up[i - 1],
  145. feature,
  146. up_feature,
  147. weights=self.up_weights[i - 1]
  148. if self.use_weighted_fusion else None)
  149. lateral_feats.append(lateral_feature)
  150. up_feature = lateral_feature
  151. out_feats = []
  152. # down
  153. down_feature = lateral_feats[-1]
  154. for i, (lateral_feature,
  155. route_feature) in enumerate(zip(lateral_feats[::-1], feats)):
  156. if i == 0:
  157. out_feats.append(lateral_feature)
  158. else:
  159. down_feature = F.max_pool2d(down_feature, 3, 2, 1)
  160. if i == len(feats) - 1:
  161. route_feature = None
  162. weights = self.down_weights[
  163. i - 1][:2] if self.use_weighted_fusion else None
  164. else:
  165. weights = self.down_weights[
  166. i - 1] if self.use_weighted_fusion else None
  167. out_feature = self._feature_fusion_cell(
  168. self.conv_down[i - 1],
  169. lateral_feature,
  170. down_feature,
  171. route_feature,
  172. weights=weights)
  173. out_feats.append(out_feature)
  174. down_feature = out_feature
  175. return out_feats
  176. @register
  177. @serializable
  178. class BiFPN(nn.Layer):
  179. """
  180. Bidirectional Feature Pyramid Network, see https://arxiv.org/abs/1911.09070
  181. Args:
  182. in_channels (list[int]): input channels of each level which can be
  183. derived from the output shape of backbone by from_config.
  184. out_channel (int): output channel of each level.
  185. num_extra_levels (int): the number of extra stages added to the last level.
  186. default: 2
  187. fpn_strides (List): The stride of each level.
  188. num_stacks (int): the number of stacks for BiFPN, default: 1.
  189. use_weighted_fusion (bool): use weighted feature fusion in BiFPN, default: True.
  190. norm_type (string|None): the normalization type in BiFPN module. If
  191. norm_type is None, norm will not be used after conv and if
  192. norm_type is string, bn, gn, sync_bn are available. default: bn.
  193. norm_groups (int): if you use gn, set this param.
  194. act (string|None): the activation function of BiFPN.
  195. """
  196. def __init__(self,
  197. in_channels=(512, 1024, 2048),
  198. out_channel=256,
  199. num_extra_levels=2,
  200. fpn_strides=[8, 16, 32, 64, 128],
  201. num_stacks=1,
  202. use_weighted_fusion=True,
  203. norm_type='bn',
  204. norm_groups=32,
  205. act='swish'):
  206. super(BiFPN, self).__init__()
  207. assert num_stacks > 0, "The number of stacks of BiFPN is at least 1."
  208. assert norm_type in ['bn', 'sync_bn', 'gn', None]
  209. assert act in ['swish', 'relu', None]
  210. assert num_extra_levels >= 0, \
  211. "The `num_extra_levels` must be non negative(>=0)."
  212. self.in_channels = in_channels
  213. self.out_channel = out_channel
  214. self.num_extra_levels = num_extra_levels
  215. self.num_stacks = num_stacks
  216. self.use_weighted_fusion = use_weighted_fusion
  217. self.norm_type = norm_type
  218. self.norm_groups = norm_groups
  219. self.act = act
  220. self.num_levels = len(self.in_channels) + self.num_extra_levels
  221. if len(fpn_strides) != self.num_levels:
  222. for i in range(self.num_extra_levels):
  223. fpn_strides += [fpn_strides[-1] * 2]
  224. self.fpn_strides = fpn_strides
  225. self.lateral_convs = nn.LayerList()
  226. for in_c in in_channels:
  227. self.lateral_convs.append(
  228. ConvNormLayer(in_c, self.out_channel, 1, 1))
  229. if self.num_extra_levels > 0:
  230. self.extra_convs = nn.LayerList()
  231. for i in range(self.num_extra_levels):
  232. if i == 0:
  233. self.extra_convs.append(
  234. ConvNormLayer(self.in_channels[-1], self.out_channel,
  235. 3, 2))
  236. else:
  237. self.extra_convs.append(nn.MaxPool2D(3, 2, 1))
  238. self.bifpn_cells = nn.LayerList()
  239. for i in range(self.num_stacks):
  240. self.bifpn_cells.append(
  241. BiFPNCell(
  242. self.out_channel,
  243. self.num_levels,
  244. use_weighted_fusion=self.use_weighted_fusion,
  245. norm_type=self.norm_type,
  246. norm_groups=self.norm_groups,
  247. act=self.act))
  248. @classmethod
  249. def from_config(cls, cfg, input_shape):
  250. return {
  251. 'in_channels': [i.channels for i in input_shape],
  252. 'fpn_strides': [i.stride for i in input_shape]
  253. }
  254. @property
  255. def out_shape(self):
  256. return [
  257. ShapeSpec(
  258. channels=self.out_channel, stride=s) for s in self.fpn_strides
  259. ]
  260. def forward(self, feats):
  261. assert len(feats) == len(self.in_channels)
  262. fpn_feats = []
  263. for conv_layer, feature in zip(self.lateral_convs, feats):
  264. fpn_feats.append(conv_layer(feature))
  265. if self.num_extra_levels > 0:
  266. feat = feats[-1]
  267. for conv_layer in self.extra_convs:
  268. feat = conv_layer(feat)
  269. fpn_feats.append(feat)
  270. for bifpn_cell in self.bifpn_cells:
  271. fpn_feats = bifpn_cell(fpn_feats)
  272. return fpn_feats