unet_plusplus.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from paddlex.paddleseg.cvlibs import manager
  17. from paddlex.paddleseg.utils import load_entire_model
  18. from paddlex.paddleseg.cvlibs.param_init import kaiming_normal_init
  19. from paddlex.paddleseg.models.layers.layer_libs import SyncBatchNorm
  20. @manager.MODELS.add_component
  21. class UNetPlusPlus(nn.Layer):
  22. """
  23. The UNet++ implementation based on PaddlePaddle.
  24. The original article refers to
  25. Zongwei Zhou, et, al. "UNet++: A Nested U-Net Architecture for Medical Image Segmentation"
  26. (https://arxiv.org/abs/1807.10165).
  27. Args:
  28. in_channels (int): The channel number of input image.
  29. num_classes (int): The unique number of target classes.
  30. use_deconv (bool, optional): A bool value indicates whether using deconvolution in upsampling.
  31. If False, use resize_bilinear. Default: False.
  32. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  33. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  34. pretrained (str, optional): The path or url of pretrained model for fine tuning. Default: None.
  35. is_ds (bool): use deep supervision or not. Default: True
  36. """
  37. def __init__(self,
  38. in_channels,
  39. num_classes,
  40. use_deconv=False,
  41. align_corners=False,
  42. pretrained=None,
  43. is_ds=True):
  44. super(UNetPlusPlus, self).__init__()
  45. self.pretrained = pretrained
  46. self.is_ds = is_ds
  47. channels = [32, 64, 128, 256, 512]
  48. self.pool = nn.MaxPool2D(kernel_size=2, stride=2)
  49. self.conv0_0 = DoubleConv(in_channels, channels[0])
  50. self.conv1_0 = DoubleConv(channels[0], channels[1])
  51. self.conv2_0 = DoubleConv(channels[1], channels[2])
  52. self.conv3_0 = DoubleConv(channels[2], channels[3])
  53. self.conv4_0 = DoubleConv(channels[3], channels[4])
  54. self.up_cat0_1 = UpSampling(
  55. channels[1],
  56. channels[0],
  57. n_cat=2,
  58. use_deconv=use_deconv,
  59. align_corners=align_corners)
  60. self.up_cat1_1 = UpSampling(
  61. channels[2],
  62. channels[1],
  63. n_cat=2,
  64. use_deconv=use_deconv,
  65. align_corners=align_corners)
  66. self.up_cat2_1 = UpSampling(
  67. channels[3],
  68. channels[2],
  69. n_cat=2,
  70. use_deconv=use_deconv,
  71. align_corners=align_corners)
  72. self.up_cat3_1 = UpSampling(
  73. channels[4],
  74. channels[3],
  75. n_cat=2,
  76. use_deconv=use_deconv,
  77. align_corners=align_corners)
  78. self.up_cat0_2 = UpSampling(
  79. channels[1],
  80. channels[0],
  81. n_cat=3,
  82. use_deconv=use_deconv,
  83. align_corners=align_corners)
  84. self.up_cat1_2 = UpSampling(
  85. channels[2],
  86. channels[1],
  87. n_cat=3,
  88. use_deconv=use_deconv,
  89. align_corners=align_corners)
  90. self.up_cat2_2 = UpSampling(
  91. channels[3],
  92. channels[2],
  93. n_cat=3,
  94. use_deconv=use_deconv,
  95. align_corners=align_corners)
  96. self.up_cat0_3 = UpSampling(
  97. channels[1],
  98. channels[0],
  99. n_cat=4,
  100. use_deconv=use_deconv,
  101. align_corners=align_corners)
  102. self.up_cat1_3 = UpSampling(
  103. channels[2],
  104. channels[1],
  105. n_cat=4,
  106. use_deconv=use_deconv,
  107. align_corners=align_corners)
  108. self.up_cat0_4 = UpSampling(
  109. channels[1],
  110. channels[0],
  111. n_cat=5,
  112. use_deconv=use_deconv,
  113. align_corners=align_corners)
  114. self.out_1 = nn.Conv2D(channels[0], num_classes, 1, 1, 0)
  115. self.out_2 = nn.Conv2D(channels[0], num_classes, 1, 1, 0)
  116. self.out_3 = nn.Conv2D(channels[0], num_classes, 1, 1, 0)
  117. self.out_4 = nn.Conv2D(channels[0], num_classes, 1, 1, 0)
  118. self.init_weight()
  119. def init_weight(self):
  120. if self.pretrained is not None:
  121. load_entire_model(self, self.pretrained)
  122. else:
  123. for sublayer in self.sublayers():
  124. if isinstance(sublayer, nn.Conv2D):
  125. kaiming_normal_init(sublayer.weight)
  126. elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
  127. kaiming_normal_init(sublayer.weight)
  128. def forward(self, inputs):
  129. # 0 down
  130. X0_0 = self.conv0_0(inputs) # n,32,h,w
  131. pool_0 = self.pool(X0_0) # n,32,h/2,w/2
  132. X1_0 = self.conv1_0(pool_0) # n,64,h/2,w/2
  133. pool_1 = self.pool(X1_0) # n,64,h/4,w/4
  134. X2_0 = self.conv2_0(pool_1) # n,128,h/4,w/4
  135. pool_2 = self.pool(X2_0) # n,128,h/8,n/8
  136. X3_0 = self.conv3_0(pool_2) # n,256,h/8,w/8
  137. pool_3 = self.pool(X3_0) # n,256,h/16,w/16
  138. X4_0 = self.conv4_0(pool_3) # n,512,h/16,w/16
  139. # 1 up+concat
  140. X0_1 = self.up_cat0_1(X1_0, X0_0) # n,32,h,w
  141. X1_1 = self.up_cat1_1(X2_0, X1_0) # n,64,h/2,w/2
  142. X2_1 = self.up_cat2_1(X3_0, X2_0) # n,128,h/4,w/4
  143. X3_1 = self.up_cat3_1(X4_0, X3_0) # n,256,h/8,w/8
  144. # 2 up+concat
  145. X0_2 = self.up_cat0_2(X1_1, X0_0, X0_1) # n,32,h,w
  146. X1_2 = self.up_cat1_2(X2_1, X1_0, X1_1) # n,64,h/2,w/2
  147. X2_2 = self.up_cat2_2(X3_1, X2_0, X2_1) # n,128,h/4,w/4
  148. # 3 up+concat
  149. X0_3 = self.up_cat0_3(X1_2, X0_0, X0_1, X0_2) # n,32,h,w
  150. X1_3 = self.up_cat1_3(X2_2, X1_0, X1_1, X1_2) # n,64,h/2,w/2
  151. # 4 up+concat
  152. X0_4 = self.up_cat0_4(X1_3, X0_0, X0_1, X0_2, X0_3) # n,32,h,w
  153. # out conv1*1
  154. out_1 = self.out_1(X0_1) # n,num_classes,h,w
  155. out_2 = self.out_2(X0_2) # n,num_classes,h,w
  156. out_3 = self.out_3(X0_3) # n,num_classes,h,w
  157. out_4 = self.out_4(X0_4) # n,num_classes,h,w
  158. output = (out_1 + out_2 + out_3 + out_4) / 4
  159. if self.is_ds:
  160. return [output]
  161. else:
  162. return [out_4]
  163. class DoubleConv(nn.Layer):
  164. def __init__(self,
  165. in_channels,
  166. out_channels,
  167. filter_size=3,
  168. stride=1,
  169. padding=1):
  170. super(DoubleConv, self).__init__()
  171. self.conv = nn.Sequential(
  172. nn.Conv2D(in_channels, out_channels, filter_size, stride, padding),
  173. SyncBatchNorm(out_channels),
  174. nn.ReLU(),
  175. nn.Conv2D(out_channels, out_channels, filter_size, stride,
  176. padding), SyncBatchNorm(out_channels), nn.ReLU())
  177. def forward(self, inputs):
  178. conv = self.conv(inputs)
  179. return conv
  180. class UpSampling(nn.Layer):
  181. def __init__(self,
  182. in_channels,
  183. out_channels,
  184. n_cat,
  185. use_deconv=False,
  186. align_corners=False):
  187. super(UpSampling, self).__init__()
  188. if use_deconv:
  189. self.up = nn.Conv2DTranspose(
  190. in_channels, out_channels, kernel_size=2, stride=2, padding=0)
  191. else:
  192. self.up = nn.Sequential(
  193. nn.Upsample(
  194. scale_factor=2,
  195. mode='bilinear',
  196. align_corners=align_corners),
  197. nn.Conv2D(in_channels, out_channels, 1, 1, 0))
  198. self.conv = DoubleConv(n_cat * out_channels, out_channels)
  199. def forward(self, high_feature, *low_features):
  200. features = [self.up(high_feature)]
  201. for feature in low_features:
  202. features.append(feature)
  203. cat_features = paddle.concat(features, axis=1)
  204. out = self.conv(cat_features)
  205. return out