ghostnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. from paddle import ParamAttr
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddle.nn import AdaptiveAvgPool2D, Linear
  20. from paddle.nn.initializer import Uniform
  21. from paddlex.ppdet.core.workspace import register, serializable
  22. from numbers import Integral
  23. from ..shape_spec import ShapeSpec
  24. from .mobilenet_v3 import make_divisible, ConvBNLayer
  25. __all__ = ['GhostNet']
  26. class ExtraBlockDW(nn.Layer):
  27. def __init__(self,
  28. in_c,
  29. ch_1,
  30. ch_2,
  31. stride,
  32. lr_mult,
  33. conv_decay=0.,
  34. norm_type='bn',
  35. norm_decay=0.,
  36. freeze_norm=False,
  37. name=None):
  38. super(ExtraBlockDW, self).__init__()
  39. self.pointwise_conv = ConvBNLayer(
  40. in_c=in_c,
  41. out_c=ch_1,
  42. filter_size=1,
  43. stride=1,
  44. padding=0,
  45. act='relu6',
  46. lr_mult=lr_mult,
  47. conv_decay=conv_decay,
  48. norm_type=norm_type,
  49. norm_decay=norm_decay,
  50. freeze_norm=freeze_norm,
  51. name=name + "_extra1")
  52. self.depthwise_conv = ConvBNLayer(
  53. in_c=ch_1,
  54. out_c=ch_2,
  55. filter_size=3,
  56. stride=stride,
  57. padding=1, #
  58. num_groups=int(ch_1),
  59. act='relu6',
  60. lr_mult=lr_mult,
  61. conv_decay=conv_decay,
  62. norm_type=norm_type,
  63. norm_decay=norm_decay,
  64. freeze_norm=freeze_norm,
  65. name=name + "_extra2_dw")
  66. self.normal_conv = ConvBNLayer(
  67. in_c=ch_2,
  68. out_c=ch_2,
  69. filter_size=1,
  70. stride=1,
  71. padding=0,
  72. act='relu6',
  73. lr_mult=lr_mult,
  74. conv_decay=conv_decay,
  75. norm_type=norm_type,
  76. norm_decay=norm_decay,
  77. freeze_norm=freeze_norm,
  78. name=name + "_extra2_sep")
  79. def forward(self, inputs):
  80. x = self.pointwise_conv(inputs)
  81. x = self.depthwise_conv(x)
  82. x = self.normal_conv(x)
  83. return x
  84. class SEBlock(nn.Layer):
  85. def __init__(self, num_channels, lr_mult, reduction_ratio=4, name=None):
  86. super(SEBlock, self).__init__()
  87. self.pool2d_gap = AdaptiveAvgPool2D(1)
  88. self._num_channels = num_channels
  89. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  90. med_ch = num_channels // reduction_ratio
  91. self.squeeze = Linear(
  92. num_channels,
  93. med_ch,
  94. weight_attr=ParamAttr(
  95. learning_rate=lr_mult,
  96. initializer=Uniform(-stdv, stdv),
  97. name=name + "_1_weights"),
  98. bias_attr=ParamAttr(
  99. learning_rate=lr_mult, name=name + "_1_offset"))
  100. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  101. self.excitation = Linear(
  102. med_ch,
  103. num_channels,
  104. weight_attr=ParamAttr(
  105. learning_rate=lr_mult,
  106. initializer=Uniform(-stdv, stdv),
  107. name=name + "_2_weights"),
  108. bias_attr=ParamAttr(
  109. learning_rate=lr_mult, name=name + "_2_offset"))
  110. def forward(self, inputs):
  111. pool = self.pool2d_gap(inputs)
  112. pool = paddle.squeeze(pool, axis=[2, 3])
  113. squeeze = self.squeeze(pool)
  114. squeeze = F.relu(squeeze)
  115. excitation = self.excitation(squeeze)
  116. excitation = paddle.clip(x=excitation, min=0, max=1)
  117. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  118. out = paddle.multiply(inputs, excitation)
  119. return out
  120. class GhostModule(nn.Layer):
  121. def __init__(self,
  122. in_channels,
  123. output_channels,
  124. kernel_size=1,
  125. ratio=2,
  126. dw_size=3,
  127. stride=1,
  128. relu=True,
  129. lr_mult=1.,
  130. conv_decay=0.,
  131. norm_type='bn',
  132. norm_decay=0.,
  133. freeze_norm=False,
  134. name=None):
  135. super(GhostModule, self).__init__()
  136. init_channels = int(math.ceil(output_channels / ratio))
  137. new_channels = int(init_channels * (ratio - 1))
  138. self.primary_conv = ConvBNLayer(
  139. in_c=in_channels,
  140. out_c=init_channels,
  141. filter_size=kernel_size,
  142. stride=stride,
  143. padding=int((kernel_size - 1) // 2),
  144. num_groups=1,
  145. act="relu" if relu else None,
  146. lr_mult=lr_mult,
  147. conv_decay=conv_decay,
  148. norm_type=norm_type,
  149. norm_decay=norm_decay,
  150. freeze_norm=freeze_norm,
  151. name=name + "_primary_conv")
  152. self.cheap_operation = ConvBNLayer(
  153. in_c=init_channels,
  154. out_c=new_channels,
  155. filter_size=dw_size,
  156. stride=1,
  157. padding=int((dw_size - 1) // 2),
  158. num_groups=init_channels,
  159. act="relu" if relu else None,
  160. lr_mult=lr_mult,
  161. conv_decay=conv_decay,
  162. norm_type=norm_type,
  163. norm_decay=norm_decay,
  164. freeze_norm=freeze_norm,
  165. name=name + "_cheap_operation")
  166. def forward(self, inputs):
  167. x = self.primary_conv(inputs)
  168. y = self.cheap_operation(x)
  169. out = paddle.concat([x, y], axis=1)
  170. return out
  171. class GhostBottleneck(nn.Layer):
  172. def __init__(self,
  173. in_channels,
  174. hidden_dim,
  175. output_channels,
  176. kernel_size,
  177. stride,
  178. use_se,
  179. lr_mult,
  180. conv_decay=0.,
  181. norm_type='bn',
  182. norm_decay=0.,
  183. freeze_norm=False,
  184. return_list=False,
  185. name=None):
  186. super(GhostBottleneck, self).__init__()
  187. self._stride = stride
  188. self._use_se = use_se
  189. self._num_channels = in_channels
  190. self._output_channels = output_channels
  191. self.return_list = return_list
  192. self.ghost_module_1 = GhostModule(
  193. in_channels=in_channels,
  194. output_channels=hidden_dim,
  195. kernel_size=1,
  196. stride=1,
  197. relu=True,
  198. lr_mult=lr_mult,
  199. conv_decay=conv_decay,
  200. norm_type=norm_type,
  201. norm_decay=norm_decay,
  202. freeze_norm=freeze_norm,
  203. name=name + "_ghost_module_1")
  204. if stride == 2:
  205. self.depthwise_conv = ConvBNLayer(
  206. in_c=hidden_dim,
  207. out_c=hidden_dim,
  208. filter_size=kernel_size,
  209. stride=stride,
  210. padding=int((kernel_size - 1) // 2),
  211. num_groups=hidden_dim,
  212. act=None,
  213. lr_mult=lr_mult,
  214. conv_decay=conv_decay,
  215. norm_type=norm_type,
  216. norm_decay=norm_decay,
  217. freeze_norm=freeze_norm,
  218. name=name +
  219. "_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
  220. )
  221. if use_se:
  222. self.se_block = SEBlock(hidden_dim, lr_mult, name=name + "_se")
  223. self.ghost_module_2 = GhostModule(
  224. in_channels=hidden_dim,
  225. output_channels=output_channels,
  226. kernel_size=1,
  227. relu=False,
  228. lr_mult=lr_mult,
  229. conv_decay=conv_decay,
  230. norm_type=norm_type,
  231. norm_decay=norm_decay,
  232. freeze_norm=freeze_norm,
  233. name=name + "_ghost_module_2")
  234. if stride != 1 or in_channels != output_channels:
  235. self.shortcut_depthwise = ConvBNLayer(
  236. in_c=in_channels,
  237. out_c=in_channels,
  238. filter_size=kernel_size,
  239. stride=stride,
  240. padding=int((kernel_size - 1) // 2),
  241. num_groups=in_channels,
  242. act=None,
  243. lr_mult=lr_mult,
  244. conv_decay=conv_decay,
  245. norm_type=norm_type,
  246. norm_decay=norm_decay,
  247. freeze_norm=freeze_norm,
  248. name=name +
  249. "_shortcut_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
  250. )
  251. self.shortcut_conv = ConvBNLayer(
  252. in_c=in_channels,
  253. out_c=output_channels,
  254. filter_size=1,
  255. stride=1,
  256. padding=0,
  257. num_groups=1,
  258. act=None,
  259. lr_mult=lr_mult,
  260. conv_decay=conv_decay,
  261. norm_type=norm_type,
  262. norm_decay=norm_decay,
  263. freeze_norm=freeze_norm,
  264. name=name + "_shortcut_conv")
  265. def forward(self, inputs):
  266. y = self.ghost_module_1(inputs)
  267. x = y
  268. if self._stride == 2:
  269. x = self.depthwise_conv(x)
  270. if self._use_se:
  271. x = self.se_block(x)
  272. x = self.ghost_module_2(x)
  273. if self._stride == 1 and self._num_channels == self._output_channels:
  274. shortcut = inputs
  275. else:
  276. shortcut = self.shortcut_depthwise(inputs)
  277. shortcut = self.shortcut_conv(shortcut)
  278. x = paddle.add(x=x, y=shortcut)
  279. if self.return_list:
  280. return [y, x]
  281. else:
  282. return x
  283. @register
  284. @serializable
  285. class GhostNet(nn.Layer):
  286. __shared__ = ['norm_type']
  287. def __init__(self,
  288. scale=1.3,
  289. feature_maps=[6, 12, 15],
  290. with_extra_blocks=False,
  291. extra_block_filters=[[256, 512], [128, 256], [128, 256],
  292. [64, 128]],
  293. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
  294. conv_decay=0.,
  295. norm_type='bn',
  296. norm_decay=0.0,
  297. freeze_norm=False):
  298. super(GhostNet, self).__init__()
  299. if isinstance(feature_maps, Integral):
  300. feature_maps = [feature_maps]
  301. if norm_type == 'sync_bn' and freeze_norm:
  302. raise ValueError(
  303. "The norm_type should not be sync_bn when freeze_norm is True")
  304. self.feature_maps = feature_maps
  305. self.with_extra_blocks = with_extra_blocks
  306. self.extra_block_filters = extra_block_filters
  307. inplanes = 16
  308. self.cfgs = [
  309. # k, t, c, SE, s
  310. [3, 16, 16, 0, 1],
  311. [3, 48, 24, 0, 2],
  312. [3, 72, 24, 0, 1],
  313. [5, 72, 40, 1, 2],
  314. [5, 120, 40, 1, 1],
  315. [3, 240, 80, 0, 2],
  316. [3, 200, 80, 0, 1],
  317. [3, 184, 80, 0, 1],
  318. [3, 184, 80, 0, 1],
  319. [3, 480, 112, 1, 1],
  320. [3, 672, 112, 1, 1],
  321. [5, 672, 160, 1, 2], # SSDLite output
  322. [5, 960, 160, 0, 1],
  323. [5, 960, 160, 1, 1],
  324. [5, 960, 160, 0, 1],
  325. [5, 960, 160, 1, 1]
  326. ]
  327. self.scale = scale
  328. conv1_out_ch = int(make_divisible(inplanes * self.scale, 4))
  329. self.conv1 = ConvBNLayer(
  330. in_c=3,
  331. out_c=conv1_out_ch,
  332. filter_size=3,
  333. stride=2,
  334. padding=1,
  335. num_groups=1,
  336. act="relu",
  337. lr_mult=1.,
  338. conv_decay=conv_decay,
  339. norm_type=norm_type,
  340. norm_decay=norm_decay,
  341. freeze_norm=freeze_norm,
  342. name="conv1")
  343. # build inverted residual blocks
  344. self._out_channels = []
  345. self.ghost_bottleneck_list = []
  346. idx = 0
  347. inplanes = conv1_out_ch
  348. for k, exp_size, c, use_se, s in self.cfgs:
  349. lr_idx = min(idx // 3, len(lr_mult_list) - 1)
  350. lr_mult = lr_mult_list[lr_idx]
  351. # for SSD/SSDLite, first head input is after ResidualUnit expand_conv
  352. return_list = self.with_extra_blocks and idx + 2 in self.feature_maps
  353. ghost_bottleneck = self.add_sublayer(
  354. "_ghostbottleneck_" + str(idx),
  355. sublayer=GhostBottleneck(
  356. in_channels=inplanes,
  357. hidden_dim=int(make_divisible(exp_size * self.scale, 4)),
  358. output_channels=int(make_divisible(c * self.scale, 4)),
  359. kernel_size=k,
  360. stride=s,
  361. use_se=use_se,
  362. lr_mult=lr_mult,
  363. conv_decay=conv_decay,
  364. norm_type=norm_type,
  365. norm_decay=norm_decay,
  366. freeze_norm=freeze_norm,
  367. return_list=return_list,
  368. name="_ghostbottleneck_" + str(idx)))
  369. self.ghost_bottleneck_list.append(ghost_bottleneck)
  370. inplanes = int(make_divisible(c * self.scale, 4))
  371. idx += 1
  372. self._update_out_channels(
  373. int(make_divisible(exp_size * self.scale, 4))
  374. if return_list else inplanes, idx + 1, feature_maps)
  375. if self.with_extra_blocks:
  376. self.extra_block_list = []
  377. extra_out_c = int(make_divisible(self.scale * self.cfgs[-1][1], 4))
  378. lr_idx = min(idx // 3, len(lr_mult_list) - 1)
  379. lr_mult = lr_mult_list[lr_idx]
  380. conv_extra = self.add_sublayer(
  381. "conv" + str(idx + 2),
  382. sublayer=ConvBNLayer(
  383. in_c=inplanes,
  384. out_c=extra_out_c,
  385. filter_size=1,
  386. stride=1,
  387. padding=0,
  388. num_groups=1,
  389. act="relu6",
  390. lr_mult=lr_mult,
  391. conv_decay=conv_decay,
  392. norm_type=norm_type,
  393. norm_decay=norm_decay,
  394. freeze_norm=freeze_norm,
  395. name="conv" + str(idx + 2)))
  396. self.extra_block_list.append(conv_extra)
  397. idx += 1
  398. self._update_out_channels(extra_out_c, idx + 1, feature_maps)
  399. for j, block_filter in enumerate(self.extra_block_filters):
  400. in_c = extra_out_c if j == 0 else self.extra_block_filters[
  401. j - 1][1]
  402. conv_extra = self.add_sublayer(
  403. "conv" + str(idx + 2),
  404. sublayer=ExtraBlockDW(
  405. in_c,
  406. block_filter[0],
  407. block_filter[1],
  408. stride=2,
  409. lr_mult=lr_mult,
  410. conv_decay=conv_decay,
  411. norm_type=norm_type,
  412. norm_decay=norm_decay,
  413. freeze_norm=freeze_norm,
  414. name='conv' + str(idx + 2)))
  415. self.extra_block_list.append(conv_extra)
  416. idx += 1
  417. self._update_out_channels(block_filter[1], idx + 1,
  418. feature_maps)
  419. def _update_out_channels(self, channel, feature_idx, feature_maps):
  420. if feature_idx in feature_maps:
  421. self._out_channels.append(channel)
  422. def forward(self, inputs):
  423. x = self.conv1(inputs['image'])
  424. outs = []
  425. for idx, ghost_bottleneck in enumerate(self.ghost_bottleneck_list):
  426. x = ghost_bottleneck(x)
  427. if idx + 2 in self.feature_maps:
  428. if isinstance(x, list):
  429. outs.append(x[0])
  430. x = x[1]
  431. else:
  432. outs.append(x)
  433. if not self.with_extra_blocks:
  434. return outs
  435. for i, block in enumerate(self.extra_block_list):
  436. idx = i + len(self.ghost_bottleneck_list)
  437. x = block(x)
  438. if idx + 2 in self.feature_maps:
  439. outs.append(x)
  440. return outs
  441. @property
  442. def out_shape(self):
  443. return [ShapeSpec(channels=c) for c in self._out_channels]