isanet.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.models import layers
  18. from paddlex.paddleseg.cvlibs import manager
  19. from paddlex.paddleseg.utils import utils
  20. @manager.MODELS.add_component
  21. class ISANet(nn.Layer):
  22. """Interlaced Sparse Self-Attention for Semantic Segmentation.
  23. The original article refers to Lang Huang, et al. "Interlaced Sparse Self-Attention for Semantic Segmentation"
  24. (https://arxiv.org/abs/1907.12273).
  25. Args:
  26. num_classes (int): The unique number of target classes.
  27. backbone (Paddle.nn.Layer): A backbone network.
  28. backbone_indices (tuple): The values in the tuple indicate the indices of output of backbone.
  29. isa_channels (int): The channels of ISA Module.
  30. down_factor (tuple): Divide the height and width dimension to (Ph, PW) groups.
  31. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  32. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  33. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  34. pretrained (str, optional): The path or url of pretrained model. Default: None.
  35. """
  36. def __init__(self,
  37. num_classes,
  38. backbone,
  39. backbone_indices=(2, 3),
  40. isa_channels=256,
  41. down_factor=(8, 8),
  42. enable_auxiliary_loss=True,
  43. align_corners=False,
  44. pretrained=None):
  45. super().__init__()
  46. self.backbone = backbone
  47. self.backbone_indices = backbone_indices
  48. in_channels = [self.backbone.feat_channels[i] for i in backbone_indices]
  49. self.head = ISAHead(num_classes, in_channels, isa_channels, down_factor,
  50. enable_auxiliary_loss)
  51. self.align_corners = align_corners
  52. self.pretrained = pretrained
  53. self.init_weight()
  54. def forward(self, x):
  55. feats = self.backbone(x)
  56. feats = [feats[i] for i in self.backbone_indices]
  57. logit_list = self.head(feats)
  58. logit_list = [
  59. F.interpolate(
  60. logit,
  61. paddle.shape(x)[2:],
  62. mode='bilinear',
  63. align_corners=self.align_corners,
  64. align_mode=1) for logit in logit_list
  65. ]
  66. return logit_list
  67. def init_weight(self):
  68. if self.pretrained is not None:
  69. utils.load_entire_model(self, self.pretrained)
  70. class ISAHead(nn.Layer):
  71. """
  72. The ISAHead.
  73. Args:
  74. num_classes (int): The unique number of target classes.
  75. in_channels (tuple): The number of input channels.
  76. isa_channels (int): The channels of ISA Module.
  77. down_factor (tuple): Divide the height and width dimension to (Ph, PW) groups.
  78. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  79. """
  80. def __init__(self, num_classes, in_channels, isa_channels, down_factor,
  81. enable_auxiliary_loss):
  82. super(ISAHead, self).__init__()
  83. self.in_channels = in_channels[-1]
  84. inter_channels = self.in_channels // 4
  85. self.inter_channels = inter_channels
  86. self.down_factor = down_factor
  87. self.enable_auxiliary_loss = enable_auxiliary_loss
  88. self.in_conv = layers.ConvBNReLU(
  89. self.in_channels, inter_channels, 3, bias_attr=False)
  90. self.global_relation = SelfAttentionBlock(inter_channels, isa_channels)
  91. self.local_relation = SelfAttentionBlock(inter_channels, isa_channels)
  92. self.out_conv = layers.ConvBNReLU(
  93. inter_channels * 2, inter_channels, 1, bias_attr=False)
  94. self.cls = nn.Sequential(
  95. nn.Dropout2D(p=0.1), nn.Conv2D(inter_channels, num_classes, 1))
  96. self.aux = nn.Sequential(
  97. layers.ConvBNReLU(
  98. in_channels=1024,
  99. out_channels=256,
  100. kernel_size=3,
  101. bias_attr=False), nn.Dropout2D(p=0.1),
  102. nn.Conv2D(256, num_classes, 1))
  103. def forward(self, feat_list):
  104. C3, C4 = feat_list
  105. x = self.in_conv(C4)
  106. x_shape = paddle.shape(x)
  107. P_h, P_w = self.down_factor
  108. Q_h, Q_w = paddle.ceil(x_shape[2] / P_h).astype('int32'), paddle.ceil(
  109. x_shape[3] / P_w).astype('int32')
  110. pad_h, pad_w = (Q_h * P_h - x_shape[2]).astype('int32'), (
  111. Q_w * P_w - x_shape[3]).astype('int32')
  112. if pad_h > 0 or pad_w > 0:
  113. padding = paddle.concat([
  114. pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
  115. ],
  116. axis=0)
  117. feat = F.pad(x, padding)
  118. else:
  119. feat = x
  120. feat = feat.reshape([0, x_shape[1], Q_h, P_h, Q_w, P_w])
  121. feat = feat.transpose([0, 3, 5, 1, 2,
  122. 4]).reshape([-1, self.inter_channels, Q_h, Q_w])
  123. feat = self.global_relation(feat)
  124. feat = feat.reshape([x_shape[0], P_h, P_w, x_shape[1], Q_h, Q_w])
  125. feat = feat.transpose([0, 4, 5, 3, 1,
  126. 2]).reshape([-1, self.inter_channels, P_h, P_w])
  127. feat = self.local_relation(feat)
  128. feat = feat.reshape([x_shape[0], Q_h, Q_w, x_shape[1], P_h, P_w])
  129. feat = feat.transpose([0, 3, 1, 4, 2, 5]).reshape(
  130. [0, self.inter_channels, P_h * Q_h, P_w * Q_w])
  131. if pad_h > 0 or pad_w > 0:
  132. feat = paddle.slice(
  133. feat,
  134. axes=[2, 3],
  135. starts=[pad_h // 2, pad_w // 2],
  136. ends=[pad_h // 2 + x_shape[2], pad_w // 2 + x_shape[3]])
  137. feat = self.out_conv(paddle.concat([feat, x], axis=1))
  138. output = self.cls(feat)
  139. if self.enable_auxiliary_loss:
  140. auxout = self.aux(C3)
  141. return [output, auxout]
  142. else:
  143. return [output]
  144. class SelfAttentionBlock(layers.AttentionBlock):
  145. """General self-attention block/non-local block.
  146. Args:
  147. in_channels (int): Input channels of key/query feature.
  148. channels (int): Output channels of key/query transform.
  149. """
  150. def __init__(self, in_channels, channels):
  151. super(SelfAttentionBlock, self).__init__(
  152. key_in_channels=in_channels,
  153. query_in_channels=in_channels,
  154. channels=channels,
  155. out_channels=in_channels,
  156. share_key_query=False,
  157. query_downsample=None,
  158. key_downsample=None,
  159. key_query_num_convs=2,
  160. key_query_norm=True,
  161. value_out_num_convs=1,
  162. value_out_norm=False,
  163. matmul_norm=True,
  164. with_out=False)
  165. self.output_project = self.build_project(
  166. in_channels, in_channels, num_convs=1, use_conv_module=True)
  167. def forward(self, x):
  168. context = super(SelfAttentionBlock, self).forward(x, x)
  169. return self.output_project(context)