decoupled_segnet.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import cv2
  15. import numpy as np
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddlex.paddleseg.cvlibs import manager
  20. from paddlex.paddleseg.models import layers
  21. from paddlex.paddleseg.models.backbones import resnet_vd
  22. from paddlex.paddleseg.models import deeplab
  23. from paddlex.paddleseg.utils import utils
  24. @manager.MODELS.add_component
  25. class DecoupledSegNet(nn.Layer):
  26. """
  27. The DecoupledSegNet implementation based on PaddlePaddle.
  28. The original article refers to
  29. Xiangtai Li, et, al. "Improving Semantic Segmentation via Decoupled Body and Edge Supervision"
  30. (https://arxiv.org/pdf/2007.10035.pdf)
  31. Args:
  32. num_classes (int): The unique number of target classes.
  33. backbone (paddle.nn.Layer): Backbone network, currently support Resnet50_vd/Resnet101_vd.
  34. backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone.
  35. Default: (0, 3).
  36. aspp_ratios (tuple, optional): The dilation rate using in ASSP module.
  37. If output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
  38. If output_stride=8, aspp_ratios is (1, 12, 24, 36).
  39. Default: (1, 6, 12, 18).
  40. aspp_out_channels (int, optional): The output channels of ASPP module. Default: 256.
  41. align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
  42. e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  43. pretrained (str, optional): The path or url of pretrained model. Default: None.
  44. """
  45. def __init__(self,
  46. num_classes,
  47. backbone,
  48. backbone_indices=(0, 3),
  49. aspp_ratios=(1, 6, 12, 18),
  50. aspp_out_channels=256,
  51. align_corners=False,
  52. pretrained=None):
  53. super().__init__()
  54. self.backbone = backbone
  55. backbone_channels = self.backbone.feat_channels
  56. self.head = DecoupledSegNetHead(num_classes, backbone_indices,
  57. backbone_channels, aspp_ratios,
  58. aspp_out_channels, align_corners)
  59. self.align_corners = align_corners
  60. self.pretrained = pretrained
  61. self.init_weight()
  62. def forward(self, x):
  63. feat_list = self.backbone(x)
  64. logit_list = self.head(feat_list)
  65. seg_logit, body_logit, edge_logit = [
  66. F.interpolate(
  67. logit,
  68. paddle.shape(x)[2:],
  69. mode='bilinear',
  70. align_corners=self.align_corners) for logit in logit_list
  71. ]
  72. return [seg_logit, body_logit, edge_logit, (seg_logit, edge_logit)]
  73. def init_weight(self):
  74. if self.pretrained is not None:
  75. utils.load_entire_model(self, self.pretrained)
  76. class DecoupledSegNetHead(nn.Layer):
  77. """
  78. The DecoupledSegNetHead implementation based on PaddlePaddle.
  79. Args:
  80. num_classes (int): The unique number of target classes.
  81. backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone.
  82. the first index will be taken as a low-level feature in Edge presevation component;
  83. the second one will be taken as input of ASPP component.
  84. backbone_channels (tuple): The channels of output of backbone.
  85. aspp_ratios (tuple): The dilation rates using in ASSP module.
  86. aspp_out_channels (int): The output channels of ASPP module.
  87. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  88. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
  89. """
  90. def __init__(self, num_classes, backbone_indices, backbone_channels,
  91. aspp_ratios, aspp_out_channels, align_corners):
  92. super().__init__()
  93. self.backbone_indices = backbone_indices
  94. self.align_corners = align_corners
  95. self.aspp = layers.ASPPModule(
  96. aspp_ratios=aspp_ratios,
  97. in_channels=backbone_channels[backbone_indices[1]],
  98. out_channels=aspp_out_channels,
  99. align_corners=align_corners,
  100. image_pooling=True)
  101. self.bot_fine = nn.Conv2D(
  102. backbone_channels[backbone_indices[0]], 48, 1, bias_attr=False)
  103. # decoupled
  104. self.squeeze_body_edge = SqueezeBodyEdge(
  105. 256, align_corners=self.align_corners)
  106. self.edge_fusion = nn.Conv2D(256 + 48, 256, 1, bias_attr=False)
  107. self.sigmoid_edge = nn.Sigmoid()
  108. self.edge_out = nn.Sequential(
  109. layers.ConvBNReLU(
  110. in_channels=256,
  111. out_channels=48,
  112. kernel_size=3,
  113. bias_attr=False), nn.Conv2D(48, 1, 1, bias_attr=False))
  114. self.dsn_seg_body = nn.Sequential(
  115. layers.ConvBNReLU(
  116. in_channels=256,
  117. out_channels=256,
  118. kernel_size=3,
  119. bias_attr=False), nn.Conv2D(
  120. 256, num_classes, 1, bias_attr=False))
  121. self.final_seg = nn.Sequential(
  122. layers.ConvBNReLU(
  123. in_channels=512,
  124. out_channels=256,
  125. kernel_size=3,
  126. bias_attr=False),
  127. layers.ConvBNReLU(
  128. in_channels=256,
  129. out_channels=256,
  130. kernel_size=3,
  131. bias_attr=False),
  132. nn.Conv2D(256, num_classes, kernel_size=1, bias_attr=False))
  133. def forward(self, feat_list):
  134. fine_fea = feat_list[self.backbone_indices[0]]
  135. fine_size = paddle.shape(fine_fea)
  136. x = feat_list[self.backbone_indices[1]]
  137. aspp = self.aspp(x)
  138. # decoupled
  139. seg_body, seg_edge = self.squeeze_body_edge(aspp)
  140. # Edge presevation and edge out
  141. fine_fea = self.bot_fine(fine_fea)
  142. seg_edge = F.interpolate(
  143. seg_edge,
  144. fine_size[2:],
  145. mode='bilinear',
  146. align_corners=self.align_corners)
  147. seg_edge = self.edge_fusion(paddle.concat([seg_edge, fine_fea], axis=1))
  148. seg_edge_out = self.edge_out(seg_edge)
  149. seg_edge_out = self.sigmoid_edge(seg_edge_out) # seg_edge output
  150. seg_body_out = self.dsn_seg_body(seg_body) # body out
  151. # seg_final out
  152. seg_out = seg_edge + F.interpolate(
  153. seg_body,
  154. fine_size[2:],
  155. mode='bilinear',
  156. align_corners=self.align_corners)
  157. aspp = F.interpolate(
  158. aspp,
  159. fine_size[2:],
  160. mode='bilinear',
  161. align_corners=self.align_corners)
  162. seg_out = paddle.concat([aspp, seg_out], axis=1)
  163. seg_final_out = self.final_seg(seg_out)
  164. return [seg_final_out, seg_body_out, seg_edge_out]
  165. class SqueezeBodyEdge(nn.Layer):
  166. def __init__(self, inplane, align_corners=False):
  167. super().__init__()
  168. self.align_corners = align_corners
  169. self.down = nn.Sequential(
  170. layers.ConvBNReLU(
  171. inplane, inplane, kernel_size=3, groups=inplane, stride=2),
  172. layers.ConvBNReLU(
  173. inplane, inplane, kernel_size=3, groups=inplane, stride=2))
  174. self.flow_make = nn.Conv2D(
  175. inplane * 2, 2, kernel_size=3, padding='same', bias_attr=False)
  176. def forward(self, x):
  177. size = paddle.shape(x)[2:]
  178. seg_down = self.down(x)
  179. seg_down = F.interpolate(
  180. seg_down,
  181. size=size,
  182. mode='bilinear',
  183. align_corners=self.align_corners)
  184. flow = self.flow_make(paddle.concat([x, seg_down], axis=1))
  185. seg_flow_warp = self.flow_warp(x, flow, size)
  186. seg_edge = x - seg_flow_warp
  187. return seg_flow_warp, seg_edge
  188. def flow_warp(self, input, flow, size):
  189. input_shape = paddle.shape(input)
  190. norm = size[::-1].reshape([1, 1, 1, -1])
  191. norm.stop_gradient = True
  192. h_grid = paddle.linspace(-1.0, 1.0, size[0]).reshape([-1, 1])
  193. h_grid = h_grid.tile([size[1]])
  194. w_grid = paddle.linspace(-1.0, 1.0, size[1]).reshape([-1, 1])
  195. w_grid = w_grid.tile([size[0]]).transpose([1, 0])
  196. grid = paddle.concat([w_grid.unsqueeze(2), h_grid.unsqueeze(2)], axis=2)
  197. grid.unsqueeze(0).tile([input_shape[0], 1, 1, 1])
  198. grid = grid + paddle.transpose(flow, (0, 2, 3, 1)) / norm
  199. output = F.grid_sample(input, grid)
  200. return output