blazeface_fpn.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn.functional as F
  16. from paddle import ParamAttr
  17. import paddle.nn as nn
  18. from paddle.nn.initializer import KaimingNormal
  19. from paddlex.ppdet.core.workspace import register, serializable
  20. from ..shape_spec import ShapeSpec
  21. __all__ = ['BlazeNeck']
  22. def hard_swish(x):
  23. return x * F.relu6(x + 3) / 6.
  24. class ConvBNLayer(nn.Layer):
  25. def __init__(self,
  26. in_channels,
  27. out_channels,
  28. kernel_size,
  29. stride,
  30. padding,
  31. num_groups=1,
  32. act='relu',
  33. conv_lr=0.1,
  34. conv_decay=0.,
  35. norm_decay=0.,
  36. norm_type='bn',
  37. name=None):
  38. super(ConvBNLayer, self).__init__()
  39. self.act = act
  40. self._conv = nn.Conv2D(
  41. in_channels,
  42. out_channels,
  43. kernel_size=kernel_size,
  44. stride=stride,
  45. padding=padding,
  46. groups=num_groups,
  47. weight_attr=ParamAttr(
  48. learning_rate=conv_lr,
  49. initializer=KaimingNormal(),
  50. name=name + "_weights"),
  51. bias_attr=False)
  52. param_attr = ParamAttr(name=name + "_bn_scale")
  53. bias_attr = ParamAttr(name=name + "_bn_offset")
  54. if norm_type == 'sync_bn':
  55. self._batch_norm = nn.SyncBatchNorm(
  56. out_channels, weight_attr=param_attr, bias_attr=bias_attr)
  57. else:
  58. self._batch_norm = nn.BatchNorm(
  59. out_channels,
  60. act=None,
  61. param_attr=param_attr,
  62. bias_attr=bias_attr,
  63. use_global_stats=False,
  64. moving_mean_name=name + '_bn_mean',
  65. moving_variance_name=name + '_bn_variance')
  66. def forward(self, x):
  67. x = self._conv(x)
  68. x = self._batch_norm(x)
  69. if self.act == "relu":
  70. x = F.relu(x)
  71. elif self.act == "relu6":
  72. x = F.relu6(x)
  73. elif self.act == 'leaky':
  74. x = F.leaky_relu(x)
  75. elif self.act == 'hard_swish':
  76. x = hard_swish(x)
  77. return x
  78. class FPN(nn.Layer):
  79. def __init__(self, in_channels, out_channels, name=None):
  80. super(FPN, self).__init__()
  81. self.conv1_fpn = ConvBNLayer(
  82. in_channels,
  83. out_channels // 2,
  84. kernel_size=1,
  85. padding=0,
  86. stride=1,
  87. act='leaky',
  88. name=name + '_output1')
  89. self.conv2_fpn = ConvBNLayer(
  90. in_channels,
  91. out_channels // 2,
  92. kernel_size=1,
  93. padding=0,
  94. stride=1,
  95. act='leaky',
  96. name=name + '_output2')
  97. self.conv3_fpn = ConvBNLayer(
  98. out_channels // 2,
  99. out_channels // 2,
  100. kernel_size=3,
  101. padding=1,
  102. stride=1,
  103. act='leaky',
  104. name=name + '_merge')
  105. def forward(self, input):
  106. output1 = self.conv1_fpn(input[0])
  107. output2 = self.conv2_fpn(input[1])
  108. up2 = F.upsample(
  109. output2, size=paddle.shape(output1)[-2:], mode='nearest')
  110. output1 = paddle.add(output1, up2)
  111. output1 = self.conv3_fpn(output1)
  112. return output1, output2
  113. class SSH(nn.Layer):
  114. def __init__(self, in_channels, out_channels, name=None):
  115. super(SSH, self).__init__()
  116. assert out_channels % 4 == 0
  117. self.conv0_ssh = ConvBNLayer(
  118. in_channels,
  119. out_channels // 2,
  120. kernel_size=3,
  121. padding=1,
  122. stride=1,
  123. act=None,
  124. name=name + 'ssh_conv3')
  125. self.conv1_ssh = ConvBNLayer(
  126. out_channels // 2,
  127. out_channels // 4,
  128. kernel_size=3,
  129. padding=1,
  130. stride=1,
  131. act='leaky',
  132. name=name + 'ssh_conv5_1')
  133. self.conv2_ssh = ConvBNLayer(
  134. out_channels // 4,
  135. out_channels // 4,
  136. kernel_size=3,
  137. padding=1,
  138. stride=1,
  139. act=None,
  140. name=name + 'ssh_conv5_2')
  141. self.conv3_ssh = ConvBNLayer(
  142. out_channels // 4,
  143. out_channels // 4,
  144. kernel_size=3,
  145. padding=1,
  146. stride=1,
  147. act='leaky',
  148. name=name + 'ssh_conv7_1')
  149. self.conv4_ssh = ConvBNLayer(
  150. out_channels // 4,
  151. out_channels // 4,
  152. kernel_size=3,
  153. padding=1,
  154. stride=1,
  155. act=None,
  156. name=name + 'ssh_conv7_2')
  157. def forward(self, x):
  158. conv0 = self.conv0_ssh(x)
  159. conv1 = self.conv1_ssh(conv0)
  160. conv2 = self.conv2_ssh(conv1)
  161. conv3 = self.conv3_ssh(conv2)
  162. conv4 = self.conv4_ssh(conv3)
  163. concat = paddle.concat([conv0, conv2, conv4], axis=1)
  164. return F.relu(concat)
  165. @register
  166. @serializable
  167. class BlazeNeck(nn.Layer):
  168. def __init__(self, in_channel, neck_type="None", data_format='NCHW'):
  169. super(BlazeNeck, self).__init__()
  170. self.neck_type = neck_type
  171. self.reture_input = False
  172. self._out_channels = in_channel
  173. if self.neck_type == 'None':
  174. self.reture_input = True
  175. if "fpn" in self.neck_type:
  176. self.fpn = FPN(self._out_channels[0],
  177. self._out_channels[1],
  178. name='fpn')
  179. self._out_channels = [
  180. self._out_channels[0] // 2, self._out_channels[1] // 2
  181. ]
  182. if "ssh" in self.neck_type:
  183. self.ssh1 = SSH(self._out_channels[0],
  184. self._out_channels[0],
  185. name='ssh1')
  186. self.ssh2 = SSH(self._out_channels[1],
  187. self._out_channels[1],
  188. name='ssh2')
  189. self._out_channels = [self._out_channels[0], self._out_channels[1]]
  190. def forward(self, inputs):
  191. if self.reture_input:
  192. return inputs
  193. output1, output2 = None, None
  194. if "fpn" in self.neck_type:
  195. backout_4, backout_1 = inputs
  196. output1, output2 = self.fpn([backout_4, backout_1])
  197. if self.neck_type == "only_fpn":
  198. return [output1, output2]
  199. if self.neck_type == "only_ssh":
  200. output1, output2 = inputs
  201. feature1 = self.ssh1(output1)
  202. feature2 = self.ssh2(output2)
  203. return [feature1, feature2]
  204. @property
  205. def out_shape(self):
  206. return [
  207. ShapeSpec(channels=c)
  208. for c in [self._out_channels[0], self._out_channels[1]]
  209. ]