squeezenet.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import paddle
  2. from paddle import ParamAttr
  3. import paddle.nn as nn
  4. import paddle.nn.functional as F
  5. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  6. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  7. __all__ = ["SqueezeNet1_0", "SqueezeNet1_1"]
  8. class MakeFireConv(nn.Layer):
  9. def __init__(self,
  10. input_channels,
  11. output_channels,
  12. filter_size,
  13. padding=0,
  14. name=None):
  15. super(MakeFireConv, self).__init__()
  16. self._conv = Conv2D(
  17. input_channels,
  18. output_channels,
  19. filter_size,
  20. padding=padding,
  21. weight_attr=ParamAttr(name=name + "_weights"),
  22. bias_attr=ParamAttr(name=name + "_offset"))
  23. def forward(self, x):
  24. x = self._conv(x)
  25. x = F.relu(x)
  26. return x
  27. class MakeFire(nn.Layer):
  28. def __init__(self,
  29. input_channels,
  30. squeeze_channels,
  31. expand1x1_channels,
  32. expand3x3_channels,
  33. name=None):
  34. super(MakeFire, self).__init__()
  35. self._conv = MakeFireConv(
  36. input_channels, squeeze_channels, 1, name=name + "_squeeze1x1")
  37. self._conv_path1 = MakeFireConv(
  38. squeeze_channels, expand1x1_channels, 1, name=name + "_expand1x1")
  39. self._conv_path2 = MakeFireConv(
  40. squeeze_channels,
  41. expand3x3_channels,
  42. 3,
  43. padding=1,
  44. name=name + "_expand3x3")
  45. def forward(self, inputs):
  46. x = self._conv(inputs)
  47. x1 = self._conv_path1(x)
  48. x2 = self._conv_path2(x)
  49. return paddle.concat([x1, x2], axis=1)
  50. class SqueezeNet(nn.Layer):
  51. def __init__(self, version, class_dim=1000):
  52. super(SqueezeNet, self).__init__()
  53. self.version = version
  54. if self.version == "1.0":
  55. self._conv = Conv2D(
  56. 3,
  57. 96,
  58. 7,
  59. stride=2,
  60. weight_attr=ParamAttr(name="conv1_weights"),
  61. bias_attr=ParamAttr(name="conv1_offset"))
  62. self._pool = MaxPool2D(kernel_size=3, stride=2, padding=0)
  63. self._conv1 = MakeFire(96, 16, 64, 64, name="fire2")
  64. self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
  65. self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
  66. self._conv4 = MakeFire(256, 32, 128, 128, name="fire5")
  67. self._conv5 = MakeFire(256, 48, 192, 192, name="fire6")
  68. self._conv6 = MakeFire(384, 48, 192, 192, name="fire7")
  69. self._conv7 = MakeFire(384, 64, 256, 256, name="fire8")
  70. self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
  71. else:
  72. self._conv = Conv2D(
  73. 3,
  74. 64,
  75. 3,
  76. stride=2,
  77. padding=1,
  78. weight_attr=ParamAttr(name="conv1_weights"),
  79. bias_attr=ParamAttr(name="conv1_offset"))
  80. self._pool = MaxPool2D(kernel_size=3, stride=2, padding=0)
  81. self._conv1 = MakeFire(64, 16, 64, 64, name="fire2")
  82. self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
  83. self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
  84. self._conv4 = MakeFire(256, 32, 128, 128, name="fire5")
  85. self._conv5 = MakeFire(256, 48, 192, 192, name="fire6")
  86. self._conv6 = MakeFire(384, 48, 192, 192, name="fire7")
  87. self._conv7 = MakeFire(384, 64, 256, 256, name="fire8")
  88. self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
  89. self._drop = Dropout(p=0.5, mode="downscale_in_infer")
  90. self._conv9 = Conv2D(
  91. 512,
  92. class_dim,
  93. 1,
  94. weight_attr=ParamAttr(name="conv10_weights"),
  95. bias_attr=ParamAttr(name="conv10_offset"))
  96. self._avg_pool = AdaptiveAvgPool2D(1)
  97. def forward(self, inputs):
  98. x = self._conv(inputs)
  99. x = F.relu(x)
  100. x = self._pool(x)
  101. if self.version == "1.0":
  102. x = self._conv1(x)
  103. x = self._conv2(x)
  104. x = self._conv3(x)
  105. x = self._pool(x)
  106. x = self._conv4(x)
  107. x = self._conv5(x)
  108. x = self._conv6(x)
  109. x = self._conv7(x)
  110. x = self._pool(x)
  111. x = self._conv8(x)
  112. else:
  113. x = self._conv1(x)
  114. x = self._conv2(x)
  115. x = self._pool(x)
  116. x = self._conv3(x)
  117. x = self._conv4(x)
  118. x = self._pool(x)
  119. x = self._conv5(x)
  120. x = self._conv6(x)
  121. x = self._conv7(x)
  122. x = self._conv8(x)
  123. x = self._drop(x)
  124. x = self._conv9(x)
  125. x = F.relu(x)
  126. x = self._avg_pool(x)
  127. x = paddle.squeeze(x, axis=[2, 3])
  128. return x
  129. def SqueezeNet1_0(**args):
  130. model = SqueezeNet(version="1.0", **args)
  131. return model
  132. def SqueezeNet1_1(**args):
  133. model = SqueezeNet(version="1.1", **args)
  134. return model