ann.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.cvlibs import manager
  18. from paddlex.paddleseg.models import layers
  19. from paddlex.paddleseg.utils import utils
  20. @manager.MODELS.add_component
  21. class ANN(nn.Layer):
  22. """
  23. The ANN implementation based on PaddlePaddle.
  24. The original article refers to
  25. Zhen, Zhu, et al. "Asymmetric Non-local Neural Networks for Semantic Segmentation"
  26. (https://arxiv.org/pdf/1908.07678.pdf).
  27. Args:
  28. num_classes (int): The unique number of target classes.
  29. backbone (Paddle.nn.Layer): Backbone network, currently support Resnet50/101.
  30. backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone.
  31. key_value_channels (int, optional): The key and value channels of self-attention map in both AFNB and APNB modules.
  32. Default: 256.
  33. inter_channels (int, optional): Both input and output channels of APNB modules. Default: 512.
  34. psp_size (tuple, optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
  35. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  36. align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
  37. e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  38. pretrained (str, optional): The path or url of pretrained model. Default: None.
  39. """
  40. def __init__(self,
  41. num_classes,
  42. backbone,
  43. backbone_indices=(2, 3),
  44. key_value_channels=256,
  45. inter_channels=512,
  46. psp_size=(1, 3, 6, 8),
  47. enable_auxiliary_loss=True,
  48. align_corners=False,
  49. pretrained=None):
  50. super().__init__()
  51. self.backbone = backbone
  52. backbone_channels = [
  53. backbone.feat_channels[i] for i in backbone_indices
  54. ]
  55. self.head = ANNHead(num_classes, backbone_indices, backbone_channels,
  56. key_value_channels, inter_channels, psp_size,
  57. enable_auxiliary_loss)
  58. self.align_corners = align_corners
  59. self.pretrained = pretrained
  60. self.init_weight()
  61. def forward(self, x):
  62. feat_list = self.backbone(x)
  63. logit_list = self.head(feat_list)
  64. return [
  65. F.interpolate(
  66. logit,
  67. paddle.shape(x)[2:],
  68. mode='bilinear',
  69. align_corners=self.align_corners) for logit in logit_list
  70. ]
  71. def init_weight(self):
  72. if self.pretrained is not None:
  73. utils.load_entire_model(self, self.pretrained)
  74. class ANNHead(nn.Layer):
  75. """
  76. The ANNHead implementation.
  77. It mainly consists of AFNB and APNB modules.
  78. Args:
  79. num_classes (int): The unique number of target classes.
  80. backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone.
  81. The first index will be taken as low-level features; the second one will be
  82. taken as high-level features in AFNB module. Usually backbone consists of four
  83. downsampling stage, such as ResNet, and return an output of each stage. If it is (2, 3),
  84. it means taking feature map of the third stage and the fourth stage in backbone.
  85. backbone_channels (tuple): The same length with "backbone_indices". It indicates the channels of corresponding index.
  86. key_value_channels (int): The key and value channels of self-attention map in both AFNB and APNB modules.
  87. inter_channels (int): Both input and output channels of APNB modules.
  88. psp_size (tuple): The out size of pooled feature maps.
  89. enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
  90. """
  91. def __init__(self,
  92. num_classes,
  93. backbone_indices,
  94. backbone_channels,
  95. key_value_channels,
  96. inter_channels,
  97. psp_size,
  98. enable_auxiliary_loss=True):
  99. super().__init__()
  100. low_in_channels = backbone_channels[0]
  101. high_in_channels = backbone_channels[1]
  102. self.fusion = AFNB(
  103. low_in_channels=low_in_channels,
  104. high_in_channels=high_in_channels,
  105. out_channels=high_in_channels,
  106. key_channels=key_value_channels,
  107. value_channels=key_value_channels,
  108. dropout_prob=0.05,
  109. repeat_sizes=([1]),
  110. psp_size=psp_size)
  111. self.context = nn.Sequential(
  112. layers.ConvBNReLU(
  113. in_channels=high_in_channels,
  114. out_channels=inter_channels,
  115. kernel_size=3,
  116. padding=1),
  117. APNB(
  118. in_channels=inter_channels,
  119. out_channels=inter_channels,
  120. key_channels=key_value_channels,
  121. value_channels=key_value_channels,
  122. dropout_prob=0.05,
  123. repeat_sizes=([1]),
  124. psp_size=psp_size))
  125. self.cls = nn.Conv2D(
  126. in_channels=inter_channels, out_channels=num_classes, kernel_size=1)
  127. self.auxlayer = layers.AuxLayer(
  128. in_channels=low_in_channels,
  129. inter_channels=low_in_channels // 2,
  130. out_channels=num_classes,
  131. dropout_prob=0.05)
  132. self.backbone_indices = backbone_indices
  133. self.enable_auxiliary_loss = enable_auxiliary_loss
  134. def forward(self, feat_list):
  135. logit_list = []
  136. low_level_x = feat_list[self.backbone_indices[0]]
  137. high_level_x = feat_list[self.backbone_indices[1]]
  138. x = self.fusion(low_level_x, high_level_x)
  139. x = self.context(x)
  140. logit = self.cls(x)
  141. logit_list.append(logit)
  142. if self.enable_auxiliary_loss:
  143. auxiliary_logit = self.auxlayer(low_level_x)
  144. logit_list.append(auxiliary_logit)
  145. return logit_list
  146. class AFNB(nn.Layer):
  147. """
  148. Asymmetric Fusion Non-local Block.
  149. Args:
  150. low_in_channels (int): Low-level-feature channels.
  151. high_in_channels (int): High-level-feature channels.
  152. out_channels (int): Out channels of AFNB module.
  153. key_channels (int): The key channels in self-attention block.
  154. value_channels (int): The value channels in self-attention block.
  155. dropout_prob (float): The dropout rate of output.
  156. repeat_sizes (tuple, optional): The number of AFNB modules. Default: ([1]).
  157. psp_size (tuple. optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
  158. """
  159. def __init__(self,
  160. low_in_channels,
  161. high_in_channels,
  162. out_channels,
  163. key_channels,
  164. value_channels,
  165. dropout_prob,
  166. repeat_sizes=([1]),
  167. psp_size=(1, 3, 6, 8)):
  168. super().__init__()
  169. self.psp_size = psp_size
  170. self.stages = nn.LayerList([
  171. SelfAttentionBlock_AFNB(low_in_channels, high_in_channels,
  172. key_channels, value_channels, out_channels,
  173. size) for size in repeat_sizes
  174. ])
  175. self.conv_bn = layers.ConvBN(
  176. in_channels=out_channels + high_in_channels,
  177. out_channels=out_channels,
  178. kernel_size=1)
  179. self.dropout = nn.Dropout(p=dropout_prob)
  180. def forward(self, low_feats, high_feats):
  181. priors = [stage(low_feats, high_feats) for stage in self.stages]
  182. context = priors[0]
  183. for i in range(1, len(priors)):
  184. context += priors[i]
  185. output = self.conv_bn(paddle.concat([context, high_feats], axis=1))
  186. output = self.dropout(output)
  187. return output
  188. class APNB(nn.Layer):
  189. """
  190. Asymmetric Pyramid Non-local Block.
  191. Args:
  192. in_channels (int): The input channels of APNB module.
  193. out_channels (int): Out channels of APNB module.
  194. key_channels (int): The key channels in self-attention block.
  195. value_channels (int): The value channels in self-attention block.
  196. dropout_prob (float): The dropout rate of output.
  197. repeat_sizes (tuple, optional): The number of AFNB modules. Default: ([1]).
  198. psp_size (tuple, optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
  199. """
  200. def __init__(self,
  201. in_channels,
  202. out_channels,
  203. key_channels,
  204. value_channels,
  205. dropout_prob,
  206. repeat_sizes=([1]),
  207. psp_size=(1, 3, 6, 8)):
  208. super().__init__()
  209. self.psp_size = psp_size
  210. self.stages = nn.LayerList([
  211. SelfAttentionBlock_APNB(in_channels, out_channels, key_channels,
  212. value_channels, size)
  213. for size in repeat_sizes
  214. ])
  215. self.conv_bn = layers.ConvBNReLU(
  216. in_channels=in_channels * 2,
  217. out_channels=out_channels,
  218. kernel_size=1)
  219. self.dropout = nn.Dropout(p=dropout_prob)
  220. def forward(self, x):
  221. priors = [stage(x) for stage in self.stages]
  222. context = priors[0]
  223. for i in range(1, len(priors)):
  224. context += priors[i]
  225. output = self.conv_bn(paddle.concat([context, x], axis=1))
  226. output = self.dropout(output)
  227. return output
  228. def _pp_module(x, psp_size):
  229. n, c, h, w = x.shape
  230. priors = []
  231. for size in psp_size:
  232. feat = F.adaptive_avg_pool2d(x, size)
  233. feat = paddle.reshape(feat, shape=(0, c, -1))
  234. priors.append(feat)
  235. center = paddle.concat(priors, axis=-1)
  236. return center
  237. class SelfAttentionBlock_AFNB(nn.Layer):
  238. """
  239. Self-Attention Block for AFNB module.
  240. Args:
  241. low_in_channels (int): Low-level-feature channels.
  242. high_in_channels (int): High-level-feature channels.
  243. key_channels (int): The key channels in self-attention block.
  244. value_channels (int): The value channels in self-attention block.
  245. out_channels (int, optional): Out channels of AFNB module. Default: None.
  246. scale (int, optional): Pooling size. Default: 1.
  247. psp_size (tuple, optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
  248. """
  249. def __init__(self,
  250. low_in_channels,
  251. high_in_channels,
  252. key_channels,
  253. value_channels,
  254. out_channels=None,
  255. scale=1,
  256. psp_size=(1, 3, 6, 8)):
  257. super().__init__()
  258. self.scale = scale
  259. self.in_channels = low_in_channels
  260. self.out_channels = out_channels
  261. self.key_channels = key_channels
  262. self.value_channels = value_channels
  263. if out_channels == None:
  264. self.out_channels = high_in_channels
  265. self.pool = nn.MaxPool2D(scale)
  266. self.f_key = layers.ConvBNReLU(
  267. in_channels=low_in_channels,
  268. out_channels=key_channels,
  269. kernel_size=1)
  270. self.f_query = layers.ConvBNReLU(
  271. in_channels=high_in_channels,
  272. out_channels=key_channels,
  273. kernel_size=1)
  274. self.f_value = nn.Conv2D(
  275. in_channels=low_in_channels,
  276. out_channels=value_channels,
  277. kernel_size=1)
  278. self.W = nn.Conv2D(
  279. in_channels=value_channels,
  280. out_channels=out_channels,
  281. kernel_size=1)
  282. self.psp_size = psp_size
  283. def forward(self, low_feats, high_feats):
  284. batch_size, _, h, w = high_feats.shape
  285. value = self.f_value(low_feats)
  286. value = _pp_module(value, self.psp_size)
  287. value = paddle.transpose(value, (0, 2, 1))
  288. query = self.f_query(high_feats)
  289. query = paddle.reshape(query, shape=(0, self.key_channels, -1))
  290. query = paddle.transpose(query, perm=(0, 2, 1))
  291. key = self.f_key(low_feats)
  292. key = _pp_module(key, self.psp_size)
  293. sim_map = paddle.matmul(query, key)
  294. sim_map = (self.key_channels**-.5) * sim_map
  295. sim_map = F.softmax(sim_map, axis=-1)
  296. context = paddle.matmul(sim_map, value)
  297. context = paddle.transpose(context, perm=(0, 2, 1))
  298. hf_shape = paddle.shape(high_feats)
  299. context = paddle.reshape(
  300. context, shape=[0, self.value_channels, hf_shape[2], hf_shape[3]])
  301. context = self.W(context)
  302. return context
  303. class SelfAttentionBlock_APNB(nn.Layer):
  304. """
  305. Self-Attention Block for APNB module.
  306. Args:
  307. in_channels (int): The input channels of APNB module.
  308. out_channels (int): The out channels of APNB module.
  309. key_channels (int): The key channels in self-attention block.
  310. value_channels (int): The value channels in self-attention block.
  311. scale (int, optional): Pooling size. Default: 1.
  312. psp_size (tuple, optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
  313. """
  314. def __init__(self,
  315. in_channels,
  316. out_channels,
  317. key_channels,
  318. value_channels,
  319. scale=1,
  320. psp_size=(1, 3, 6, 8)):
  321. super().__init__()
  322. self.scale = scale
  323. self.in_channels = in_channels
  324. self.out_channels = out_channels
  325. self.key_channels = key_channels
  326. self.value_channels = value_channels
  327. self.pool = nn.MaxPool2D(scale)
  328. self.f_key = layers.ConvBNReLU(
  329. in_channels=self.in_channels,
  330. out_channels=self.key_channels,
  331. kernel_size=1)
  332. self.f_query = self.f_key
  333. self.f_value = nn.Conv2D(
  334. in_channels=self.in_channels,
  335. out_channels=self.value_channels,
  336. kernel_size=1)
  337. self.W = nn.Conv2D(
  338. in_channels=self.value_channels,
  339. out_channels=self.out_channels,
  340. kernel_size=1)
  341. self.psp_size = psp_size
  342. def forward(self, x):
  343. batch_size, _, h, w = x.shape
  344. if self.scale > 1:
  345. x = self.pool(x)
  346. value = self.f_value(x)
  347. value = _pp_module(value, self.psp_size)
  348. value = paddle.transpose(value, perm=(0, 2, 1))
  349. query = self.f_query(x)
  350. query = paddle.reshape(query, shape=(0, self.key_channels, -1))
  351. query = paddle.transpose(query, perm=(0, 2, 1))
  352. key = self.f_key(x)
  353. key = _pp_module(key, self.psp_size)
  354. sim_map = paddle.matmul(query, key)
  355. sim_map = (self.key_channels**-.5) * sim_map
  356. sim_map = F.softmax(sim_map, axis=-1)
  357. context = paddle.matmul(sim_map, value)
  358. context = paddle.transpose(context, perm=(0, 2, 1))
  359. x_shape = paddle.shape(x)
  360. context = paddle.reshape(
  361. context, shape=[0, self.value_channels, x_shape[2], x_shape[3]])
  362. context = self.W(context)
  363. return context