blazenet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. from paddle.nn.initializer import KaimingNormal
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. from ..shape_spec import ShapeSpec
  24. __all__ = ['BlazeNet']
  25. def hard_swish(x):
  26. return x * F.relu6(x + 3) / 6.
  27. class ConvBNLayer(nn.Layer):
  28. def __init__(self,
  29. in_channels,
  30. out_channels,
  31. kernel_size,
  32. stride,
  33. padding,
  34. num_groups=1,
  35. act='relu',
  36. conv_lr=0.1,
  37. conv_decay=0.,
  38. norm_decay=0.,
  39. norm_type='bn',
  40. name=None):
  41. super(ConvBNLayer, self).__init__()
  42. self.act = act
  43. self._conv = nn.Conv2D(
  44. in_channels,
  45. out_channels,
  46. kernel_size=kernel_size,
  47. stride=stride,
  48. padding=padding,
  49. groups=num_groups,
  50. weight_attr=ParamAttr(
  51. learning_rate=conv_lr,
  52. initializer=KaimingNormal(),
  53. name=name + "_weights"),
  54. bias_attr=False)
  55. param_attr = ParamAttr(name=name + "_bn_scale")
  56. bias_attr = ParamAttr(name=name + "_bn_offset")
  57. if norm_type == 'sync_bn':
  58. self._batch_norm = nn.SyncBatchNorm(
  59. out_channels, weight_attr=param_attr, bias_attr=bias_attr)
  60. else:
  61. self._batch_norm = nn.BatchNorm(
  62. out_channels,
  63. act=None,
  64. param_attr=param_attr,
  65. bias_attr=bias_attr,
  66. use_global_stats=False,
  67. moving_mean_name=name + '_bn_mean',
  68. moving_variance_name=name + '_bn_variance')
  69. def forward(self, x):
  70. x = self._conv(x)
  71. x = self._batch_norm(x)
  72. if self.act == "relu":
  73. x = F.relu(x)
  74. elif self.act == "relu6":
  75. x = F.relu6(x)
  76. elif self.act == 'leaky':
  77. x = F.leaky_relu(x)
  78. elif self.act == 'hard_swish':
  79. x = hard_swish(x)
  80. return x
  81. class BlazeBlock(nn.Layer):
  82. def __init__(self,
  83. in_channels,
  84. out_channels1,
  85. out_channels2,
  86. double_channels=None,
  87. stride=1,
  88. use_5x5kernel=True,
  89. act='relu',
  90. name=None):
  91. super(BlazeBlock, self).__init__()
  92. assert stride in [1, 2]
  93. self.use_pool = not stride == 1
  94. self.use_double_block = double_channels is not None
  95. self.conv_dw = []
  96. if use_5x5kernel:
  97. self.conv_dw.append(
  98. self.add_sublayer(
  99. name + "1_dw",
  100. ConvBNLayer(
  101. in_channels=in_channels,
  102. out_channels=out_channels1,
  103. kernel_size=5,
  104. stride=stride,
  105. padding=2,
  106. num_groups=out_channels1,
  107. name=name + "1_dw")))
  108. else:
  109. self.conv_dw.append(
  110. self.add_sublayer(
  111. name + "1_dw_1",
  112. ConvBNLayer(
  113. in_channels=in_channels,
  114. out_channels=out_channels1,
  115. kernel_size=3,
  116. stride=1,
  117. padding=1,
  118. num_groups=out_channels1,
  119. name=name + "1_dw_1")))
  120. self.conv_dw.append(
  121. self.add_sublayer(
  122. name + "1_dw_2",
  123. ConvBNLayer(
  124. in_channels=out_channels1,
  125. out_channels=out_channels1,
  126. kernel_size=3,
  127. stride=stride,
  128. padding=1,
  129. num_groups=out_channels1,
  130. name=name + "1_dw_2")))
  131. self.act = act if self.use_double_block else None
  132. self.conv_pw = ConvBNLayer(
  133. in_channels=out_channels1,
  134. out_channels=out_channels2,
  135. kernel_size=1,
  136. stride=1,
  137. padding=0,
  138. act=self.act,
  139. name=name + "1_sep")
  140. if self.use_double_block:
  141. self.conv_dw2 = []
  142. if use_5x5kernel:
  143. self.conv_dw2.append(
  144. self.add_sublayer(
  145. name + "2_dw",
  146. ConvBNLayer(
  147. in_channels=out_channels2,
  148. out_channels=out_channels2,
  149. kernel_size=5,
  150. stride=1,
  151. padding=2,
  152. num_groups=out_channels2,
  153. name=name + "2_dw")))
  154. else:
  155. self.conv_dw2.append(
  156. self.add_sublayer(
  157. name + "2_dw_1",
  158. ConvBNLayer(
  159. in_channels=out_channels2,
  160. out_channels=out_channels2,
  161. kernel_size=3,
  162. stride=1,
  163. padding=1,
  164. num_groups=out_channels2,
  165. name=name + "1_dw_1")))
  166. self.conv_dw2.append(
  167. self.add_sublayer(
  168. name + "2_dw_2",
  169. ConvBNLayer(
  170. in_channels=out_channels2,
  171. out_channels=out_channels2,
  172. kernel_size=3,
  173. stride=1,
  174. padding=1,
  175. num_groups=out_channels2,
  176. name=name + "2_dw_2")))
  177. self.conv_pw2 = ConvBNLayer(
  178. in_channels=out_channels2,
  179. out_channels=double_channels,
  180. kernel_size=1,
  181. stride=1,
  182. padding=0,
  183. name=name + "2_sep")
  184. # shortcut
  185. if self.use_pool:
  186. shortcut_channel = double_channels or out_channels2
  187. self._shortcut = []
  188. self._shortcut.append(
  189. self.add_sublayer(
  190. name + '_shortcut_pool',
  191. nn.MaxPool2D(
  192. kernel_size=stride, stride=stride, ceil_mode=True)))
  193. self._shortcut.append(
  194. self.add_sublayer(
  195. name + '_shortcut_conv',
  196. ConvBNLayer(
  197. in_channels=in_channels,
  198. out_channels=shortcut_channel,
  199. kernel_size=1,
  200. stride=1,
  201. padding=0,
  202. name="shortcut" + name)))
  203. def forward(self, x):
  204. y = x
  205. for conv_dw_block in self.conv_dw:
  206. y = conv_dw_block(y)
  207. y = self.conv_pw(y)
  208. if self.use_double_block:
  209. for conv_dw2_block in self.conv_dw2:
  210. y = conv_dw2_block(y)
  211. y = self.conv_pw2(y)
  212. if self.use_pool:
  213. for shortcut in self._shortcut:
  214. x = shortcut(x)
  215. return F.relu(paddle.add(x, y))
  216. @register
  217. @serializable
  218. class BlazeNet(nn.Layer):
  219. """
  220. BlazeFace, see https://arxiv.org/abs/1907.05047
  221. Args:
  222. blaze_filters (list): number of filter for each blaze block.
  223. double_blaze_filters (list): number of filter for each double_blaze block.
  224. use_5x5kernel (bool): whether or not filter size is 5x5 in depth-wise conv.
  225. """
  226. def __init__(self,
  227. blaze_filters=[[24, 24], [24, 24], [24, 48, 2], [48, 48],
  228. [48, 48]],
  229. double_blaze_filters=[[48, 24, 96, 2], [96, 24, 96],
  230. [96, 24, 96], [96, 24, 96, 2],
  231. [96, 24, 96], [96, 24, 96]],
  232. use_5x5kernel=True,
  233. act=None):
  234. super(BlazeNet, self).__init__()
  235. conv1_num_filters = blaze_filters[0][0]
  236. self.conv1 = ConvBNLayer(
  237. in_channels=3,
  238. out_channels=conv1_num_filters,
  239. kernel_size=3,
  240. stride=2,
  241. padding=1,
  242. name="conv1")
  243. in_channels = conv1_num_filters
  244. self.blaze_block = []
  245. self._out_channels = []
  246. for k, v in enumerate(blaze_filters):
  247. assert len(v) in [2, 3], \
  248. "blaze_filters {} not in [2, 3]"
  249. if len(v) == 2:
  250. self.blaze_block.append(
  251. self.add_sublayer(
  252. 'blaze_{}'.format(k),
  253. BlazeBlock(
  254. in_channels,
  255. v[0],
  256. v[1],
  257. use_5x5kernel=use_5x5kernel,
  258. act=act,
  259. name='blaze_{}'.format(k))))
  260. elif len(v) == 3:
  261. self.blaze_block.append(
  262. self.add_sublayer(
  263. 'blaze_{}'.format(k),
  264. BlazeBlock(
  265. in_channels,
  266. v[0],
  267. v[1],
  268. stride=v[2],
  269. use_5x5kernel=use_5x5kernel,
  270. act=act,
  271. name='blaze_{}'.format(k))))
  272. in_channels = v[1]
  273. for k, v in enumerate(double_blaze_filters):
  274. assert len(v) in [3, 4], \
  275. "blaze_filters {} not in [3, 4]"
  276. if len(v) == 3:
  277. self.blaze_block.append(
  278. self.add_sublayer(
  279. 'double_blaze_{}'.format(k),
  280. BlazeBlock(
  281. in_channels,
  282. v[0],
  283. v[1],
  284. double_channels=v[2],
  285. use_5x5kernel=use_5x5kernel,
  286. act=act,
  287. name='double_blaze_{}'.format(k))))
  288. elif len(v) == 4:
  289. self.blaze_block.append(
  290. self.add_sublayer(
  291. 'double_blaze_{}'.format(k),
  292. BlazeBlock(
  293. in_channels,
  294. v[0],
  295. v[1],
  296. double_channels=v[2],
  297. stride=v[3],
  298. use_5x5kernel=use_5x5kernel,
  299. act=act,
  300. name='double_blaze_{}'.format(k))))
  301. in_channels = v[2]
  302. self._out_channels.append(in_channels)
  303. def forward(self, inputs):
  304. outs = []
  305. y = self.conv1(inputs['image'])
  306. for block in self.blaze_block:
  307. y = block(y)
  308. outs.append(y)
  309. return [outs[-4], outs[-1]]
  310. @property
  311. def out_shape(self):
  312. return [
  313. ShapeSpec(channels=c)
  314. for c in [self._out_channels[-4], self._out_channels[-1]]
  315. ]