emanet.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.models import layers
  18. from paddlex.paddleseg.cvlibs import manager
  19. from paddlex.paddleseg.utils import utils
  20. @manager.MODELS.add_component
  21. class EMANet(nn.Layer):
  22. """
  23. Expectation Maximization Attention Networks for Semantic Segmentation based on PaddlePaddle.
  24. The original article refers to
  25. Xia Li, et al. "Expectation-Maximization Attention Networks for Semantic Segmentation"
  26. (https://arxiv.org/abs/1907.13426)
  27. Args:
  28. num_classes (int): The unique number of target classes.
  29. backbone (Paddle.nn.Layer): A backbone network.
  30. backbone_indices (tuple): The values in the tuple indicate the indices of output of backbone.
  31. ema_channels (int): EMA module channels.
  32. gc_channels (int): The input channels to Global Context Block.
  33. num_bases (int): Number of bases.
  34. stage_num (int): The iteration number for EM.
  35. momentum (float): The parameter for updating bases.
  36. concat_input (bool): Whether concat the input and output of convs before classification layer. Default: True
  37. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  38. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  39. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  40. pretrained (str, optional): The path or url of pretrained model. Default: None.
  41. """
  42. def __init__(self,
  43. num_classes,
  44. backbone,
  45. backbone_indices=(2, 3),
  46. ema_channels=512,
  47. gc_channels=256,
  48. num_bases=64,
  49. stage_num=3,
  50. momentum=0.1,
  51. concat_input=True,
  52. enable_auxiliary_loss=True,
  53. align_corners=False,
  54. pretrained=None):
  55. super().__init__()
  56. self.backbone = backbone
  57. self.backbone_indices = backbone_indices
  58. in_channels = [self.backbone.feat_channels[i] for i in backbone_indices]
  59. self.head = EMAHead(num_classes, in_channels, ema_channels, gc_channels,
  60. num_bases, stage_num, momentum, concat_input,
  61. enable_auxiliary_loss)
  62. self.align_corners = align_corners
  63. self.pretrained = pretrained
  64. self.init_weight()
  65. def forward(self, x):
  66. feats = self.backbone(x)
  67. feats = [feats[i] for i in self.backbone_indices]
  68. logit_list = self.head(feats)
  69. logit_list = [
  70. F.interpolate(
  71. logit,
  72. paddle.shape(x)[2:],
  73. mode='bilinear',
  74. align_corners=self.align_corners) for logit in logit_list
  75. ]
  76. return logit_list
  77. def init_weight(self):
  78. if self.pretrained is not None:
  79. utils.load_entire_model(self, self.pretrained)
  80. class EMAHead(nn.Layer):
  81. """
  82. The EMANet head.
  83. Args:
  84. num_classes (int): The unique number of target classes.
  85. in_channels (tuple): The number of input channels.
  86. ema_channels (int): EMA module channels.
  87. gc_channels (int): The input channels to Global Context Block.
  88. num_bases (int): Number of bases.
  89. stage_num (int): The iteration number for EM.
  90. momentum (float): The parameter for updating bases.
  91. concat_input (bool): Whether concat the input and output of convs before classification layer. Default: True
  92. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  93. """
  94. def __init__(self,
  95. num_classes,
  96. in_channels,
  97. ema_channels,
  98. gc_channels,
  99. num_bases,
  100. stage_num,
  101. momentum,
  102. concat_input=True,
  103. enable_auxiliary_loss=True):
  104. super(EMAHead, self).__init__()
  105. self.in_channels = in_channels[-1]
  106. self.concat_input = concat_input
  107. self.enable_auxiliary_loss = enable_auxiliary_loss
  108. self.emau = EMAU(ema_channels, num_bases, stage_num, momentum=momentum)
  109. self.ema_in_conv = layers.ConvBNReLU(
  110. in_channels=self.in_channels,
  111. out_channels=ema_channels,
  112. kernel_size=3)
  113. self.ema_mid_conv = nn.Conv2D(ema_channels, ema_channels, kernel_size=1)
  114. self.ema_out_conv = layers.ConvBNReLU(
  115. in_channels=ema_channels, out_channels=ema_channels, kernel_size=1)
  116. self.bottleneck = layers.ConvBNReLU(
  117. in_channels=ema_channels, out_channels=gc_channels, kernel_size=3)
  118. self.cls = nn.Sequential(
  119. nn.Dropout2D(p=0.1), nn.Conv2D(gc_channels, num_classes, 1))
  120. self.aux = nn.Sequential(
  121. layers.ConvBNReLU(
  122. in_channels=1024, out_channels=256, kernel_size=3),
  123. nn.Dropout2D(p=0.1), nn.Conv2D(256, num_classes, 1))
  124. if self.concat_input:
  125. self.conv_cat = layers.ConvBNReLU(
  126. self.in_channels + gc_channels, gc_channels, kernel_size=3)
  127. def forward(self, feat_list):
  128. C3, C4 = feat_list
  129. feats = self.ema_in_conv(C4)
  130. identity = feats
  131. feats = self.ema_mid_conv(feats)
  132. recon = self.emau(feats)
  133. recon = F.relu(recon)
  134. recon = self.ema_out_conv(recon)
  135. output = F.relu(identity + recon)
  136. output = self.bottleneck(output)
  137. if self.concat_input:
  138. output = self.conv_cat(paddle.concat([C4, output], axis=1))
  139. output = self.cls(output)
  140. if self.enable_auxiliary_loss:
  141. auxout = self.aux(C3)
  142. return [output, auxout]
  143. else:
  144. return [output]
  145. class EMAU(nn.Layer):
  146. '''The Expectation-Maximization Attention Unit (EMAU).
  147. Arguments:
  148. c (int): The input and output channel number.
  149. k (int): The number of the bases.
  150. stage_num (int): The iteration number for EM.
  151. momentum (float): The parameter for updating bases.
  152. '''
  153. def __init__(self, c, k, stage_num=3, momentum=0.1):
  154. super(EMAU, self).__init__()
  155. assert stage_num >= 1
  156. self.stage_num = stage_num
  157. self.momentum = momentum
  158. self.c = c
  159. tmp_mu = self.create_parameter(
  160. shape=[1, c, k],
  161. default_initializer=paddle.nn.initializer.KaimingNormal(k))
  162. mu = F.normalize(paddle.to_tensor(tmp_mu), axis=1, p=2)
  163. self.register_buffer('mu', mu)
  164. def forward(self, x):
  165. x_shape = paddle.shape(x)
  166. x = x.flatten(2)
  167. mu = paddle.tile(self.mu, [x_shape[0], 1, 1])
  168. with paddle.no_grad():
  169. for i in range(self.stage_num):
  170. x_t = paddle.transpose(x, [0, 2, 1])
  171. z = paddle.bmm(x_t, mu)
  172. z = F.softmax(z, axis=2)
  173. z_ = F.normalize(z, axis=1, p=1)
  174. mu = paddle.bmm(x, z_)
  175. mu = F.normalize(mu, axis=1, p=2)
  176. z_t = paddle.transpose(z, [0, 2, 1])
  177. x = paddle.matmul(mu, z_t)
  178. x = paddle.reshape(x, [0, self.c, x_shape[2], x_shape[3]])
  179. if self.training:
  180. mu = paddle.mean(mu, 0, keepdim=True)
  181. mu = F.normalize(mu, axis=1, p=2)
  182. mu = self.mu * (1 - self.momentum) + mu * self.momentum
  183. if paddle.distributed.get_world_size() > 1:
  184. mu = paddle.distributed.all_reduce(mu)
  185. mu /= paddle.distributed.get_world_size()
  186. self.mu = mu
  187. return x