blazenet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. from paddle.nn.initializer import KaimingNormal
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. from ..shape_spec import ShapeSpec
  24. __all__ = ['BlazeNet']
  25. def hard_swish(x):
  26. return x * F.relu6(x + 3) / 6.
  27. class ConvBNLayer(nn.Layer):
  28. def __init__(self,
  29. in_channels,
  30. out_channels,
  31. kernel_size,
  32. stride,
  33. padding,
  34. num_groups=1,
  35. act='relu',
  36. conv_lr=0.1,
  37. conv_decay=0.,
  38. norm_decay=0.,
  39. norm_type='bn',
  40. name=None):
  41. super(ConvBNLayer, self).__init__()
  42. self.act = act
  43. self._conv = nn.Conv2D(
  44. in_channels,
  45. out_channels,
  46. kernel_size=kernel_size,
  47. stride=stride,
  48. padding=padding,
  49. groups=num_groups,
  50. weight_attr=ParamAttr(
  51. learning_rate=conv_lr, initializer=KaimingNormal()),
  52. bias_attr=False)
  53. if norm_type in ['bn', 'sync_bn']:
  54. self._batch_norm = nn.BatchNorm2D(out_channels)
  55. def forward(self, x):
  56. x = self._conv(x)
  57. x = self._batch_norm(x)
  58. if self.act == "relu":
  59. x = F.relu(x)
  60. elif self.act == "relu6":
  61. x = F.relu6(x)
  62. elif self.act == 'leaky':
  63. x = F.leaky_relu(x)
  64. elif self.act == 'hard_swish':
  65. x = hard_swish(x)
  66. return x
  67. class BlazeBlock(nn.Layer):
  68. def __init__(self,
  69. in_channels,
  70. out_channels1,
  71. out_channels2,
  72. double_channels=None,
  73. stride=1,
  74. use_5x5kernel=True,
  75. act='relu',
  76. name=None):
  77. super(BlazeBlock, self).__init__()
  78. assert stride in [1, 2]
  79. self.use_pool = not stride == 1
  80. self.use_double_block = double_channels is not None
  81. self.conv_dw = []
  82. if use_5x5kernel:
  83. self.conv_dw.append(
  84. self.add_sublayer(
  85. name + "1_dw",
  86. ConvBNLayer(
  87. in_channels=in_channels,
  88. out_channels=out_channels1,
  89. kernel_size=5,
  90. stride=stride,
  91. padding=2,
  92. num_groups=out_channels1,
  93. name=name + "1_dw")))
  94. else:
  95. self.conv_dw.append(
  96. self.add_sublayer(
  97. name + "1_dw_1",
  98. ConvBNLayer(
  99. in_channels=in_channels,
  100. out_channels=out_channels1,
  101. kernel_size=3,
  102. stride=1,
  103. padding=1,
  104. num_groups=out_channels1,
  105. name=name + "1_dw_1")))
  106. self.conv_dw.append(
  107. self.add_sublayer(
  108. name + "1_dw_2",
  109. ConvBNLayer(
  110. in_channels=out_channels1,
  111. out_channels=out_channels1,
  112. kernel_size=3,
  113. stride=stride,
  114. padding=1,
  115. num_groups=out_channels1,
  116. name=name + "1_dw_2")))
  117. self.act = act if self.use_double_block else None
  118. self.conv_pw = ConvBNLayer(
  119. in_channels=out_channels1,
  120. out_channels=out_channels2,
  121. kernel_size=1,
  122. stride=1,
  123. padding=0,
  124. act=self.act,
  125. name=name + "1_sep")
  126. if self.use_double_block:
  127. self.conv_dw2 = []
  128. if use_5x5kernel:
  129. self.conv_dw2.append(
  130. self.add_sublayer(
  131. name + "2_dw",
  132. ConvBNLayer(
  133. in_channels=out_channels2,
  134. out_channels=out_channels2,
  135. kernel_size=5,
  136. stride=1,
  137. padding=2,
  138. num_groups=out_channels2,
  139. name=name + "2_dw")))
  140. else:
  141. self.conv_dw2.append(
  142. self.add_sublayer(
  143. name + "2_dw_1",
  144. ConvBNLayer(
  145. in_channels=out_channels2,
  146. out_channels=out_channels2,
  147. kernel_size=3,
  148. stride=1,
  149. padding=1,
  150. num_groups=out_channels2,
  151. name=name + "1_dw_1")))
  152. self.conv_dw2.append(
  153. self.add_sublayer(
  154. name + "2_dw_2",
  155. ConvBNLayer(
  156. in_channels=out_channels2,
  157. out_channels=out_channels2,
  158. kernel_size=3,
  159. stride=1,
  160. padding=1,
  161. num_groups=out_channels2,
  162. name=name + "2_dw_2")))
  163. self.conv_pw2 = ConvBNLayer(
  164. in_channels=out_channels2,
  165. out_channels=double_channels,
  166. kernel_size=1,
  167. stride=1,
  168. padding=0,
  169. name=name + "2_sep")
  170. # shortcut
  171. if self.use_pool:
  172. shortcut_channel = double_channels or out_channels2
  173. self._shortcut = []
  174. self._shortcut.append(
  175. self.add_sublayer(
  176. name + '_shortcut_pool',
  177. nn.MaxPool2D(
  178. kernel_size=stride, stride=stride, ceil_mode=True)))
  179. self._shortcut.append(
  180. self.add_sublayer(
  181. name + '_shortcut_conv',
  182. ConvBNLayer(
  183. in_channels=in_channels,
  184. out_channels=shortcut_channel,
  185. kernel_size=1,
  186. stride=1,
  187. padding=0,
  188. name="shortcut" + name)))
  189. def forward(self, x):
  190. y = x
  191. for conv_dw_block in self.conv_dw:
  192. y = conv_dw_block(y)
  193. y = self.conv_pw(y)
  194. if self.use_double_block:
  195. for conv_dw2_block in self.conv_dw2:
  196. y = conv_dw2_block(y)
  197. y = self.conv_pw2(y)
  198. if self.use_pool:
  199. for shortcut in self._shortcut:
  200. x = shortcut(x)
  201. return F.relu(paddle.add(x, y))
  202. @register
  203. @serializable
  204. class BlazeNet(nn.Layer):
  205. """
  206. BlazeFace, see https://arxiv.org/abs/1907.05047
  207. Args:
  208. blaze_filters (list): number of filter for each blaze block.
  209. double_blaze_filters (list): number of filter for each double_blaze block.
  210. use_5x5kernel (bool): whether or not filter size is 5x5 in depth-wise conv.
  211. """
  212. def __init__(self,
  213. blaze_filters=[[24, 24], [24, 24], [24, 48, 2], [48, 48],
  214. [48, 48]],
  215. double_blaze_filters=[[48, 24, 96, 2], [96, 24, 96],
  216. [96, 24, 96], [96, 24, 96, 2],
  217. [96, 24, 96], [96, 24, 96]],
  218. use_5x5kernel=True,
  219. act=None):
  220. super(BlazeNet, self).__init__()
  221. conv1_num_filters = blaze_filters[0][0]
  222. self.conv1 = ConvBNLayer(
  223. in_channels=3,
  224. out_channels=conv1_num_filters,
  225. kernel_size=3,
  226. stride=2,
  227. padding=1,
  228. name="conv1")
  229. in_channels = conv1_num_filters
  230. self.blaze_block = []
  231. self._out_channels = []
  232. for k, v in enumerate(blaze_filters):
  233. assert len(v) in [2, 3], \
  234. "blaze_filters {} not in [2, 3]"
  235. if len(v) == 2:
  236. self.blaze_block.append(
  237. self.add_sublayer(
  238. 'blaze_{}'.format(k),
  239. BlazeBlock(
  240. in_channels,
  241. v[0],
  242. v[1],
  243. use_5x5kernel=use_5x5kernel,
  244. act=act,
  245. name='blaze_{}'.format(k))))
  246. elif len(v) == 3:
  247. self.blaze_block.append(
  248. self.add_sublayer(
  249. 'blaze_{}'.format(k),
  250. BlazeBlock(
  251. in_channels,
  252. v[0],
  253. v[1],
  254. stride=v[2],
  255. use_5x5kernel=use_5x5kernel,
  256. act=act,
  257. name='blaze_{}'.format(k))))
  258. in_channels = v[1]
  259. for k, v in enumerate(double_blaze_filters):
  260. assert len(v) in [3, 4], \
  261. "blaze_filters {} not in [3, 4]"
  262. if len(v) == 3:
  263. self.blaze_block.append(
  264. self.add_sublayer(
  265. 'double_blaze_{}'.format(k),
  266. BlazeBlock(
  267. in_channels,
  268. v[0],
  269. v[1],
  270. double_channels=v[2],
  271. use_5x5kernel=use_5x5kernel,
  272. act=act,
  273. name='double_blaze_{}'.format(k))))
  274. elif len(v) == 4:
  275. self.blaze_block.append(
  276. self.add_sublayer(
  277. 'double_blaze_{}'.format(k),
  278. BlazeBlock(
  279. in_channels,
  280. v[0],
  281. v[1],
  282. double_channels=v[2],
  283. stride=v[3],
  284. use_5x5kernel=use_5x5kernel,
  285. act=act,
  286. name='double_blaze_{}'.format(k))))
  287. in_channels = v[2]
  288. self._out_channels.append(in_channels)
  289. def forward(self, inputs):
  290. outs = []
  291. y = self.conv1(inputs['image'])
  292. for block in self.blaze_block:
  293. y = block(y)
  294. outs.append(y)
  295. return [outs[-4], outs[-1]]
  296. @property
  297. def out_shape(self):
  298. return [
  299. ShapeSpec(channels=c)
  300. for c in [self._out_channels[-4], self._out_channels[-1]]
  301. ]