gcnet.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.cvlibs import manager
  18. from paddlex.paddleseg.models import layers
  19. from paddlex.paddleseg.utils import utils
  20. @manager.MODELS.add_component
  21. class GCNet(nn.Layer):
  22. """
  23. The GCNet implementation based on PaddlePaddle.
  24. The original article refers to
  25. Cao, Yue, et al. "GCnet: Non-local networks meet squeeze-excitation networks and beyond"
  26. (https://arxiv.org/pdf/1904.11492.pdf).
  27. Args:
  28. num_classes (int): The unique number of target classes.
  29. backbone (Paddle.nn.Layer): Backbone network, currently support Resnet50/101.
  30. backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone.
  31. gc_channels (int, optional): The input channels to Global Context Block. Default: 512.
  32. ratio (float, optional): It indicates the ratio of attention channels and gc_channels. Default: 0.25.
  33. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  34. align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
  35. e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  36. pretrained (str, optional): The path or url of pretrained model. Default: None.
  37. """
  38. def __init__(self,
  39. num_classes,
  40. backbone,
  41. backbone_indices=(2, 3),
  42. gc_channels=512,
  43. ratio=0.25,
  44. enable_auxiliary_loss=True,
  45. align_corners=False,
  46. pretrained=None):
  47. super().__init__()
  48. self.backbone = backbone
  49. backbone_channels = [
  50. backbone.feat_channels[i] for i in backbone_indices
  51. ]
  52. self.head = GCNetHead(num_classes, backbone_indices, backbone_channels,
  53. gc_channels, ratio, enable_auxiliary_loss)
  54. self.align_corners = align_corners
  55. self.pretrained = pretrained
  56. self.init_weight()
  57. def forward(self, x):
  58. feat_list = self.backbone(x)
  59. logit_list = self.head(feat_list)
  60. return [
  61. F.interpolate(
  62. logit,
  63. paddle.shape(x)[2:],
  64. mode='bilinear',
  65. align_corners=self.align_corners) for logit in logit_list
  66. ]
  67. def init_weight(self):
  68. if self.pretrained is not None:
  69. utils.load_entire_model(self, self.pretrained)
  70. class GCNetHead(nn.Layer):
  71. """
  72. The GCNetHead implementation.
  73. Args:
  74. num_classes (int): The unique number of target classes.
  75. backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone.
  76. The first index will be taken as a deep-supervision feature in auxiliary layer;
  77. the second one will be taken as input of GlobalContextBlock.
  78. backbone_channels (tuple): The same length with "backbone_indices". It indicates the channels of corresponding index.
  79. gc_channels (int): The input channels to Global Context Block.
  80. ratio (float): It indicates the ratio of attention channels and gc_channels.
  81. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  82. """
  83. def __init__(self,
  84. num_classes,
  85. backbone_indices,
  86. backbone_channels,
  87. gc_channels,
  88. ratio,
  89. enable_auxiliary_loss=True):
  90. super().__init__()
  91. in_channels = backbone_channels[1]
  92. self.conv_bn_relu1 = layers.ConvBNReLU(
  93. in_channels=in_channels,
  94. out_channels=gc_channels,
  95. kernel_size=3,
  96. padding=1)
  97. self.gc_block = GlobalContextBlock(
  98. gc_channels=gc_channels, in_channels=gc_channels, ratio=ratio)
  99. self.conv_bn_relu2 = layers.ConvBNReLU(
  100. in_channels=gc_channels,
  101. out_channels=gc_channels,
  102. kernel_size=3,
  103. padding=1)
  104. self.conv_bn_relu3 = layers.ConvBNReLU(
  105. in_channels=in_channels + gc_channels,
  106. out_channels=gc_channels,
  107. kernel_size=3,
  108. padding=1)
  109. self.dropout = nn.Dropout(p=0.1)
  110. self.conv = nn.Conv2D(
  111. in_channels=gc_channels, out_channels=num_classes, kernel_size=1)
  112. if enable_auxiliary_loss:
  113. self.auxlayer = layers.AuxLayer(
  114. in_channels=backbone_channels[0],
  115. inter_channels=backbone_channels[0] // 4,
  116. out_channels=num_classes)
  117. self.backbone_indices = backbone_indices
  118. self.enable_auxiliary_loss = enable_auxiliary_loss
  119. def forward(self, feat_list):
  120. logit_list = []
  121. x = feat_list[self.backbone_indices[1]]
  122. output = self.conv_bn_relu1(x)
  123. output = self.gc_block(output)
  124. output = self.conv_bn_relu2(output)
  125. output = paddle.concat([x, output], axis=1)
  126. output = self.conv_bn_relu3(output)
  127. output = self.dropout(output)
  128. logit = self.conv(output)
  129. logit_list.append(logit)
  130. if self.enable_auxiliary_loss:
  131. low_level_feat = feat_list[self.backbone_indices[0]]
  132. auxiliary_logit = self.auxlayer(low_level_feat)
  133. logit_list.append(auxiliary_logit)
  134. return logit_list
  135. class GlobalContextBlock(nn.Layer):
  136. """
  137. Global Context Block implementation.
  138. Args:
  139. in_channels (int): The input channels of Global Context Block.
  140. ratio (float): The channels of attention map.
  141. """
  142. def __init__(self, gc_channels, in_channels, ratio):
  143. super().__init__()
  144. self.gc_channels = gc_channels
  145. self.conv_mask = nn.Conv2D(
  146. in_channels=in_channels, out_channels=1, kernel_size=1)
  147. self.softmax = nn.Softmax(axis=2)
  148. inter_channels = int(in_channels * ratio)
  149. self.channel_add_conv = nn.Sequential(
  150. nn.Conv2D(
  151. in_channels=in_channels,
  152. out_channels=inter_channels,
  153. kernel_size=1),
  154. nn.LayerNorm(normalized_shape=[inter_channels, 1, 1]),
  155. nn.ReLU(),
  156. nn.Conv2D(
  157. in_channels=inter_channels,
  158. out_channels=in_channels,
  159. kernel_size=1))
  160. def global_context_block(self, x):
  161. x_shape = paddle.shape(x)
  162. # [N, C, H * W]
  163. input_x = paddle.reshape(x, shape=[0, self.gc_channels, -1])
  164. # [N, 1, C, H * W]
  165. input_x = paddle.unsqueeze(input_x, axis=1)
  166. # [N, 1, H, W]
  167. context_mask = self.conv_mask(x)
  168. # [N, 1, H * W]
  169. context_mask = paddle.reshape(context_mask, shape=[0, 1, -1])
  170. context_mask = self.softmax(context_mask)
  171. # [N, 1, H * W, 1]
  172. context_mask = paddle.unsqueeze(context_mask, axis=-1)
  173. # [N, 1, C, 1]
  174. context = paddle.matmul(input_x, context_mask)
  175. # [N, C, 1, 1]
  176. context = paddle.reshape(context, shape=[0, self.gc_channels, 1, 1])
  177. return context
  178. def forward(self, x):
  179. context = self.global_context_block(x)
  180. channel_add_term = self.channel_add_conv(context)
  181. out = x + channel_add_term
  182. return out