dnlnet.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.models import layers
  18. from paddlex.paddleseg.cvlibs import manager
  19. from paddlex.paddleseg.utils import utils
  20. @manager.MODELS.add_component
  21. class DNLNet(nn.Layer):
  22. """Disentangled Non-Local Neural Networks.
  23. The original article refers to
  24. Minghao Yin, et al. "Disentangled Non-Local Neural Networks"
  25. (https://arxiv.org/abs/2006.06668)
  26. Args:
  27. num_classes (int): The unique number of target classes.
  28. backbone (Paddle.nn.Layer): A backbone network.
  29. backbone_indices (tuple): The values in the tuple indicate the indices of output of backbone.
  30. reduction (int): Reduction factor of projection transform. Default: 2.
  31. use_scale (bool): Whether to scale pairwise_weight by
  32. sqrt(1/inter_channels). Default: False.
  33. mode (str): The nonlocal mode. Options are 'embedded_gaussian',
  34. 'dot_product'. Default: 'embedded_gaussian'.
  35. temperature (float): Temperature to adjust attention. Default: 0.05.
  36. concat_input (bool): Whether concat the input and output of convs before classification layer. Default: True
  37. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  38. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  39. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  40. pretrained (str, optional): The path or url of pretrained model. Default: None.
  41. """
  42. def __init__(self,
  43. num_classes,
  44. backbone,
  45. backbone_indices=(2, 3),
  46. reduction=2,
  47. use_scale=True,
  48. mode='embedded_gaussian',
  49. temperature=0.05,
  50. concat_input=True,
  51. enable_auxiliary_loss=True,
  52. align_corners=False,
  53. pretrained=None):
  54. super().__init__()
  55. self.backbone = backbone
  56. self.backbone_indices = backbone_indices
  57. in_channels = [self.backbone.feat_channels[i] for i in backbone_indices]
  58. self.head = DNLHead(num_classes, in_channels, reduction, use_scale,
  59. mode, temperature, concat_input,
  60. enable_auxiliary_loss)
  61. self.align_corners = align_corners
  62. self.pretrained = pretrained
  63. self.init_weight()
  64. def forward(self, x):
  65. feats = self.backbone(x)
  66. feats = [feats[i] for i in self.backbone_indices]
  67. logit_list = self.head(feats)
  68. logit_list = [
  69. F.interpolate(
  70. logit,
  71. paddle.shape(x)[2:],
  72. mode='bilinear',
  73. align_corners=self.align_corners,
  74. align_mode=1) for logit in logit_list
  75. ]
  76. return logit_list
  77. def init_weight(self):
  78. if self.pretrained is not None:
  79. utils.load_entire_model(self, self.pretrained)
  80. class DNLHead(nn.Layer):
  81. """
  82. The DNLNet head.
  83. Args:
  84. num_classes (int): The unique number of target classes.
  85. in_channels (tuple): The number of input channels.
  86. reduction (int): Reduction factor of projection transform. Default: 2.
  87. use_scale (bool): Whether to scale pairwise_weight by
  88. sqrt(1/inter_channels). Default: False.
  89. mode (str): The nonlocal mode. Options are 'embedded_gaussian',
  90. 'dot_product'. Default: 'embedded_gaussian.'.
  91. temperature (float): Temperature to adjust attention. Default: 0.05
  92. concat_input (bool): Whether concat the input and output of convs before classification layer. Default: True
  93. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  94. """
  95. def __init__(self,
  96. num_classes,
  97. in_channels,
  98. reduction,
  99. use_scale,
  100. mode,
  101. temperature,
  102. concat_input=True,
  103. enable_auxiliary_loss=True,
  104. **kwargs):
  105. super(DNLHead, self).__init__()
  106. self.in_channels = in_channels[-1]
  107. self.concat_input = concat_input
  108. self.enable_auxiliary_loss = enable_auxiliary_loss
  109. inter_channels = self.in_channels // 4
  110. self.dnl_block = DisentangledNonLocal2D(
  111. in_channels=inter_channels,
  112. reduction=reduction,
  113. use_scale=use_scale,
  114. temperature=temperature,
  115. mode=mode)
  116. self.conv0 = layers.ConvBNReLU(
  117. in_channels=self.in_channels,
  118. out_channels=inter_channels,
  119. kernel_size=3,
  120. bias_attr=False)
  121. self.conv1 = layers.ConvBNReLU(
  122. in_channels=inter_channels,
  123. out_channels=inter_channels,
  124. kernel_size=3,
  125. bias_attr=False)
  126. self.cls = nn.Sequential(
  127. nn.Dropout2D(p=0.1), nn.Conv2D(inter_channels, num_classes, 1))
  128. self.aux = nn.Sequential(
  129. layers.ConvBNReLU(
  130. in_channels=1024,
  131. out_channels=256,
  132. kernel_size=3,
  133. bias_attr=False), nn.Dropout2D(p=0.1),
  134. nn.Conv2D(256, num_classes, 1))
  135. if self.concat_input:
  136. self.conv_cat = layers.ConvBNReLU(
  137. self.in_channels + inter_channels,
  138. inter_channels,
  139. kernel_size=3,
  140. bias_attr=False)
  141. def forward(self, feat_list):
  142. C3, C4 = feat_list
  143. output = self.conv0(C4)
  144. output = self.dnl_block(output)
  145. output = self.conv1(output)
  146. if self.concat_input:
  147. output = self.conv_cat(paddle.concat([C4, output], axis=1))
  148. output = self.cls(output)
  149. if self.enable_auxiliary_loss:
  150. auxout = self.aux(C3)
  151. return [output, auxout]
  152. else:
  153. return [output]
  154. class DisentangledNonLocal2D(layers.NonLocal2D):
  155. """Disentangled Non-Local Blocks.
  156. Args:
  157. temperature (float): Temperature to adjust attention.
  158. """
  159. def __init__(self, temperature, *arg, **kwargs):
  160. super().__init__(*arg, **kwargs)
  161. self.temperature = temperature
  162. self.conv_mask = nn.Conv2D(self.in_channels, 1, kernel_size=1)
  163. def embedded_gaussian(self, theta_x, phi_x):
  164. pairwise_weight = paddle.matmul(theta_x, phi_x)
  165. if self.use_scale:
  166. pairwise_weight /= theta_x.shape[-1]**0.5
  167. pairwise_weight /= self.temperature
  168. pairwise_weight = F.softmax(pairwise_weight, -1)
  169. return pairwise_weight
  170. def forward(self, x):
  171. x_shape = paddle.shape(x)
  172. g_x = self.g(x).reshape([0, self.inter_channels,
  173. -1]).transpose([0, 2, 1])
  174. if self.mode == "gaussian":
  175. theta_x = paddle.transpose(
  176. x.reshape([0, self.in_channels, -1]), [0, 2, 1])
  177. if self.sub_sample:
  178. phi_x = paddle.transpose(self.phi(x), [0, self.in_channels, -1])
  179. else:
  180. phi_x = paddle.transpose(x, [0, self.in_channels, -1])
  181. elif self.mode == "concatenation":
  182. theta_x = paddle.reshape(
  183. self.theta(x), [0, self.inter_channels, -1, 1])
  184. phi_x = paddle.reshape(self.phi(x), [0, self.inter_channels, 1, -1])
  185. else:
  186. theta_x = self.theta(x).reshape([0, self.inter_channels,
  187. -1]).transpose([0, 2, 1])
  188. phi_x = paddle.reshape(self.phi(x), [0, self.inter_channels, -1])
  189. theta_x -= paddle.mean(theta_x, axis=-2, keepdim=True)
  190. phi_x -= paddle.mean(phi_x, axis=-1, keepdim=True)
  191. pairwise_func = getattr(self, self.mode)
  192. pairwise_weight = pairwise_func(theta_x, phi_x)
  193. y = paddle.matmul(pairwise_weight, g_x).transpose([0, 2, 1]).reshape(
  194. [0, self.inter_channels, x_shape[2], x_shape[3]])
  195. unary_mask = F.softmax(
  196. paddle.reshape(self.conv_mask(x), [0, 1, -1]), -1)
  197. unary_x = paddle.matmul(unary_mask, g_x).transpose([0, 2, 1]).reshape(
  198. [0, self.inter_channels, 1, 1])
  199. output = x + self.conv_out(y + unary_x)
  200. return output