mobilenet_v3.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. from paddle.regularizer import L2Decay
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. from numbers import Integral
  24. from ..shape_spec import ShapeSpec
  25. __all__ = ['MobileNetV3']
  26. def make_divisible(v, divisor=8, min_value=None):
  27. if min_value is None:
  28. min_value = divisor
  29. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  30. if new_v < 0.9 * v:
  31. new_v += divisor
  32. return new_v
  33. class ConvBNLayer(nn.Layer):
  34. def __init__(self,
  35. in_c,
  36. out_c,
  37. filter_size,
  38. stride,
  39. padding,
  40. num_groups=1,
  41. act=None,
  42. lr_mult=1.,
  43. conv_decay=0.,
  44. norm_type='bn',
  45. norm_decay=0.,
  46. freeze_norm=False,
  47. name=""):
  48. super(ConvBNLayer, self).__init__()
  49. self.act = act
  50. self.conv = nn.Conv2D(
  51. in_channels=in_c,
  52. out_channels=out_c,
  53. kernel_size=filter_size,
  54. stride=stride,
  55. padding=padding,
  56. groups=num_groups,
  57. weight_attr=ParamAttr(
  58. learning_rate=lr_mult,
  59. regularizer=L2Decay(conv_decay),
  60. name=name + "_weights"),
  61. bias_attr=False)
  62. norm_lr = 0. if freeze_norm else lr_mult
  63. param_attr = ParamAttr(
  64. learning_rate=norm_lr,
  65. regularizer=L2Decay(norm_decay),
  66. name=name + "_bn_scale",
  67. trainable=False if freeze_norm else True)
  68. bias_attr = ParamAttr(
  69. learning_rate=norm_lr,
  70. regularizer=L2Decay(norm_decay),
  71. name=name + "_bn_offset",
  72. trainable=False if freeze_norm else True)
  73. global_stats = True if freeze_norm else False
  74. if norm_type == 'sync_bn':
  75. self.bn = nn.SyncBatchNorm(
  76. out_c, weight_attr=param_attr, bias_attr=bias_attr)
  77. else:
  78. self.bn = nn.BatchNorm(
  79. out_c,
  80. act=None,
  81. param_attr=param_attr,
  82. bias_attr=bias_attr,
  83. use_global_stats=global_stats,
  84. moving_mean_name=name + '_bn_mean',
  85. moving_variance_name=name + '_bn_variance')
  86. norm_params = self.bn.parameters()
  87. if freeze_norm:
  88. for param in norm_params:
  89. param.stop_gradient = True
  90. def forward(self, x):
  91. x = self.conv(x)
  92. x = self.bn(x)
  93. if self.act is not None:
  94. if self.act == "relu":
  95. x = F.relu(x)
  96. elif self.act == "relu6":
  97. x = F.relu6(x)
  98. elif self.act == "hard_swish":
  99. x = F.hardswish(x)
  100. else:
  101. raise NotImplementedError(
  102. "The activation function is selected incorrectly.")
  103. return x
  104. class ResidualUnit(nn.Layer):
  105. def __init__(self,
  106. in_c,
  107. mid_c,
  108. out_c,
  109. filter_size,
  110. stride,
  111. use_se,
  112. lr_mult,
  113. conv_decay=0.,
  114. norm_type='bn',
  115. norm_decay=0.,
  116. freeze_norm=False,
  117. act=None,
  118. return_list=False,
  119. name=''):
  120. super(ResidualUnit, self).__init__()
  121. self.if_shortcut = stride == 1 and in_c == out_c
  122. self.use_se = use_se
  123. self.return_list = return_list
  124. self.expand_conv = ConvBNLayer(
  125. in_c=in_c,
  126. out_c=mid_c,
  127. filter_size=1,
  128. stride=1,
  129. padding=0,
  130. act=act,
  131. lr_mult=lr_mult,
  132. conv_decay=conv_decay,
  133. norm_type=norm_type,
  134. norm_decay=norm_decay,
  135. freeze_norm=freeze_norm,
  136. name=name + "_expand")
  137. self.bottleneck_conv = ConvBNLayer(
  138. in_c=mid_c,
  139. out_c=mid_c,
  140. filter_size=filter_size,
  141. stride=stride,
  142. padding=int((filter_size - 1) // 2),
  143. num_groups=mid_c,
  144. act=act,
  145. lr_mult=lr_mult,
  146. conv_decay=conv_decay,
  147. norm_type=norm_type,
  148. norm_decay=norm_decay,
  149. freeze_norm=freeze_norm,
  150. name=name + "_depthwise")
  151. if self.use_se:
  152. self.mid_se = SEModule(
  153. mid_c, lr_mult, conv_decay, name=name + "_se")
  154. self.linear_conv = ConvBNLayer(
  155. in_c=mid_c,
  156. out_c=out_c,
  157. filter_size=1,
  158. stride=1,
  159. padding=0,
  160. act=None,
  161. lr_mult=lr_mult,
  162. conv_decay=conv_decay,
  163. norm_type=norm_type,
  164. norm_decay=norm_decay,
  165. freeze_norm=freeze_norm,
  166. name=name + "_linear")
  167. def forward(self, inputs):
  168. y = self.expand_conv(inputs)
  169. x = self.bottleneck_conv(y)
  170. if self.use_se:
  171. x = self.mid_se(x)
  172. x = self.linear_conv(x)
  173. if self.if_shortcut:
  174. x = paddle.add(inputs, x)
  175. if self.return_list:
  176. return [y, x]
  177. else:
  178. return x
  179. class SEModule(nn.Layer):
  180. def __init__(self, channel, lr_mult, conv_decay, reduction=4, name=""):
  181. super(SEModule, self).__init__()
  182. self.avg_pool = nn.AdaptiveAvgPool2D(1)
  183. mid_channels = int(channel // reduction)
  184. self.conv1 = nn.Conv2D(
  185. in_channels=channel,
  186. out_channels=mid_channels,
  187. kernel_size=1,
  188. stride=1,
  189. padding=0,
  190. weight_attr=ParamAttr(
  191. learning_rate=lr_mult,
  192. regularizer=L2Decay(conv_decay),
  193. name=name + "_1_weights"),
  194. bias_attr=ParamAttr(
  195. learning_rate=lr_mult,
  196. regularizer=L2Decay(conv_decay),
  197. name=name + "_1_offset"))
  198. self.conv2 = nn.Conv2D(
  199. in_channels=mid_channels,
  200. out_channels=channel,
  201. kernel_size=1,
  202. stride=1,
  203. padding=0,
  204. weight_attr=ParamAttr(
  205. learning_rate=lr_mult,
  206. regularizer=L2Decay(conv_decay),
  207. name=name + "_2_weights"),
  208. bias_attr=ParamAttr(
  209. learning_rate=lr_mult,
  210. regularizer=L2Decay(conv_decay),
  211. name=name + "_2_offset"))
  212. def forward(self, inputs):
  213. outputs = self.avg_pool(inputs)
  214. outputs = self.conv1(outputs)
  215. outputs = F.relu(outputs)
  216. outputs = self.conv2(outputs)
  217. outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
  218. return paddle.multiply(x=inputs, y=outputs)
  219. class ExtraBlockDW(nn.Layer):
  220. def __init__(self,
  221. in_c,
  222. ch_1,
  223. ch_2,
  224. stride,
  225. lr_mult,
  226. conv_decay=0.,
  227. norm_type='bn',
  228. norm_decay=0.,
  229. freeze_norm=False,
  230. name=None):
  231. super(ExtraBlockDW, self).__init__()
  232. self.pointwise_conv = ConvBNLayer(
  233. in_c=in_c,
  234. out_c=ch_1,
  235. filter_size=1,
  236. stride=1,
  237. padding='SAME',
  238. act='relu6',
  239. lr_mult=lr_mult,
  240. conv_decay=conv_decay,
  241. norm_type=norm_type,
  242. norm_decay=norm_decay,
  243. freeze_norm=freeze_norm,
  244. name=name + "_extra1")
  245. self.depthwise_conv = ConvBNLayer(
  246. in_c=ch_1,
  247. out_c=ch_2,
  248. filter_size=3,
  249. stride=stride,
  250. padding='SAME',
  251. num_groups=int(ch_1),
  252. act='relu6',
  253. lr_mult=lr_mult,
  254. conv_decay=conv_decay,
  255. norm_type=norm_type,
  256. norm_decay=norm_decay,
  257. freeze_norm=freeze_norm,
  258. name=name + "_extra2_dw")
  259. self.normal_conv = ConvBNLayer(
  260. in_c=ch_2,
  261. out_c=ch_2,
  262. filter_size=1,
  263. stride=1,
  264. padding='SAME',
  265. act='relu6',
  266. lr_mult=lr_mult,
  267. conv_decay=conv_decay,
  268. norm_type=norm_type,
  269. norm_decay=norm_decay,
  270. freeze_norm=freeze_norm,
  271. name=name + "_extra2_sep")
  272. def forward(self, inputs):
  273. x = self.pointwise_conv(inputs)
  274. x = self.depthwise_conv(x)
  275. x = self.normal_conv(x)
  276. return x
  277. @register
  278. @serializable
  279. class MobileNetV3(nn.Layer):
  280. __shared__ = ['norm_type']
  281. def __init__(self,
  282. scale=1.0,
  283. model_name="large",
  284. feature_maps=[6, 12, 15],
  285. with_extra_blocks=False,
  286. extra_block_filters=[[256, 512], [128, 256], [128, 256],
  287. [64, 128]],
  288. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
  289. conv_decay=0.0,
  290. multiplier=1.0,
  291. norm_type='bn',
  292. norm_decay=0.0,
  293. freeze_norm=False):
  294. super(MobileNetV3, self).__init__()
  295. if isinstance(feature_maps, Integral):
  296. feature_maps = [feature_maps]
  297. if norm_type == 'sync_bn' and freeze_norm:
  298. raise ValueError(
  299. "The norm_type should not be sync_bn when freeze_norm is True")
  300. self.feature_maps = feature_maps
  301. self.with_extra_blocks = with_extra_blocks
  302. self.extra_block_filters = extra_block_filters
  303. inplanes = 16
  304. if model_name == "large":
  305. self.cfg = [
  306. # k, exp, c, se, nl, s,
  307. [3, 16, 16, False, "relu", 1],
  308. [3, 64, 24, False, "relu", 2],
  309. [3, 72, 24, False, "relu", 1],
  310. [5, 72, 40, True, "relu", 2], # RCNN output
  311. [5, 120, 40, True, "relu", 1],
  312. [5, 120, 40, True, "relu", 1], # YOLOv3 output
  313. [3, 240, 80, False, "hard_swish", 2], # RCNN output
  314. [3, 200, 80, False, "hard_swish", 1],
  315. [3, 184, 80, False, "hard_swish", 1],
  316. [3, 184, 80, False, "hard_swish", 1],
  317. [3, 480, 112, True, "hard_swish", 1],
  318. [3, 672, 112, True, "hard_swish", 1], # YOLOv3 output
  319. [5, 672, 160, True, "hard_swish",
  320. 2], # SSD/SSDLite/RCNN output
  321. [5, 960, 160, True, "hard_swish", 1],
  322. [5, 960, 160, True, "hard_swish", 1], # YOLOv3 output
  323. ]
  324. elif model_name == "small":
  325. self.cfg = [
  326. # k, exp, c, se, nl, s,
  327. [3, 16, 16, True, "relu", 2],
  328. [3, 72, 24, False, "relu", 2], # RCNN output
  329. [3, 88, 24, False, "relu", 1], # YOLOv3 output
  330. [5, 96, 40, True, "hard_swish", 2], # RCNN output
  331. [5, 240, 40, True, "hard_swish", 1],
  332. [5, 240, 40, True, "hard_swish", 1],
  333. [5, 120, 48, True, "hard_swish", 1],
  334. [5, 144, 48, True, "hard_swish", 1], # YOLOv3 output
  335. [5, 288, 96, True, "hard_swish", 2], # SSD/SSDLite/RCNN output
  336. [5, 576, 96, True, "hard_swish", 1],
  337. [5, 576, 96, True, "hard_swish", 1], # YOLOv3 output
  338. ]
  339. else:
  340. raise NotImplementedError(
  341. "mode[{}_model] is not implemented!".format(model_name))
  342. if multiplier != 1.0:
  343. self.cfg[-3][2] = int(self.cfg[-3][2] * multiplier)
  344. self.cfg[-2][1] = int(self.cfg[-2][1] * multiplier)
  345. self.cfg[-2][2] = int(self.cfg[-2][2] * multiplier)
  346. self.cfg[-1][1] = int(self.cfg[-1][1] * multiplier)
  347. self.cfg[-1][2] = int(self.cfg[-1][2] * multiplier)
  348. self.conv1 = ConvBNLayer(
  349. in_c=3,
  350. out_c=make_divisible(inplanes * scale),
  351. filter_size=3,
  352. stride=2,
  353. padding=1,
  354. num_groups=1,
  355. act="hard_swish",
  356. lr_mult=lr_mult_list[0],
  357. conv_decay=conv_decay,
  358. norm_type=norm_type,
  359. norm_decay=norm_decay,
  360. freeze_norm=freeze_norm,
  361. name="conv1")
  362. self._out_channels = []
  363. self.block_list = []
  364. i = 0
  365. inplanes = make_divisible(inplanes * scale)
  366. for (k, exp, c, se, nl, s) in self.cfg:
  367. lr_idx = min(i // 3, len(lr_mult_list) - 1)
  368. lr_mult = lr_mult_list[lr_idx]
  369. # for SSD/SSDLite, first head input is after ResidualUnit expand_conv
  370. return_list = self.with_extra_blocks and i + 2 in self.feature_maps
  371. block = self.add_sublayer(
  372. "conv" + str(i + 2),
  373. sublayer=ResidualUnit(
  374. in_c=inplanes,
  375. mid_c=make_divisible(scale * exp),
  376. out_c=make_divisible(scale * c),
  377. filter_size=k,
  378. stride=s,
  379. use_se=se,
  380. act=nl,
  381. lr_mult=lr_mult,
  382. conv_decay=conv_decay,
  383. norm_type=norm_type,
  384. norm_decay=norm_decay,
  385. freeze_norm=freeze_norm,
  386. return_list=return_list,
  387. name="conv" + str(i + 2)))
  388. self.block_list.append(block)
  389. inplanes = make_divisible(scale * c)
  390. i += 1
  391. self._update_out_channels(
  392. make_divisible(scale * exp)
  393. if return_list else inplanes, i + 1, feature_maps)
  394. if self.with_extra_blocks:
  395. self.extra_block_list = []
  396. extra_out_c = make_divisible(scale * self.cfg[-1][1])
  397. lr_idx = min(i // 3, len(lr_mult_list) - 1)
  398. lr_mult = lr_mult_list[lr_idx]
  399. conv_extra = self.add_sublayer(
  400. "conv" + str(i + 2),
  401. sublayer=ConvBNLayer(
  402. in_c=inplanes,
  403. out_c=extra_out_c,
  404. filter_size=1,
  405. stride=1,
  406. padding=0,
  407. num_groups=1,
  408. act="hard_swish",
  409. lr_mult=lr_mult,
  410. conv_decay=conv_decay,
  411. norm_type=norm_type,
  412. norm_decay=norm_decay,
  413. freeze_norm=freeze_norm,
  414. name="conv" + str(i + 2)))
  415. self.extra_block_list.append(conv_extra)
  416. i += 1
  417. self._update_out_channels(extra_out_c, i + 1, feature_maps)
  418. for j, block_filter in enumerate(self.extra_block_filters):
  419. in_c = extra_out_c if j == 0 else self.extra_block_filters[
  420. j - 1][1]
  421. conv_extra = self.add_sublayer(
  422. "conv" + str(i + 2),
  423. sublayer=ExtraBlockDW(
  424. in_c,
  425. block_filter[0],
  426. block_filter[1],
  427. stride=2,
  428. lr_mult=lr_mult,
  429. conv_decay=conv_decay,
  430. norm_type=norm_type,
  431. norm_decay=norm_decay,
  432. freeze_norm=freeze_norm,
  433. name='conv' + str(i + 2)))
  434. self.extra_block_list.append(conv_extra)
  435. i += 1
  436. self._update_out_channels(block_filter[1], i + 1, feature_maps)
  437. def _update_out_channels(self, channel, feature_idx, feature_maps):
  438. if feature_idx in feature_maps:
  439. self._out_channels.append(channel)
  440. def forward(self, inputs):
  441. x = self.conv1(inputs['image'])
  442. outs = []
  443. for idx, block in enumerate(self.block_list):
  444. x = block(x)
  445. if idx + 2 in self.feature_maps:
  446. if isinstance(x, list):
  447. outs.append(x[0])
  448. x = x[1]
  449. else:
  450. outs.append(x)
  451. if not self.with_extra_blocks:
  452. return outs
  453. for i, block in enumerate(self.extra_block_list):
  454. idx = i + len(self.block_list)
  455. x = block(x)
  456. if idx + 2 in self.feature_maps:
  457. outs.append(x)
  458. return outs
  459. @property
  460. def out_shape(self):
  461. return [ShapeSpec(channels=c) for c in self._out_channels]