mixnet.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. MixNet for ImageNet-1K, implemented in Paddle.
  16. Original paper: 'MixConv: Mixed Depthwise Convolutional Kernels,'
  17. https://arxiv.org/abs/1907.09595.
  18. """
  19. __all__ = ['MixNet_S', 'MixNet_M', 'MixNet_L']
  20. import os
  21. from inspect import isfunction
  22. from functools import reduce
  23. import paddle
  24. import paddle.nn as nn
  25. class Identity(nn.Layer):
  26. """
  27. Identity block.
  28. """
  29. def __init__(self):
  30. super(Identity, self).__init__()
  31. def forward(self, x):
  32. return x
  33. def round_channels(channels, divisor=8):
  34. """
  35. Round weighted channel number (make divisible operation).
  36. Parameters:
  37. ----------
  38. channels : int or float
  39. Original number of channels.
  40. divisor : int, default 8
  41. Alignment value.
  42. Returns:
  43. -------
  44. int
  45. Weighted number of channels.
  46. """
  47. rounded_channels = max(
  48. int(channels + divisor / 2.0) // divisor * divisor, divisor)
  49. if float(rounded_channels) < 0.9 * channels:
  50. rounded_channels += divisor
  51. return rounded_channels
  52. def get_activation_layer(activation):
  53. """
  54. Create activation layer from string/function.
  55. Parameters:
  56. ----------
  57. activation : function, or str, or nn.Module
  58. Activation function or name of activation function.
  59. Returns:
  60. -------
  61. nn.Module
  62. Activation layer.
  63. """
  64. assert activation is not None
  65. if isfunction(activation):
  66. return activation()
  67. elif isinstance(activation, str):
  68. if activation == "relu":
  69. return nn.ReLU()
  70. elif activation == "relu6":
  71. return nn.ReLU6()
  72. elif activation == "swish":
  73. return nn.Swish()
  74. elif activation == "hswish":
  75. return nn.Hardswish()
  76. elif activation == "sigmoid":
  77. return nn.Sigmoid()
  78. elif activation == "hsigmoid":
  79. return nn.Hardsigmoid()
  80. elif activation == "identity":
  81. return Identity()
  82. else:
  83. raise NotImplementedError()
  84. else:
  85. assert isinstance(activation, nn.Layer)
  86. return activation
  87. class ConvBlock(nn.Layer):
  88. """
  89. Standard convolution block with Batch normalization and activation.
  90. Parameters:
  91. ----------
  92. in_channels : int
  93. Number of input channels.
  94. out_channels : int
  95. Number of output channels.
  96. kernel_size : int or tuple/list of 2 int
  97. Convolution window size.
  98. stride : int or tuple/list of 2 int
  99. Strides of the convolution.
  100. padding : int, or tuple/list of 2 int, or tuple/list of 4 int
  101. Padding value for convolution layer.
  102. dilation : int or tuple/list of 2 int, default 1
  103. Dilation value for convolution layer.
  104. groups : int, default 1
  105. Number of groups.
  106. bias : bool, default False
  107. Whether the layer uses a bias vector.
  108. use_bn : bool, default True
  109. Whether to use BatchNorm layer.
  110. bn_eps : float, default 1e-5
  111. Small float added to variance in Batch norm.
  112. activation : function or str or None, default nn.ReLU()
  113. Activation function or name of activation function.
  114. """
  115. def __init__(self,
  116. in_channels,
  117. out_channels,
  118. kernel_size,
  119. stride,
  120. padding,
  121. dilation=1,
  122. groups=1,
  123. bias=False,
  124. use_bn=True,
  125. bn_eps=1e-5,
  126. activation=nn.ReLU()):
  127. super(ConvBlock, self).__init__()
  128. self.activate = (activation is not None)
  129. self.use_bn = use_bn
  130. self.use_pad = (isinstance(padding, (list, tuple)) and
  131. (len(padding) == 4))
  132. if self.use_pad:
  133. self.pad = padding
  134. self.conv = nn.Conv2D(
  135. in_channels=in_channels,
  136. out_channels=out_channels,
  137. kernel_size=kernel_size,
  138. stride=stride,
  139. padding=padding,
  140. dilation=dilation,
  141. groups=groups,
  142. bias_attr=bias,
  143. weight_attr=None)
  144. if self.use_bn:
  145. self.bn = nn.BatchNorm2D(num_features=out_channels, epsilon=bn_eps)
  146. if self.activate:
  147. self.activ = get_activation_layer(activation)
  148. def forward(self, x):
  149. x = self.conv(x)
  150. if self.use_bn:
  151. x = self.bn(x)
  152. if self.activate:
  153. x = self.activ(x)
  154. return x
  155. class SEBlock(nn.Layer):
  156. def __init__(self,
  157. channels,
  158. reduction=16,
  159. mid_channels=None,
  160. round_mid=False,
  161. use_conv=True,
  162. mid_activation=nn.ReLU(),
  163. out_activation=nn.Sigmoid()):
  164. super(SEBlock, self).__init__()
  165. self.use_conv = use_conv
  166. if mid_channels is None:
  167. mid_channels = channels // reduction if not round_mid else round_channels(
  168. float(channels) / reduction)
  169. self.pool = nn.AdaptiveAvgPool2D(output_size=1)
  170. if use_conv:
  171. self.conv1 = nn.Conv2D(
  172. in_channels=channels,
  173. out_channels=mid_channels,
  174. kernel_size=1,
  175. stride=1,
  176. groups=1,
  177. bias_attr=True,
  178. weight_attr=None)
  179. else:
  180. self.fc1 = nn.Linear(
  181. in_features=channels, out_features=mid_channels)
  182. self.activ = get_activation_layer(mid_activation)
  183. if use_conv:
  184. self.conv2 = nn.Conv2D(
  185. in_channels=mid_channels,
  186. out_channels=channels,
  187. kernel_size=1,
  188. stride=1,
  189. groups=1,
  190. bias_attr=True,
  191. weight_attr=None)
  192. else:
  193. self.fc2 = nn.Linear(
  194. in_features=mid_channels, out_features=channels)
  195. self.sigmoid = get_activation_layer(out_activation)
  196. def forward(self, x):
  197. w = self.pool(x)
  198. if not self.use_conv:
  199. w = w.reshape(shape=[w.shape[0], -1])
  200. w = self.conv1(w) if self.use_conv else self.fc1(w)
  201. w = self.activ(w)
  202. w = self.conv2(w) if self.use_conv else self.fc2(w)
  203. w = self.sigmoid(w)
  204. if not self.use_conv:
  205. w = w.unsqueeze(2).unsqueeze(3)
  206. x = x * w
  207. return x
  208. class MixConv(nn.Layer):
  209. """
  210. Mixed convolution layer from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  211. https://arxiv.org/abs/1907.09595.
  212. Parameters:
  213. ----------
  214. in_channels : int
  215. Number of input channels.
  216. out_channels : int
  217. Number of output channels.
  218. kernel_size : int or tuple/list of int, or tuple/list of tuple/list of 2 int
  219. Convolution window size.
  220. stride : int or tuple/list of 2 int
  221. Strides of the convolution.
  222. padding : int or tuple/list of int, or tuple/list of tuple/list of 2 int
  223. Padding value for convolution layer.
  224. dilation : int or tuple/list of 2 int, default 1
  225. Dilation value for convolution layer.
  226. groups : int, default 1
  227. Number of groups.
  228. bias : bool, default False
  229. Whether the layer uses a bias vector.
  230. axis : int, default 1
  231. The axis on which to concatenate the outputs.
  232. """
  233. def __init__(self,
  234. in_channels,
  235. out_channels,
  236. kernel_size,
  237. stride,
  238. padding,
  239. dilation=1,
  240. groups=1,
  241. bias=False,
  242. axis=1):
  243. super(MixConv, self).__init__()
  244. kernel_size = kernel_size if isinstance(kernel_size,
  245. list) else [kernel_size]
  246. padding = padding if isinstance(padding, list) else [padding]
  247. kernel_count = len(kernel_size)
  248. self.splitted_in_channels = self.split_channels(in_channels,
  249. kernel_count)
  250. splitted_out_channels = self.split_channels(out_channels, kernel_count)
  251. for i, kernel_size_i in enumerate(kernel_size):
  252. in_channels_i = self.splitted_in_channels[i]
  253. out_channels_i = splitted_out_channels[i]
  254. padding_i = padding[i]
  255. _ = self.add_sublayer(
  256. name=str(i),
  257. sublayer=nn.Conv2D(
  258. in_channels=in_channels_i,
  259. out_channels=out_channels_i,
  260. kernel_size=kernel_size_i,
  261. stride=stride,
  262. padding=padding_i,
  263. dilation=dilation,
  264. groups=(out_channels_i
  265. if out_channels == groups else groups),
  266. bias_attr=bias,
  267. weight_attr=None))
  268. self.axis = axis
  269. def forward(self, x):
  270. xx = paddle.split(x, self.splitted_in_channels, axis=self.axis)
  271. xx = paddle.split(x, self.splitted_in_channels, axis=self.axis)
  272. out = [
  273. conv_i(x_i) for x_i, conv_i in zip(xx, self._sub_layers.values())
  274. ]
  275. x = paddle.concat(tuple(out), axis=self.axis)
  276. return x
  277. @staticmethod
  278. def split_channels(channels, kernel_count):
  279. splitted_channels = [channels // kernel_count] * kernel_count
  280. splitted_channels[0] += channels - sum(splitted_channels)
  281. return splitted_channels
  282. class MixConvBlock(nn.Layer):
  283. """
  284. Mixed convolution block with Batch normalization and activation.
  285. Parameters:
  286. ----------
  287. in_channels : int
  288. Number of input channels.
  289. out_channels : int
  290. Number of output channels.
  291. kernel_size : int or tuple/list of int, or tuple/list of tuple/list of 2 int
  292. Convolution window size.
  293. stride : int or tuple/list of 2 int
  294. Strides of the convolution.
  295. padding : int or tuple/list of int, or tuple/list of tuple/list of 2 int
  296. Padding value for convolution layer.
  297. dilation : int or tuple/list of 2 int, default 1
  298. Dilation value for convolution layer.
  299. groups : int, default 1
  300. Number of groups.
  301. bias : bool, default False
  302. Whether the layer uses a bias vector.
  303. use_bn : bool, default True
  304. Whether to use BatchNorm layer.
  305. bn_eps : float, default 1e-5
  306. Small float added to variance in Batch norm.
  307. activation : function or str or None, default nn.ReLU()
  308. Activation function or name of activation function.
  309. activate : bool, default True
  310. Whether activate the convolution block.
  311. """
  312. def __init__(self,
  313. in_channels,
  314. out_channels,
  315. kernel_size,
  316. stride,
  317. padding,
  318. dilation=1,
  319. groups=1,
  320. bias=False,
  321. use_bn=True,
  322. bn_eps=1e-5,
  323. activation=nn.ReLU()):
  324. super(MixConvBlock, self).__init__()
  325. self.activate = (activation is not None)
  326. self.use_bn = use_bn
  327. self.conv = MixConv(
  328. in_channels=in_channels,
  329. out_channels=out_channels,
  330. kernel_size=kernel_size,
  331. stride=stride,
  332. padding=padding,
  333. dilation=dilation,
  334. groups=groups,
  335. bias=bias)
  336. if self.use_bn:
  337. self.bn = nn.BatchNorm2D(num_features=out_channels, epsilon=bn_eps)
  338. if self.activate:
  339. self.activ = get_activation_layer(activation)
  340. def forward(self, x):
  341. x = self.conv(x)
  342. if self.use_bn:
  343. x = self.bn(x)
  344. if self.activate:
  345. x = self.activ(x)
  346. return x
  347. def mixconv1x1_block(in_channels,
  348. out_channels,
  349. kernel_count,
  350. stride=1,
  351. groups=1,
  352. bias=False,
  353. use_bn=True,
  354. bn_eps=1e-5,
  355. activation=nn.ReLU()):
  356. """
  357. 1x1 version of the mixed convolution block.
  358. Parameters:
  359. ----------
  360. in_channels : int
  361. Number of input channels.
  362. out_channels : int
  363. Number of output channels.
  364. kernel_count : int
  365. Kernel count.
  366. stride : int or tuple/list of 2 int, default 1
  367. Strides of the convolution.
  368. groups : int, default 1
  369. Number of groups.
  370. bias : bool, default False
  371. Whether the layer uses a bias vector.
  372. use_bn : bool, default True
  373. Whether to use BatchNorm layer.
  374. bn_eps : float, default 1e-5
  375. Small float added to variance in Batch norm.
  376. activation : function or str, or None, default nn.ReLU()
  377. Activation function or name of activation function.
  378. """
  379. return MixConvBlock(
  380. in_channels=in_channels,
  381. out_channels=out_channels,
  382. kernel_size=([1] * kernel_count),
  383. stride=stride,
  384. padding=([0] * kernel_count),
  385. groups=groups,
  386. bias=bias,
  387. use_bn=use_bn,
  388. bn_eps=bn_eps,
  389. activation=activation)
  390. class MixUnit(nn.Layer):
  391. """
  392. MixNet unit.
  393. Parameters:
  394. ----------
  395. in_channels : int
  396. Number of input channels.
  397. out_channels : int
  398. Number of output channels. exp_channels : int
  399. Number of middle (expanded) channels.
  400. stride : int or tuple/list of 2 int
  401. Strides of the second convolution layer.
  402. exp_kernel_count : int
  403. Expansion convolution kernel count for each unit.
  404. conv1_kernel_count : int
  405. Conv1 kernel count for each unit.
  406. conv2_kernel_count : int
  407. Conv2 kernel count for each unit.
  408. exp_factor : int
  409. Expansion factor for each unit.
  410. se_factor : int
  411. SE reduction factor for each unit.
  412. activation : str
  413. Activation function or name of activation function.
  414. """
  415. def __init__(self, in_channels, out_channels, stride, exp_kernel_count,
  416. conv1_kernel_count, conv2_kernel_count, exp_factor, se_factor,
  417. activation):
  418. super(MixUnit, self).__init__()
  419. assert exp_factor >= 1
  420. assert se_factor >= 0
  421. self.residual = (in_channels == out_channels) and (stride == 1)
  422. self.use_se = se_factor > 0
  423. mid_channels = exp_factor * in_channels
  424. self.use_exp_conv = exp_factor > 1
  425. if self.use_exp_conv:
  426. if exp_kernel_count == 1:
  427. self.exp_conv = ConvBlock(
  428. in_channels=in_channels,
  429. out_channels=mid_channels,
  430. kernel_size=1,
  431. stride=1,
  432. padding=0,
  433. groups=1,
  434. bias=False,
  435. use_bn=True,
  436. bn_eps=1e-5,
  437. activation=activation)
  438. else:
  439. self.exp_conv = mixconv1x1_block(
  440. in_channels=in_channels,
  441. out_channels=mid_channels,
  442. kernel_count=exp_kernel_count,
  443. activation=activation)
  444. if conv1_kernel_count == 1:
  445. self.conv1 = ConvBlock(
  446. in_channels=mid_channels,
  447. out_channels=mid_channels,
  448. kernel_size=3,
  449. stride=stride,
  450. padding=1,
  451. dilation=1,
  452. groups=mid_channels,
  453. bias=False,
  454. use_bn=True,
  455. bn_eps=1e-5,
  456. activation=activation)
  457. else:
  458. self.conv1 = MixConvBlock(
  459. in_channels=mid_channels,
  460. out_channels=mid_channels,
  461. kernel_size=[3 + 2 * i for i in range(conv1_kernel_count)],
  462. stride=stride,
  463. padding=[1 + i for i in range(conv1_kernel_count)],
  464. groups=mid_channels,
  465. activation=activation)
  466. if self.use_se:
  467. self.se = SEBlock(
  468. channels=mid_channels,
  469. reduction=(exp_factor * se_factor),
  470. round_mid=False,
  471. mid_activation=activation)
  472. if conv2_kernel_count == 1:
  473. self.conv2 = ConvBlock(
  474. in_channels=mid_channels,
  475. out_channels=out_channels,
  476. activation=None,
  477. kernel_size=1,
  478. stride=1,
  479. padding=0,
  480. groups=1,
  481. bias=False,
  482. use_bn=True,
  483. bn_eps=1e-5)
  484. else:
  485. self.conv2 = mixconv1x1_block(
  486. in_channels=mid_channels,
  487. out_channels=out_channels,
  488. kernel_count=conv2_kernel_count,
  489. activation=None)
  490. def forward(self, x):
  491. if self.residual:
  492. identity = x
  493. if self.use_exp_conv:
  494. x = self.exp_conv(x)
  495. x = self.conv1(x)
  496. if self.use_se:
  497. x = self.se(x)
  498. x = self.conv2(x)
  499. if self.residual:
  500. x = x + identity
  501. return x
  502. class MixInitBlock(nn.Layer):
  503. """
  504. MixNet specific initial block.
  505. Parameters:
  506. ----------
  507. in_channels : int
  508. Number of input channels.
  509. out_channels : int
  510. Number of output channels.
  511. """
  512. def __init__(self, in_channels, out_channels):
  513. super(MixInitBlock, self).__init__()
  514. self.conv1 = ConvBlock(
  515. in_channels=in_channels,
  516. out_channels=out_channels,
  517. stride=2,
  518. kernel_size=3,
  519. padding=1)
  520. self.conv2 = MixUnit(
  521. in_channels=out_channels,
  522. out_channels=out_channels,
  523. stride=1,
  524. exp_kernel_count=1,
  525. conv1_kernel_count=1,
  526. conv2_kernel_count=1,
  527. exp_factor=1,
  528. se_factor=0,
  529. activation="relu")
  530. def forward(self, x):
  531. x = self.conv1(x)
  532. x = self.conv2(x)
  533. return x
  534. class MixNet(nn.Layer):
  535. """
  536. MixNet model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  537. https://arxiv.org/abs/1907.09595.
  538. Parameters:
  539. ----------
  540. channels : list of list of int
  541. Number of output channels for each unit.
  542. init_block_channels : int
  543. Number of output channels for the initial unit.
  544. final_block_channels : int
  545. Number of output channels for the final block of the feature extractor.
  546. exp_kernel_counts : list of list of int
  547. Expansion convolution kernel count for each unit.
  548. conv1_kernel_counts : list of list of int
  549. Conv1 kernel count for each unit.
  550. conv2_kernel_counts : list of list of int
  551. Conv2 kernel count for each unit.
  552. exp_factors : list of list of int
  553. Expansion factor for each unit.
  554. se_factors : list of list of int
  555. SE reduction factor for each unit.
  556. in_channels : int, default 3
  557. Number of input channels.
  558. in_size : tuple of two ints, default (224, 224)
  559. Spatial size of the expected input image.
  560. class_dim : int, default 1000
  561. Number of classification classes.
  562. """
  563. def __init__(self,
  564. channels,
  565. init_block_channels,
  566. final_block_channels,
  567. exp_kernel_counts,
  568. conv1_kernel_counts,
  569. conv2_kernel_counts,
  570. exp_factors,
  571. se_factors,
  572. in_channels=3,
  573. in_size=(224, 224),
  574. class_dim=1000):
  575. super(MixNet, self).__init__()
  576. self.in_size = in_size
  577. self.class_dim = class_dim
  578. self.features = nn.Sequential()
  579. self.features.add_sublayer(
  580. "init_block",
  581. MixInitBlock(
  582. in_channels=in_channels, out_channels=init_block_channels))
  583. in_channels = init_block_channels
  584. for i, channels_per_stage in enumerate(channels):
  585. stage = nn.Sequential()
  586. for j, out_channels in enumerate(channels_per_stage):
  587. stride = 2 if ((j == 0) and (i != 3)) or (
  588. (j == len(channels_per_stage) // 2) and (i == 3)) else 1
  589. exp_kernel_count = exp_kernel_counts[i][j]
  590. conv1_kernel_count = conv1_kernel_counts[i][j]
  591. conv2_kernel_count = conv2_kernel_counts[i][j]
  592. exp_factor = exp_factors[i][j]
  593. se_factor = se_factors[i][j]
  594. activation = "relu" if i == 0 else "swish"
  595. stage.add_sublayer(
  596. "unit{}".format(j + 1),
  597. MixUnit(
  598. in_channels=in_channels,
  599. out_channels=out_channels,
  600. stride=stride,
  601. exp_kernel_count=exp_kernel_count,
  602. conv1_kernel_count=conv1_kernel_count,
  603. conv2_kernel_count=conv2_kernel_count,
  604. exp_factor=exp_factor,
  605. se_factor=se_factor,
  606. activation=activation))
  607. in_channels = out_channels
  608. self.features.add_sublayer("stage{}".format(i + 1), stage)
  609. self.features.add_sublayer(
  610. "final_block",
  611. ConvBlock(
  612. in_channels=in_channels,
  613. out_channels=final_block_channels,
  614. kernel_size=1,
  615. stride=1,
  616. padding=0,
  617. groups=1,
  618. bias=False,
  619. use_bn=True,
  620. bn_eps=1e-5,
  621. activation=nn.ReLU()))
  622. in_channels = final_block_channels
  623. self.features.add_sublayer(
  624. "final_pool", nn.AvgPool2D(
  625. kernel_size=7, stride=1))
  626. self.output = nn.Linear(
  627. in_features=in_channels, out_features=class_dim)
  628. def forward(self, x):
  629. x = self.features(x)
  630. reshape_dim = reduce(lambda x, y: x * y, x.shape[1:])
  631. x = x.reshape(shape=[x.shape[0], reshape_dim])
  632. x = self.output(x)
  633. return x
  634. def get_mixnet(version, width_scale, model_name=None, **kwargs):
  635. """
  636. Create MixNet model with specific parameters.
  637. Parameters:
  638. ----------
  639. version : str
  640. Version of MobileNetV3 ('s' or 'm').
  641. width_scale : float
  642. Scale factor for width of layers.
  643. model_name : str or None, default None
  644. Model name.
  645. """
  646. if version == "s":
  647. init_block_channels = 16
  648. channels = [[24, 24], [40, 40, 40, 40], [80, 80, 80],
  649. [120, 120, 120, 200, 200, 200]]
  650. exp_kernel_counts = [[2, 2], [1, 2, 2, 2], [1, 1, 1],
  651. [2, 2, 2, 1, 1, 1]]
  652. conv1_kernel_counts = [[1, 1], [3, 2, 2, 2], [3, 2, 2],
  653. [3, 4, 4, 5, 4, 4]]
  654. conv2_kernel_counts = [[2, 2], [1, 2, 2, 2], [2, 2, 2],
  655. [2, 2, 2, 1, 2, 2]]
  656. exp_factors = [[6, 3], [6, 6, 6, 6], [6, 6, 6], [6, 3, 3, 6, 6, 6]]
  657. se_factors = [[0, 0], [2, 2, 2, 2], [4, 4, 4], [2, 2, 2, 2, 2, 2]]
  658. elif version == "m":
  659. init_block_channels = 24
  660. channels = [[32, 32], [40, 40, 40, 40], [80, 80, 80, 80],
  661. [120, 120, 120, 120, 200, 200, 200, 200]]
  662. exp_kernel_counts = [[2, 2], [1, 2, 2, 2], [1, 2, 2, 2],
  663. [1, 2, 2, 2, 1, 1, 1, 1]]
  664. conv1_kernel_counts = [[3, 1], [4, 2, 2, 2], [3, 4, 4, 4],
  665. [1, 4, 4, 4, 4, 4, 4, 4]]
  666. conv2_kernel_counts = [[2, 2], [1, 2, 2, 2], [1, 2, 2, 2],
  667. [1, 2, 2, 2, 1, 2, 2, 2]]
  668. exp_factors = [[6, 3], [6, 6, 6, 6], [6, 6, 6, 6],
  669. [6, 3, 3, 3, 6, 6, 6, 6]]
  670. se_factors = [[0, 0], [2, 2, 2, 2], [4, 4, 4, 4],
  671. [2, 2, 2, 2, 2, 2, 2, 2]]
  672. else:
  673. raise ValueError("Unsupported MixNet version {}".format(version))
  674. final_block_channels = 1536
  675. if width_scale != 1.0:
  676. channels = [[round_channels(cij * width_scale) for cij in ci]
  677. for ci in channels]
  678. init_block_channels = round_channels(init_block_channels * width_scale)
  679. net = MixNet(
  680. channels=channels,
  681. init_block_channels=init_block_channels,
  682. final_block_channels=final_block_channels,
  683. exp_kernel_counts=exp_kernel_counts,
  684. conv1_kernel_counts=conv1_kernel_counts,
  685. conv2_kernel_counts=conv2_kernel_counts,
  686. exp_factors=exp_factors,
  687. se_factors=se_factors,
  688. **kwargs)
  689. return net
  690. def MixNet_S(**kwargs):
  691. """
  692. MixNet-S model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  693. https://arxiv.org/abs/1907.09595.
  694. """
  695. return get_mixnet(
  696. version="s", width_scale=1.0, model_name="MixNet_S", **kwargs)
  697. def MixNet_M(**kwargs):
  698. """
  699. MixNet-M model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  700. https://arxiv.org/abs/1907.09595.
  701. """
  702. return get_mixnet(
  703. version="m", width_scale=1.0, model_name="MixNet_M", **kwargs)
  704. def MixNet_L(**kwargs):
  705. """
  706. MixNet-L model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
  707. https://arxiv.org/abs/1907.09595.
  708. """
  709. return get_mixnet(
  710. version="m", width_scale=1.3, model_name="MixNet_L", **kwargs)