vgg.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import paddle
  2. from paddle import ParamAttr
  3. import paddle.nn as nn
  4. import paddle.nn.functional as F
  5. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  6. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  7. __all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]
  8. class ConvBlock(nn.Layer):
  9. def __init__(self, input_channels, output_channels, groups, name=None):
  10. super(ConvBlock, self).__init__()
  11. self.groups = groups
  12. self._conv_1 = Conv2D(
  13. in_channels=input_channels,
  14. out_channels=output_channels,
  15. kernel_size=3,
  16. stride=1,
  17. padding=1,
  18. weight_attr=ParamAttr(name=name + "1_weights"),
  19. bias_attr=False)
  20. if groups == 2 or groups == 3 or groups == 4:
  21. self._conv_2 = Conv2D(
  22. in_channels=output_channels,
  23. out_channels=output_channels,
  24. kernel_size=3,
  25. stride=1,
  26. padding=1,
  27. weight_attr=ParamAttr(name=name + "2_weights"),
  28. bias_attr=False)
  29. if groups == 3 or groups == 4:
  30. self._conv_3 = Conv2D(
  31. in_channels=output_channels,
  32. out_channels=output_channels,
  33. kernel_size=3,
  34. stride=1,
  35. padding=1,
  36. weight_attr=ParamAttr(name=name + "3_weights"),
  37. bias_attr=False)
  38. if groups == 4:
  39. self._conv_4 = Conv2D(
  40. in_channels=output_channels,
  41. out_channels=output_channels,
  42. kernel_size=3,
  43. stride=1,
  44. padding=1,
  45. weight_attr=ParamAttr(name=name + "4_weights"),
  46. bias_attr=False)
  47. self._pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
  48. def forward(self, inputs):
  49. x = self._conv_1(inputs)
  50. x = F.relu(x)
  51. if self.groups == 2 or self.groups == 3 or self.groups == 4:
  52. x = self._conv_2(x)
  53. x = F.relu(x)
  54. if self.groups == 3 or self.groups == 4:
  55. x = self._conv_3(x)
  56. x = F.relu(x)
  57. if self.groups == 4:
  58. x = self._conv_4(x)
  59. x = F.relu(x)
  60. x = self._pool(x)
  61. return x
  62. class VGGNet(nn.Layer):
  63. def __init__(self, layers=11, stop_grad_layers=0, class_dim=1000):
  64. super(VGGNet, self).__init__()
  65. self.layers = layers
  66. self.stop_grad_layers = stop_grad_layers
  67. self.vgg_configure = {
  68. 11: [1, 1, 2, 2, 2],
  69. 13: [2, 2, 2, 2, 2],
  70. 16: [2, 2, 3, 3, 3],
  71. 19: [2, 2, 4, 4, 4]
  72. }
  73. assert self.layers in self.vgg_configure.keys(), \
  74. "supported layers are {} but input layer is {}".format(
  75. self.vgg_configure.keys(), layers)
  76. self.groups = self.vgg_configure[self.layers]
  77. self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")
  78. self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")
  79. self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")
  80. self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
  81. self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
  82. for idx, block in enumerate([
  83. self._conv_block_1, self._conv_block_2, self._conv_block_3,
  84. self._conv_block_4, self._conv_block_5
  85. ]):
  86. if self.stop_grad_layers >= idx + 1:
  87. for param in block.parameters():
  88. param.trainable = False
  89. self._drop = Dropout(p=0.5, mode="downscale_in_infer")
  90. self._fc1 = Linear(
  91. 7 * 7 * 512,
  92. 4096,
  93. weight_attr=ParamAttr(name="fc6_weights"),
  94. bias_attr=ParamAttr(name="fc6_offset"))
  95. self._fc2 = Linear(
  96. 4096,
  97. 4096,
  98. weight_attr=ParamAttr(name="fc7_weights"),
  99. bias_attr=ParamAttr(name="fc7_offset"))
  100. self._out = Linear(
  101. 4096,
  102. class_dim,
  103. weight_attr=ParamAttr(name="fc8_weights"),
  104. bias_attr=ParamAttr(name="fc8_offset"))
  105. def forward(self, inputs):
  106. x = self._conv_block_1(inputs)
  107. x = self._conv_block_2(x)
  108. x = self._conv_block_3(x)
  109. x = self._conv_block_4(x)
  110. x = self._conv_block_5(x)
  111. x = paddle.flatten(x, start_axis=1, stop_axis=-1)
  112. x = self._fc1(x)
  113. x = F.relu(x)
  114. x = self._drop(x)
  115. x = self._fc2(x)
  116. x = F.relu(x)
  117. x = self._drop(x)
  118. x = self._out(x)
  119. return x
  120. def VGG11(**args):
  121. model = VGGNet(layers=11, **args)
  122. return model
  123. def VGG13(**args):
  124. model = VGGNet(layers=13, **args)
  125. return model
  126. def VGG16(**args):
  127. model = VGGNet(layers=16, **args)
  128. return model
  129. def VGG19(**args):
  130. model = VGGNet(layers=19, **args)
  131. return model