blazenet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. from paddle.nn.initializer import KaimingNormal
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. from ..shape_spec import ShapeSpec
  24. __all__ = ['BlazeNet']
  25. def hard_swish(x):
  26. return x * F.relu6(x + 3) / 6.
  27. class ConvBNLayer(nn.Layer):
  28. def __init__(self,
  29. in_channels,
  30. out_channels,
  31. kernel_size,
  32. stride,
  33. padding,
  34. num_groups=1,
  35. act='relu',
  36. conv_lr=0.1,
  37. conv_decay=0.,
  38. norm_decay=0.,
  39. norm_type='bn',
  40. name=None):
  41. super(ConvBNLayer, self).__init__()
  42. self.act = act
  43. self._conv = nn.Conv2D(
  44. in_channels,
  45. out_channels,
  46. kernel_size=kernel_size,
  47. stride=stride,
  48. padding=padding,
  49. groups=num_groups,
  50. weight_attr=ParamAttr(
  51. learning_rate=conv_lr, initializer=KaimingNormal()),
  52. bias_attr=False)
  53. if norm_type == 'sync_bn':
  54. self._batch_norm = nn.SyncBatchNorm(out_channels)
  55. else:
  56. self._batch_norm = nn.BatchNorm(
  57. out_channels, act=None, use_global_stats=False)
  58. def forward(self, x):
  59. x = self._conv(x)
  60. x = self._batch_norm(x)
  61. if self.act == "relu":
  62. x = F.relu(x)
  63. elif self.act == "relu6":
  64. x = F.relu6(x)
  65. elif self.act == 'leaky':
  66. x = F.leaky_relu(x)
  67. elif self.act == 'hard_swish':
  68. x = hard_swish(x)
  69. return x
  70. class BlazeBlock(nn.Layer):
  71. def __init__(self,
  72. in_channels,
  73. out_channels1,
  74. out_channels2,
  75. double_channels=None,
  76. stride=1,
  77. use_5x5kernel=True,
  78. act='relu',
  79. name=None):
  80. super(BlazeBlock, self).__init__()
  81. assert stride in [1, 2]
  82. self.use_pool = not stride == 1
  83. self.use_double_block = double_channels is not None
  84. self.conv_dw = []
  85. if use_5x5kernel:
  86. self.conv_dw.append(
  87. self.add_sublayer(
  88. name + "1_dw",
  89. ConvBNLayer(
  90. in_channels=in_channels,
  91. out_channels=out_channels1,
  92. kernel_size=5,
  93. stride=stride,
  94. padding=2,
  95. num_groups=out_channels1,
  96. name=name + "1_dw")))
  97. else:
  98. self.conv_dw.append(
  99. self.add_sublayer(
  100. name + "1_dw_1",
  101. ConvBNLayer(
  102. in_channels=in_channels,
  103. out_channels=out_channels1,
  104. kernel_size=3,
  105. stride=1,
  106. padding=1,
  107. num_groups=out_channels1,
  108. name=name + "1_dw_1")))
  109. self.conv_dw.append(
  110. self.add_sublayer(
  111. name + "1_dw_2",
  112. ConvBNLayer(
  113. in_channels=out_channels1,
  114. out_channels=out_channels1,
  115. kernel_size=3,
  116. stride=stride,
  117. padding=1,
  118. num_groups=out_channels1,
  119. name=name + "1_dw_2")))
  120. self.act = act if self.use_double_block else None
  121. self.conv_pw = ConvBNLayer(
  122. in_channels=out_channels1,
  123. out_channels=out_channels2,
  124. kernel_size=1,
  125. stride=1,
  126. padding=0,
  127. act=self.act,
  128. name=name + "1_sep")
  129. if self.use_double_block:
  130. self.conv_dw2 = []
  131. if use_5x5kernel:
  132. self.conv_dw2.append(
  133. self.add_sublayer(
  134. name + "2_dw",
  135. ConvBNLayer(
  136. in_channels=out_channels2,
  137. out_channels=out_channels2,
  138. kernel_size=5,
  139. stride=1,
  140. padding=2,
  141. num_groups=out_channels2,
  142. name=name + "2_dw")))
  143. else:
  144. self.conv_dw2.append(
  145. self.add_sublayer(
  146. name + "2_dw_1",
  147. ConvBNLayer(
  148. in_channels=out_channels2,
  149. out_channels=out_channels2,
  150. kernel_size=3,
  151. stride=1,
  152. padding=1,
  153. num_groups=out_channels2,
  154. name=name + "1_dw_1")))
  155. self.conv_dw2.append(
  156. self.add_sublayer(
  157. name + "2_dw_2",
  158. ConvBNLayer(
  159. in_channels=out_channels2,
  160. out_channels=out_channels2,
  161. kernel_size=3,
  162. stride=1,
  163. padding=1,
  164. num_groups=out_channels2,
  165. name=name + "2_dw_2")))
  166. self.conv_pw2 = ConvBNLayer(
  167. in_channels=out_channels2,
  168. out_channels=double_channels,
  169. kernel_size=1,
  170. stride=1,
  171. padding=0,
  172. name=name + "2_sep")
  173. # shortcut
  174. if self.use_pool:
  175. shortcut_channel = double_channels or out_channels2
  176. self._shortcut = []
  177. self._shortcut.append(
  178. self.add_sublayer(
  179. name + '_shortcut_pool',
  180. nn.MaxPool2D(
  181. kernel_size=stride, stride=stride, ceil_mode=True)))
  182. self._shortcut.append(
  183. self.add_sublayer(
  184. name + '_shortcut_conv',
  185. ConvBNLayer(
  186. in_channels=in_channels,
  187. out_channels=shortcut_channel,
  188. kernel_size=1,
  189. stride=1,
  190. padding=0,
  191. name="shortcut" + name)))
  192. def forward(self, x):
  193. y = x
  194. for conv_dw_block in self.conv_dw:
  195. y = conv_dw_block(y)
  196. y = self.conv_pw(y)
  197. if self.use_double_block:
  198. for conv_dw2_block in self.conv_dw2:
  199. y = conv_dw2_block(y)
  200. y = self.conv_pw2(y)
  201. if self.use_pool:
  202. for shortcut in self._shortcut:
  203. x = shortcut(x)
  204. return F.relu(paddle.add(x, y))
  205. @register
  206. @serializable
  207. class BlazeNet(nn.Layer):
  208. """
  209. BlazeFace, see https://arxiv.org/abs/1907.05047
  210. Args:
  211. blaze_filters (list): number of filter for each blaze block.
  212. double_blaze_filters (list): number of filter for each double_blaze block.
  213. use_5x5kernel (bool): whether or not filter size is 5x5 in depth-wise conv.
  214. """
  215. def __init__(self,
  216. blaze_filters=[[24, 24], [24, 24], [24, 48, 2], [48, 48],
  217. [48, 48]],
  218. double_blaze_filters=[[48, 24, 96, 2], [96, 24, 96],
  219. [96, 24, 96], [96, 24, 96, 2],
  220. [96, 24, 96], [96, 24, 96]],
  221. use_5x5kernel=True,
  222. act=None):
  223. super(BlazeNet, self).__init__()
  224. conv1_num_filters = blaze_filters[0][0]
  225. self.conv1 = ConvBNLayer(
  226. in_channels=3,
  227. out_channels=conv1_num_filters,
  228. kernel_size=3,
  229. stride=2,
  230. padding=1,
  231. name="conv1")
  232. in_channels = conv1_num_filters
  233. self.blaze_block = []
  234. self._out_channels = []
  235. for k, v in enumerate(blaze_filters):
  236. assert len(v) in [2, 3], \
  237. "blaze_filters {} not in [2, 3]"
  238. if len(v) == 2:
  239. self.blaze_block.append(
  240. self.add_sublayer(
  241. 'blaze_{}'.format(k),
  242. BlazeBlock(
  243. in_channels,
  244. v[0],
  245. v[1],
  246. use_5x5kernel=use_5x5kernel,
  247. act=act,
  248. name='blaze_{}'.format(k))))
  249. elif len(v) == 3:
  250. self.blaze_block.append(
  251. self.add_sublayer(
  252. 'blaze_{}'.format(k),
  253. BlazeBlock(
  254. in_channels,
  255. v[0],
  256. v[1],
  257. stride=v[2],
  258. use_5x5kernel=use_5x5kernel,
  259. act=act,
  260. name='blaze_{}'.format(k))))
  261. in_channels = v[1]
  262. for k, v in enumerate(double_blaze_filters):
  263. assert len(v) in [3, 4], \
  264. "blaze_filters {} not in [3, 4]"
  265. if len(v) == 3:
  266. self.blaze_block.append(
  267. self.add_sublayer(
  268. 'double_blaze_{}'.format(k),
  269. BlazeBlock(
  270. in_channels,
  271. v[0],
  272. v[1],
  273. double_channels=v[2],
  274. use_5x5kernel=use_5x5kernel,
  275. act=act,
  276. name='double_blaze_{}'.format(k))))
  277. elif len(v) == 4:
  278. self.blaze_block.append(
  279. self.add_sublayer(
  280. 'double_blaze_{}'.format(k),
  281. BlazeBlock(
  282. in_channels,
  283. v[0],
  284. v[1],
  285. double_channels=v[2],
  286. stride=v[3],
  287. use_5x5kernel=use_5x5kernel,
  288. act=act,
  289. name='double_blaze_{}'.format(k))))
  290. in_channels = v[2]
  291. self._out_channels.append(in_channels)
  292. def forward(self, inputs):
  293. outs = []
  294. y = self.conv1(inputs['image'])
  295. for block in self.blaze_block:
  296. y = block(y)
  297. outs.append(y)
  298. return [outs[-4], outs[-1]]
  299. @property
  300. def out_shape(self):
  301. return [
  302. ShapeSpec(channels=c)
  303. for c in [self._out_channels[-4], self._out_channels[-1]]
  304. ]