attention.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.models import layers
  18. class AttentionBlock(nn.Layer):
  19. """General self-attention block/non-local block.
  20. The original article refers to refer to https://arxiv.org/abs/1706.03762.
  21. Args:
  22. key_in_channels (int): Input channels of key feature.
  23. query_in_channels (int): Input channels of query feature.
  24. channels (int): Output channels of key/query transform.
  25. out_channels (int): Output channels.
  26. share_key_query (bool): Whether share projection weight between key
  27. and query projection.
  28. query_downsample (nn.Module): Query downsample module.
  29. key_downsample (nn.Module): Key downsample module.
  30. key_query_num_convs (int): Number of convs for key/query projection.
  31. value_out_num_convs (int): Number of convs for value projection.
  32. key_query_norm (bool): Whether to use BN for key/query projection.
  33. value_out_norm (bool): Whether to use BN for value projection.
  34. matmul_norm (bool): Whether normalize attention map with sqrt of
  35. channels
  36. with_out (bool): Whether use out projection.
  37. """
  38. def __init__(self, key_in_channels, query_in_channels, channels,
  39. out_channels, share_key_query, query_downsample,
  40. key_downsample, key_query_num_convs, value_out_num_convs,
  41. key_query_norm, value_out_norm, matmul_norm, with_out):
  42. super(AttentionBlock, self).__init__()
  43. if share_key_query:
  44. assert key_in_channels == query_in_channels
  45. self.with_out = with_out
  46. self.key_in_channels = key_in_channels
  47. self.query_in_channels = query_in_channels
  48. self.out_channels = out_channels
  49. self.channels = channels
  50. self.share_key_query = share_key_query
  51. self.key_project = self.build_project(
  52. key_in_channels,
  53. channels,
  54. num_convs=key_query_num_convs,
  55. use_conv_module=key_query_norm)
  56. if share_key_query:
  57. self.query_project = self.key_project
  58. else:
  59. self.query_project = self.build_project(
  60. query_in_channels,
  61. channels,
  62. num_convs=key_query_num_convs,
  63. use_conv_module=key_query_norm)
  64. self.value_project = self.build_project(
  65. key_in_channels,
  66. channels if self.with_out else out_channels,
  67. num_convs=value_out_num_convs,
  68. use_conv_module=value_out_norm)
  69. if self.with_out:
  70. self.out_project = self.build_project(
  71. channels,
  72. out_channels,
  73. num_convs=value_out_num_convs,
  74. use_conv_module=value_out_norm)
  75. else:
  76. self.out_project = None
  77. self.query_downsample = query_downsample
  78. self.key_downsample = key_downsample
  79. self.matmul_norm = matmul_norm
  80. def build_project(self, in_channels, channels, num_convs, use_conv_module):
  81. if use_conv_module:
  82. convs = [
  83. layers.ConvBNReLU(
  84. in_channels=in_channels,
  85. out_channels=channels,
  86. kernel_size=1,
  87. bias_attr=False)
  88. ]
  89. for _ in range(num_convs - 1):
  90. convs.append(
  91. layers.ConvBNReLU(
  92. in_channels=channels,
  93. out_channels=channels,
  94. kernel_size=1,
  95. bias_attr=False))
  96. else:
  97. convs = [nn.Conv2D(in_channels, channels, 1)]
  98. for _ in range(num_convs - 1):
  99. convs.append(nn.Conv2D(channels, channels, 1))
  100. if len(convs) > 1:
  101. convs = nn.Sequential(*convs)
  102. else:
  103. convs = convs[0]
  104. return convs
  105. def forward(self, query_feats, key_feats):
  106. query_shape = paddle.shape(query_feats)
  107. query = self.query_project(query_feats)
  108. if self.query_downsample is not None:
  109. query = self.query_downsample(query)
  110. query = query.flatten(2).transpose([0, 2, 1])
  111. key = self.key_project(key_feats)
  112. value = self.value_project(key_feats)
  113. if self.key_downsample is not None:
  114. key = self.key_downsample(key)
  115. value = self.key_downsample(value)
  116. key = key.flatten(2)
  117. value = value.flatten(2).transpose([0, 2, 1])
  118. sim_map = paddle.matmul(query, key)
  119. if self.matmul_norm:
  120. sim_map = (self.channels**-0.5) * sim_map
  121. sim_map = F.softmax(sim_map, axis=-1)
  122. context = paddle.matmul(sim_map, value)
  123. context = paddle.transpose(context, [0, 2, 1])
  124. context = paddle.reshape(
  125. context, [0, self.out_channels, query_shape[2], query_shape[3]])
  126. if self.out_project is not None:
  127. context = self.out_project(context)
  128. return context