mobilenet_v3.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle.fluid as fluid
  15. from paddle.fluid.param_attr import ParamAttr
  16. from paddle.fluid.regularizer import L2Decay
  17. import math
  18. class MobileNetV3():
  19. """
  20. MobileNet v3, see https://arxiv.org/abs/1905.02244
  21. Args:
  22. scale (float): scaling factor for convolution groups proportion of mobilenet_v3.
  23. model_name (str): There are two modes, small and large.
  24. norm_type (str): normalization type, 'bn' and 'sync_bn' are supported.
  25. norm_decay (float): weight decay for normalization layer weights.
  26. conv_decay (float): weight decay for convolution layer weights.
  27. with_extra_blocks (bool): if extra blocks should be added.
  28. extra_block_filters (list): number of filter for each extra block.
  29. """
  30. def __init__(self,
  31. scale=1.0,
  32. model_name='small',
  33. with_extra_blocks=False,
  34. conv_decay=0.0,
  35. norm_type='bn',
  36. norm_decay=0.0,
  37. extra_block_filters=[[256, 512], [128, 256], [128, 256],
  38. [64, 128]],
  39. num_classes=None,
  40. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
  41. for_seg=False,
  42. output_stride=None):
  43. assert len(lr_mult_list) == 5, \
  44. "lr_mult_list length in MobileNetV3 must be 5 but got {}!!".format(
  45. len(lr_mult_list))
  46. self.scale = scale
  47. self.with_extra_blocks = with_extra_blocks
  48. self.extra_block_filters = extra_block_filters
  49. self.conv_decay = conv_decay
  50. self.norm_decay = norm_decay
  51. self.inplanes = 16
  52. self.end_points = []
  53. self.block_stride = 1
  54. self.num_classes = num_classes
  55. self.lr_mult_list = lr_mult_list
  56. self.curr_stage = 0
  57. self.for_seg = for_seg
  58. self.decode_point = None
  59. if self.for_seg:
  60. if model_name == "large":
  61. self.cfg = [
  62. # k, exp, c, se, nl, s,
  63. [3, 16, 16, False, 'relu', 1],
  64. [3, 64, 24, False, 'relu', 2],
  65. [3, 72, 24, False, 'relu', 1],
  66. [5, 72, 40, True, 'relu', 2],
  67. [5, 120, 40, True, 'relu', 1],
  68. [5, 120, 40, True, 'relu', 1],
  69. [3, 240, 80, False, 'hard_swish', 2],
  70. [3, 200, 80, False, 'hard_swish', 1],
  71. [3, 184, 80, False, 'hard_swish', 1],
  72. [3, 184, 80, False, 'hard_swish', 1],
  73. [3, 480, 112, True, 'hard_swish', 1],
  74. [3, 672, 112, True, 'hard_swish', 1],
  75. # The number of channels in the last 4 stages is reduced by a
  76. # factor of 2 compared to the standard implementation.
  77. [5, 336, 80, True, 'hard_swish', 2],
  78. [5, 480, 80, True, 'hard_swish', 1],
  79. [5, 480, 80, True, 'hard_swish', 1],
  80. ]
  81. self.cls_ch_squeeze = 480
  82. self.cls_ch_expand = 1280
  83. self.lr_interval = 3
  84. elif model_name == "small":
  85. self.cfg = [
  86. # k, exp, c, se, nl, s,
  87. [3, 16, 16, True, 'relu', 2],
  88. [3, 72, 24, False, 'relu', 2],
  89. [3, 88, 24, False, 'relu', 1],
  90. [5, 96, 40, True, 'hard_swish', 2],
  91. [5, 240, 40, True, 'hard_swish', 1],
  92. [5, 240, 40, True, 'hard_swish', 1],
  93. [5, 120, 48, True, 'hard_swish', 1],
  94. [5, 144, 48, True, 'hard_swish', 1],
  95. # The number of channels in the last 4 stages is reduced by a
  96. # factor of 2 compared to the standard implementation.
  97. [5, 144, 48, True, 'hard_swish', 2],
  98. [5, 288, 48, True, 'hard_swish', 1],
  99. [5, 288, 48, True, 'hard_swish', 1],
  100. ]
  101. else:
  102. raise NotImplementedError
  103. else:
  104. if model_name == "large":
  105. self.cfg = [
  106. # kernel_size, expand, channel, se_block, act_mode, stride
  107. [3, 16, 16, False, 'relu', 1],
  108. [3, 64, 24, False, 'relu', 2],
  109. [3, 72, 24, False, 'relu', 1],
  110. [5, 72, 40, True, 'relu', 2],
  111. [5, 120, 40, True, 'relu', 1],
  112. [5, 120, 40, True, 'relu', 1],
  113. [3, 240, 80, False, 'hard_swish', 2],
  114. [3, 200, 80, False, 'hard_swish', 1],
  115. [3, 184, 80, False, 'hard_swish', 1],
  116. [3, 184, 80, False, 'hard_swish', 1],
  117. [3, 480, 112, True, 'hard_swish', 1],
  118. [3, 672, 112, True, 'hard_swish', 1],
  119. [5, 672, 160, True, 'hard_swish', 2],
  120. [5, 960, 160, True, 'hard_swish', 1],
  121. [5, 960, 160, True, 'hard_swish', 1],
  122. ]
  123. self.cls_ch_squeeze = 960
  124. self.cls_ch_expand = 1280
  125. self.lr_interval = 3
  126. elif model_name == "small":
  127. self.cfg = [
  128. # kernel_size, expand, channel, se_block, act_mode, stride
  129. [3, 16, 16, True, 'relu', 2],
  130. [3, 72, 24, False, 'relu', 2],
  131. [3, 88, 24, False, 'relu', 1],
  132. [5, 96, 40, True, 'hard_swish', 2],
  133. [5, 240, 40, True, 'hard_swish', 1],
  134. [5, 240, 40, True, 'hard_swish', 1],
  135. [5, 120, 48, True, 'hard_swish', 1],
  136. [5, 144, 48, True, 'hard_swish', 1],
  137. [5, 288, 96, True, 'hard_swish', 2],
  138. [5, 576, 96, True, 'hard_swish', 1],
  139. [5, 576, 96, True, 'hard_swish', 1],
  140. ]
  141. self.cls_ch_squeeze = 576
  142. self.cls_ch_expand = 1280
  143. self.lr_interval = 2
  144. else:
  145. raise NotImplementedError
  146. if self.for_seg:
  147. self.modify_bottle_params(output_stride)
  148. def modify_bottle_params(self, output_stride=None):
  149. if output_stride is not None and output_stride % 2 != 0:
  150. raise Exception("output stride must to be even number")
  151. if output_stride is None:
  152. return
  153. else:
  154. stride = 2
  155. for i, _cfg in enumerate(self.cfg):
  156. stride = stride * _cfg[-1]
  157. if stride > output_stride:
  158. s = 1
  159. self.cfg[i][-1] = s
  160. def _conv_bn_layer(self,
  161. input,
  162. filter_size,
  163. num_filters,
  164. stride,
  165. padding,
  166. num_groups=1,
  167. if_act=True,
  168. act=None,
  169. name=None,
  170. use_cudnn=True):
  171. lr_idx = self.curr_stage // self.lr_interval
  172. lr_idx = min(lr_idx, len(self.lr_mult_list) - 1)
  173. lr_mult = self.lr_mult_list[lr_idx]
  174. if self.num_classes:
  175. regularizer = None
  176. else:
  177. regularizer = L2Decay(self.conv_decay)
  178. conv_param_attr = ParamAttr(
  179. name=name + '_weights',
  180. learning_rate=lr_mult,
  181. regularizer=regularizer)
  182. conv = fluid.layers.conv2d(
  183. input=input,
  184. num_filters=num_filters,
  185. filter_size=filter_size,
  186. stride=stride,
  187. padding=padding,
  188. groups=num_groups,
  189. act=None,
  190. use_cudnn=use_cudnn,
  191. param_attr=conv_param_attr,
  192. bias_attr=False)
  193. bn_name = name + '_bn'
  194. bn_param_attr = ParamAttr(
  195. name=bn_name + "_scale", regularizer=L2Decay(self.norm_decay))
  196. bn_bias_attr = ParamAttr(
  197. name=bn_name + "_offset", regularizer=L2Decay(self.norm_decay))
  198. bn = fluid.layers.batch_norm(
  199. input=conv,
  200. param_attr=bn_param_attr,
  201. bias_attr=bn_bias_attr,
  202. moving_mean_name=bn_name + '_mean',
  203. moving_variance_name=bn_name + '_variance')
  204. if if_act:
  205. if act == 'relu':
  206. bn = fluid.layers.relu(bn)
  207. elif act == 'hard_swish':
  208. bn = self._hard_swish(bn)
  209. elif act == 'relu6':
  210. bn = fluid.layers.relu6(bn)
  211. return bn
  212. def make_divisible(self, v, divisor=8, min_value=None):
  213. if min_value is None:
  214. min_value = divisor
  215. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  216. if new_v < 0.9 * v:
  217. new_v += divisor
  218. return new_v
  219. def _hard_swish(self, x):
  220. return x * fluid.layers.relu6(x + 3) / 6.
  221. def _se_block(self, input, num_out_filter, ratio=4, name=None):
  222. lr_idx = self.curr_stage // self.lr_interval
  223. lr_idx = min(lr_idx, len(self.lr_mult_list) - 1)
  224. lr_mult = self.lr_mult_list[lr_idx]
  225. num_mid_filter = int(num_out_filter // ratio)
  226. pool = fluid.layers.pool2d(
  227. input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
  228. conv1 = fluid.layers.conv2d(
  229. input=pool,
  230. filter_size=1,
  231. num_filters=num_mid_filter,
  232. act='relu',
  233. param_attr=ParamAttr(
  234. name=name + '_1_weights', learning_rate=lr_mult),
  235. bias_attr=ParamAttr(
  236. name=name + '_1_offset', learning_rate=lr_mult))
  237. conv2 = fluid.layers.conv2d(
  238. input=conv1,
  239. filter_size=1,
  240. num_filters=num_out_filter,
  241. act='hard_sigmoid',
  242. param_attr=ParamAttr(
  243. name=name + '_2_weights', learning_rate=lr_mult),
  244. bias_attr=ParamAttr(
  245. name=name + '_2_offset', learning_rate=lr_mult))
  246. scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
  247. return scale
  248. def _residual_unit(self,
  249. input,
  250. num_in_filter,
  251. num_mid_filter,
  252. num_out_filter,
  253. stride,
  254. filter_size,
  255. act=None,
  256. use_se=False,
  257. name=None):
  258. input_data = input
  259. conv0 = self._conv_bn_layer(
  260. input=input,
  261. filter_size=1,
  262. num_filters=num_mid_filter,
  263. stride=1,
  264. padding=0,
  265. if_act=True,
  266. act=act,
  267. name=name + '_expand')
  268. if self.block_stride == 16 and stride == 2:
  269. self.end_points.append(conv0)
  270. conv1 = self._conv_bn_layer(
  271. input=conv0,
  272. filter_size=filter_size,
  273. num_filters=num_mid_filter,
  274. stride=stride,
  275. padding=int((filter_size - 1) // 2),
  276. if_act=True,
  277. act=act,
  278. num_groups=num_mid_filter,
  279. use_cudnn=False,
  280. name=name + '_depthwise')
  281. if self.curr_stage == 5:
  282. self.decode_point = conv1
  283. if use_se:
  284. conv1 = self._se_block(
  285. input=conv1, num_out_filter=num_mid_filter, name=name + '_se')
  286. conv2 = self._conv_bn_layer(
  287. input=conv1,
  288. filter_size=1,
  289. num_filters=num_out_filter,
  290. stride=1,
  291. padding=0,
  292. if_act=False,
  293. name=name + '_linear')
  294. if num_in_filter != num_out_filter or stride != 1:
  295. return conv2
  296. else:
  297. return fluid.layers.elementwise_add(
  298. x=input_data, y=conv2, act=None)
  299. def _extra_block_dw(self,
  300. input,
  301. num_filters1,
  302. num_filters2,
  303. stride,
  304. name=None):
  305. pointwise_conv = self._conv_bn_layer(
  306. input=input,
  307. filter_size=1,
  308. num_filters=int(num_filters1),
  309. stride=1,
  310. padding="SAME",
  311. act='relu6',
  312. name=name + "_extra1")
  313. depthwise_conv = self._conv_bn_layer(
  314. input=pointwise_conv,
  315. filter_size=3,
  316. num_filters=int(num_filters2),
  317. stride=stride,
  318. padding="SAME",
  319. num_groups=int(num_filters1),
  320. act='relu6',
  321. use_cudnn=False,
  322. name=name + "_extra2_dw")
  323. normal_conv = self._conv_bn_layer(
  324. input=depthwise_conv,
  325. filter_size=1,
  326. num_filters=int(num_filters2),
  327. stride=1,
  328. padding="SAME",
  329. act='relu6',
  330. name=name + "_extra2_sep")
  331. return normal_conv
  332. def __call__(self, input):
  333. scale = self.scale
  334. inplanes = self.inplanes
  335. cfg = self.cfg
  336. blocks = []
  337. #conv1
  338. conv = self._conv_bn_layer(
  339. input,
  340. filter_size=3,
  341. num_filters=self.make_divisible(inplanes * scale),
  342. stride=2,
  343. padding=1,
  344. num_groups=1,
  345. if_act=True,
  346. act='hard_swish',
  347. name='conv1')
  348. i = 0
  349. inplanes = self.make_divisible(inplanes * scale)
  350. for layer_cfg in cfg:
  351. self.block_stride *= layer_cfg[5]
  352. if layer_cfg[5] == 2:
  353. blocks.append(conv)
  354. conv = self._residual_unit(
  355. input=conv,
  356. num_in_filter=inplanes,
  357. num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
  358. num_out_filter=self.make_divisible(scale * layer_cfg[2]),
  359. act=layer_cfg[4],
  360. stride=layer_cfg[5],
  361. filter_size=layer_cfg[0],
  362. use_se=layer_cfg[3],
  363. name='conv' + str(i + 2))
  364. inplanes = self.make_divisible(scale * layer_cfg[2])
  365. i += 1
  366. self.curr_stage = i
  367. blocks.append(conv)
  368. if self.for_seg:
  369. conv = self._conv_bn_layer(
  370. input=conv,
  371. filter_size=1,
  372. num_filters=self.make_divisible(scale * self.cls_ch_squeeze),
  373. stride=1,
  374. padding=0,
  375. num_groups=1,
  376. if_act=True,
  377. act='hard_swish',
  378. name='conv_last')
  379. return conv, self.decode_point
  380. if self.num_classes:
  381. conv = self._conv_bn_layer(
  382. input=conv,
  383. filter_size=1,
  384. num_filters=int(scale * self.cls_ch_squeeze),
  385. stride=1,
  386. padding=0,
  387. num_groups=1,
  388. if_act=True,
  389. act='hard_swish',
  390. name='conv_last')
  391. conv = fluid.layers.pool2d(
  392. input=conv,
  393. pool_type='avg',
  394. global_pooling=True,
  395. use_cudnn=False)
  396. conv = fluid.layers.conv2d(
  397. input=conv,
  398. num_filters=self.cls_ch_expand,
  399. filter_size=1,
  400. stride=1,
  401. padding=0,
  402. act=None,
  403. param_attr=ParamAttr(name='last_1x1_conv_weights'),
  404. bias_attr=False)
  405. conv = self._hard_swish(conv)
  406. drop = fluid.layers.dropout(x=conv, dropout_prob=0.2)
  407. out = fluid.layers.fc(input=drop,
  408. size=self.num_classes,
  409. param_attr=ParamAttr(name='fc_weights'),
  410. bias_attr=ParamAttr(name='fc_offset'))
  411. return out
  412. if not self.with_extra_blocks:
  413. return blocks
  414. # extra block
  415. conv_extra = self._conv_bn_layer(
  416. conv,
  417. filter_size=1,
  418. num_filters=int(scale * cfg[-1][1]),
  419. stride=1,
  420. padding="SAME",
  421. num_groups=1,
  422. if_act=True,
  423. act='hard_swish',
  424. name='conv' + str(i + 2))
  425. self.end_points.append(conv_extra)
  426. i += 1
  427. for block_filter in self.extra_block_filters:
  428. conv_extra = self._extra_block_dw(conv_extra, block_filter[0],
  429. block_filter[1], 2,
  430. 'conv' + str(i + 2))
  431. self.end_points.append(conv_extra)
  432. i += 1
  433. return self.end_points