bisenet.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddlex.paddleseg import utils
  19. from paddlex.paddleseg.cvlibs import manager, param_init
  20. from paddlex.paddleseg.models import layers
  21. @manager.MODELS.add_component
  22. class BiSeNetV2(nn.Layer):
  23. """
  24. The BiSeNet V2 implementation based on PaddlePaddle.
  25. The original article refers to
  26. Yu, Changqian, et al. "BiSeNet V2: Bilateral Network with Guided Aggregation for Real-time Semantic Segmentation"
  27. (https://arxiv.org/abs/2004.02147)
  28. Args:
  29. num_classes (int): The unique number of target classes.
  30. lambd (float, optional): A factor for controlling the size of semantic branch channels. Default: 0.25.
  31. pretrained (str, optional): The path or url of pretrained model. Default: None.
  32. """
  33. def __init__(self,
  34. num_classes,
  35. lambd=0.25,
  36. align_corners=False,
  37. pretrained=None):
  38. super().__init__()
  39. C1, C2, C3 = 64, 64, 128
  40. db_channels = (C1, C2, C3)
  41. C1, C3, C4, C5 = int(C1 * lambd), int(C3 * lambd), 64, 128
  42. sb_channels = (C1, C3, C4, C5)
  43. mid_channels = 128
  44. self.db = DetailBranch(db_channels)
  45. self.sb = SemanticBranch(sb_channels)
  46. self.bga = BGA(mid_channels, align_corners)
  47. self.aux_head1 = SegHead(C1, C1, num_classes)
  48. self.aux_head2 = SegHead(C3, C3, num_classes)
  49. self.aux_head3 = SegHead(C4, C4, num_classes)
  50. self.aux_head4 = SegHead(C5, C5, num_classes)
  51. self.head = SegHead(mid_channels, mid_channels, num_classes)
  52. self.align_corners = align_corners
  53. self.pretrained = pretrained
  54. self.init_weight()
  55. def forward(self, x):
  56. dfm = self.db(x)
  57. feat1, feat2, feat3, feat4, sfm = self.sb(x)
  58. logit = self.head(self.bga(dfm, sfm))
  59. if not self.training:
  60. logit_list = [logit]
  61. else:
  62. logit1 = self.aux_head1(feat1)
  63. logit2 = self.aux_head2(feat2)
  64. logit3 = self.aux_head3(feat3)
  65. logit4 = self.aux_head4(feat4)
  66. logit_list = [logit, logit1, logit2, logit3, logit4]
  67. logit_list = [
  68. F.interpolate(
  69. logit,
  70. paddle.shape(x)[2:],
  71. mode='bilinear',
  72. align_corners=self.align_corners) for logit in logit_list
  73. ]
  74. return logit_list
  75. def init_weight(self):
  76. if self.pretrained is not None:
  77. utils.load_entire_model(self, self.pretrained)
  78. else:
  79. for sublayer in self.sublayers():
  80. if isinstance(sublayer, nn.Conv2D):
  81. param_init.kaiming_normal_init(sublayer.weight)
  82. elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
  83. param_init.constant_init(sublayer.weight, value=1.0)
  84. param_init.constant_init(sublayer.bias, value=0.0)
  85. class StemBlock(nn.Layer):
  86. def __init__(self, in_dim, out_dim):
  87. super(StemBlock, self).__init__()
  88. self.conv = layers.ConvBNReLU(in_dim, out_dim, 3, stride=2)
  89. self.left = nn.Sequential(
  90. layers.ConvBNReLU(out_dim, out_dim // 2, 1),
  91. layers.ConvBNReLU(out_dim // 2, out_dim, 3, stride=2))
  92. self.right = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
  93. self.fuse = layers.ConvBNReLU(out_dim * 2, out_dim, 3)
  94. def forward(self, x):
  95. x = self.conv(x)
  96. left = self.left(x)
  97. right = self.right(x)
  98. concat = paddle.concat([left, right], axis=1)
  99. return self.fuse(concat)
  100. class ContextEmbeddingBlock(nn.Layer):
  101. def __init__(self, in_dim, out_dim):
  102. super(ContextEmbeddingBlock, self).__init__()
  103. self.gap = nn.AdaptiveAvgPool2D(1)
  104. self.bn = layers.SyncBatchNorm(in_dim)
  105. self.conv_1x1 = layers.ConvBNReLU(in_dim, out_dim, 1)
  106. self.conv_3x3 = nn.Conv2D(out_dim, out_dim, 3, 1, 1)
  107. def forward(self, x):
  108. gap = self.gap(x)
  109. bn = self.bn(gap)
  110. conv1 = self.conv_1x1(bn) + x
  111. return self.conv_3x3(conv1)
  112. class GatherAndExpansionLayer1(nn.Layer):
  113. """Gather And Expansion Layer with stride 1"""
  114. def __init__(self, in_dim, out_dim, expand):
  115. super().__init__()
  116. expand_dim = expand * in_dim
  117. self.conv = nn.Sequential(
  118. layers.ConvBNReLU(in_dim, in_dim, 3),
  119. layers.DepthwiseConvBN(in_dim, expand_dim, 3),
  120. layers.ConvBN(expand_dim, out_dim, 1))
  121. def forward(self, x):
  122. return F.relu(self.conv(x) + x)
  123. class GatherAndExpansionLayer2(nn.Layer):
  124. """Gather And Expansion Layer with stride 2"""
  125. def __init__(self, in_dim, out_dim, expand):
  126. super().__init__()
  127. expand_dim = expand * in_dim
  128. self.branch_1 = nn.Sequential(
  129. layers.ConvBNReLU(in_dim, in_dim, 3),
  130. layers.DepthwiseConvBN(in_dim, expand_dim, 3, stride=2),
  131. layers.DepthwiseConvBN(expand_dim, expand_dim, 3),
  132. layers.ConvBN(expand_dim, out_dim, 1))
  133. self.branch_2 = nn.Sequential(
  134. layers.DepthwiseConvBN(in_dim, in_dim, 3, stride=2),
  135. layers.ConvBN(in_dim, out_dim, 1))
  136. def forward(self, x):
  137. return F.relu(self.branch_1(x) + self.branch_2(x))
  138. class DetailBranch(nn.Layer):
  139. """The detail branch of BiSeNet, which has wide channels but shallow layers."""
  140. def __init__(self, in_channels):
  141. super().__init__()
  142. C1, C2, C3 = in_channels
  143. self.convs = nn.Sequential(
  144. # stage 1
  145. layers.ConvBNReLU(3, C1, 3, stride=2),
  146. layers.ConvBNReLU(C1, C1, 3),
  147. # stage 2
  148. layers.ConvBNReLU(C1, C2, 3, stride=2),
  149. layers.ConvBNReLU(C2, C2, 3),
  150. layers.ConvBNReLU(C2, C2, 3),
  151. # stage 3
  152. layers.ConvBNReLU(C2, C3, 3, stride=2),
  153. layers.ConvBNReLU(C3, C3, 3),
  154. layers.ConvBNReLU(C3, C3, 3),
  155. )
  156. def forward(self, x):
  157. return self.convs(x)
  158. class SemanticBranch(nn.Layer):
  159. """The semantic branch of BiSeNet, which has narrow channels but deep layers."""
  160. def __init__(self, in_channels):
  161. super().__init__()
  162. C1, C3, C4, C5 = in_channels
  163. self.stem = StemBlock(3, C1)
  164. self.stage3 = nn.Sequential(
  165. GatherAndExpansionLayer2(C1, C3, 6),
  166. GatherAndExpansionLayer1(C3, C3, 6))
  167. self.stage4 = nn.Sequential(
  168. GatherAndExpansionLayer2(C3, C4, 6),
  169. GatherAndExpansionLayer1(C4, C4, 6))
  170. self.stage5_4 = nn.Sequential(
  171. GatherAndExpansionLayer2(C4, C5, 6),
  172. GatherAndExpansionLayer1(C5, C5, 6),
  173. GatherAndExpansionLayer1(C5, C5, 6),
  174. GatherAndExpansionLayer1(C5, C5, 6))
  175. self.ce = ContextEmbeddingBlock(C5, C5)
  176. def forward(self, x):
  177. stage2 = self.stem(x)
  178. stage3 = self.stage3(stage2)
  179. stage4 = self.stage4(stage3)
  180. stage5_4 = self.stage5_4(stage4)
  181. fm = self.ce(stage5_4)
  182. return stage2, stage3, stage4, stage5_4, fm
  183. class BGA(nn.Layer):
  184. """The Bilateral Guided Aggregation Layer, used to fuse the semantic features and spatial features."""
  185. def __init__(self, out_dim, align_corners):
  186. super().__init__()
  187. self.align_corners = align_corners
  188. self.db_branch_keep = nn.Sequential(
  189. layers.DepthwiseConvBN(out_dim, out_dim, 3),
  190. nn.Conv2D(out_dim, out_dim, 1))
  191. self.db_branch_down = nn.Sequential(
  192. layers.ConvBN(out_dim, out_dim, 3, stride=2),
  193. nn.AvgPool2D(kernel_size=3, stride=2, padding=1))
  194. self.sb_branch_keep = nn.Sequential(
  195. layers.DepthwiseConvBN(out_dim, out_dim, 3),
  196. nn.Conv2D(out_dim, out_dim, 1), layers.Activation(act='sigmoid'))
  197. self.sb_branch_up = layers.ConvBN(out_dim, out_dim, 3)
  198. self.conv = layers.ConvBN(out_dim, out_dim, 3)
  199. def forward(self, dfm, sfm):
  200. db_feat_keep = self.db_branch_keep(dfm)
  201. db_feat_down = self.db_branch_down(dfm)
  202. sb_feat_keep = self.sb_branch_keep(sfm)
  203. sb_feat_up = self.sb_branch_up(sfm)
  204. sb_feat_up = F.interpolate(
  205. sb_feat_up,
  206. paddle.shape(db_feat_keep)[2:],
  207. mode='bilinear',
  208. align_corners=self.align_corners)
  209. sb_feat_up = F.sigmoid(sb_feat_up)
  210. db_feat = db_feat_keep * sb_feat_up
  211. sb_feat = db_feat_down * sb_feat_keep
  212. sb_feat = F.interpolate(
  213. sb_feat,
  214. paddle.shape(db_feat)[2:],
  215. mode='bilinear',
  216. align_corners=self.align_corners)
  217. return self.conv(db_feat + sb_feat)
  218. class SegHead(nn.Layer):
  219. def __init__(self, in_dim, mid_dim, num_classes):
  220. super().__init__()
  221. self.conv_3x3 = nn.Sequential(
  222. layers.ConvBNReLU(in_dim, mid_dim, 3), nn.Dropout(0.1))
  223. self.conv_1x1 = nn.Conv2D(mid_dim, num_classes, 1, 1)
  224. def forward(self, x):
  225. conv1 = self.conv_3x3(x)
  226. conv2 = self.conv_1x1(conv1)
  227. return conv2