shufflenet_slim.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.cvlibs import manager, param_init
  18. from paddlex.paddleseg.models import layers
  19. from paddlex.paddleseg.utils import utils
  20. __all__ = ['ShuffleNetV2']
  21. @manager.MODELS.add_component
  22. class ShuffleNetV2(nn.Layer):
  23. def __init__(self, num_classes, pretrained=None, align_corners=False):
  24. super().__init__()
  25. self.pretrained = pretrained
  26. self.num_classes = num_classes
  27. self.align_corners = align_corners
  28. self.conv_bn0 = _ConvBNReLU(3, 36, 3, 2, 1)
  29. self.conv_bn1 = _ConvBNReLU(36, 18, 1, 1, 0)
  30. self.block1 = nn.Sequential(
  31. SFNetV2Module(36, stride=2, out_channels=72),
  32. SFNetV2Module(72, stride=1), SFNetV2Module(72, stride=1),
  33. SFNetV2Module(72, stride=1))
  34. self.block2 = nn.Sequential(
  35. SFNetV2Module(72, stride=2), SFNetV2Module(144, stride=1),
  36. SFNetV2Module(144, stride=1), SFNetV2Module(144, stride=1),
  37. SFNetV2Module(144, stride=1), SFNetV2Module(144, stride=1),
  38. SFNetV2Module(144, stride=1), SFNetV2Module(144, stride=1))
  39. self.depthwise_separable0 = _SeparableConvBNReLU(144, 64, 3, stride=1)
  40. self.depthwise_separable1 = _SeparableConvBNReLU(82, 64, 3, stride=1)
  41. weight_attr = paddle.ParamAttr(
  42. learning_rate=1.,
  43. regularizer=paddle.regularizer.L2Decay(coeff=0.),
  44. initializer=nn.initializer.XavierUniform())
  45. self.deconv = nn.Conv2DTranspose(
  46. 64,
  47. self.num_classes,
  48. 2,
  49. stride=2,
  50. padding=0,
  51. weight_attr=weight_attr,
  52. bias_attr=True)
  53. self.init_weight()
  54. def forward(self, x):
  55. ## Encoder
  56. conv1 = self.conv_bn0(x) # encoder 1
  57. shortcut = self.conv_bn1(conv1) # shortcut 1
  58. pool = F.max_pool2d(
  59. conv1, kernel_size=3, stride=2, padding=1) # encoder 2
  60. # Block 1
  61. conv = self.block1(pool) # encoder 3
  62. # Block 2
  63. conv = self.block2(conv) # encoder 4
  64. ### decoder
  65. conv = self.depthwise_separable0(conv)
  66. shortcut_shape = paddle.shape(shortcut)[2:]
  67. conv_b = F.interpolate(
  68. conv,
  69. shortcut_shape,
  70. mode='bilinear',
  71. align_corners=self.align_corners)
  72. concat = paddle.concat(x=[shortcut, conv_b], axis=1)
  73. decode_conv = self.depthwise_separable1(concat)
  74. logit = self.deconv(decode_conv)
  75. return [logit]
  76. def init_weight(self):
  77. for layer in self.sublayers():
  78. if isinstance(layer, nn.Conv2D):
  79. param_init.normal_init(layer.weight, std=0.001)
  80. elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
  81. param_init.constant_init(layer.weight, value=1.0)
  82. param_init.constant_init(layer.bias, value=0.0)
  83. if self.pretrained is not None:
  84. utils.load_pretrained_model(self, self.pretrained)
  85. class _ConvBNReLU(nn.Layer):
  86. def __init__(self,
  87. in_channels,
  88. out_channels,
  89. kernel_size,
  90. stride,
  91. padding,
  92. groups=1,
  93. **kwargs):
  94. super().__init__()
  95. weight_attr = paddle.ParamAttr(
  96. learning_rate=1, initializer=nn.initializer.KaimingUniform())
  97. self._conv = nn.Conv2D(
  98. in_channels,
  99. out_channels,
  100. kernel_size,
  101. padding=padding,
  102. stride=stride,
  103. groups=groups,
  104. weight_attr=weight_attr,
  105. bias_attr=False,
  106. **kwargs)
  107. self._batch_norm = layers.SyncBatchNorm(out_channels)
  108. def forward(self, x):
  109. x = self._conv(x)
  110. x = self._batch_norm(x)
  111. x = F.relu(x)
  112. return x
  113. class _ConvBN(nn.Layer):
  114. def __init__(self,
  115. in_channels,
  116. out_channels,
  117. kernel_size,
  118. stride,
  119. padding,
  120. groups=1,
  121. **kwargs):
  122. super().__init__()
  123. weight_attr = paddle.ParamAttr(
  124. learning_rate=1, initializer=nn.initializer.KaimingUniform())
  125. self._conv = nn.Conv2D(
  126. in_channels,
  127. out_channels,
  128. kernel_size,
  129. padding=padding,
  130. stride=stride,
  131. groups=groups,
  132. weight_attr=weight_attr,
  133. bias_attr=False,
  134. **kwargs)
  135. self._batch_norm = layers.SyncBatchNorm(out_channels)
  136. def forward(self, x):
  137. x = self._conv(x)
  138. x = self._batch_norm(x)
  139. return x
  140. class _SeparableConvBNReLU(nn.Layer):
  141. def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
  142. super().__init__()
  143. self.depthwise_conv = _ConvBN(
  144. in_channels,
  145. out_channels=in_channels,
  146. kernel_size=kernel_size,
  147. padding=int(kernel_size / 2),
  148. groups=in_channels,
  149. **kwargs)
  150. self.piontwise_conv = _ConvBNReLU(
  151. in_channels,
  152. out_channels,
  153. kernel_size=1,
  154. groups=1,
  155. stride=1,
  156. padding=0)
  157. def forward(self, x):
  158. x = self.depthwise_conv(x)
  159. x = self.piontwise_conv(x)
  160. return x
  161. class SFNetV2Module(nn.Layer):
  162. def __init__(self, input_channels, stride, out_channels=None):
  163. super().__init__()
  164. if stride == 1:
  165. branch_channel = int(input_channels / 2)
  166. else:
  167. branch_channel = input_channels
  168. if out_channels is None:
  169. self.in_channels = int(branch_channel)
  170. else:
  171. self.in_channels = int(out_channels / 2)
  172. self._depthwise_separable_0 = _SeparableConvBNReLU(
  173. input_channels, self.in_channels, 3, stride=stride)
  174. self._conv = _ConvBNReLU(
  175. branch_channel, self.in_channels, 1, stride=1, padding=0)
  176. self._depthwise_separable_1 = _SeparableConvBNReLU(
  177. self.in_channels, self.in_channels, 3, stride=stride)
  178. self.stride = stride
  179. def forward(self, input):
  180. if self.stride == 1:
  181. shortcut, branch = paddle.split(x=input, num_or_sections=2, axis=1)
  182. else:
  183. branch = input
  184. shortcut = self._depthwise_separable_0(input)
  185. branch_1x1 = self._conv(branch)
  186. branch_dw1x1 = self._depthwise_separable_1(branch_1x1)
  187. output = paddle.concat(x=[shortcut, branch_dw1x1], axis=1)
  188. # channel shuffle
  189. out_shape = paddle.shape(output)
  190. h, w = out_shape[2], out_shape[3]
  191. output = paddle.reshape(x=output, shape=[0, 2, self.in_channels, h, w])
  192. output = paddle.transpose(x=output, perm=[0, 2, 1, 3, 4])
  193. output = paddle.reshape(x=output, shape=[0, 2 * self.in_channels, h, w])
  194. return output
  195. if __name__ == '__main__':
  196. import numpy as np
  197. import os
  198. np.random.seed(100)
  199. paddle.seed(100)
  200. net = ShuffleNetV2(10)
  201. img = np.random.random(size=(4, 3, 100, 100)).astype('float32')
  202. img = paddle.to_tensor(img)
  203. out = net(img)
  204. print(out)
  205. net.forward = paddle.jit.to_static(net.forward)
  206. save_path = os.path.join('.', 'model')
  207. in_var = paddle.ones([4, 3, 100, 100])
  208. paddle.jit.save(net, save_path, input_spec=[in_var])