unet.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg import utils
  18. from paddlex.paddleseg.cvlibs import manager
  19. from paddlex.paddleseg.models import layers
  20. @manager.MODELS.add_component
  21. class UNet(nn.Layer):
  22. """
  23. The UNet implementation based on PaddlePaddle.
  24. The original article refers to
  25. Olaf Ronneberger, et, al. "U-Net: Convolutional Networks for Biomedical Image Segmentation"
  26. (https://arxiv.org/abs/1505.04597).
  27. Args:
  28. num_classes (int): The unique number of target classes.
  29. align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
  30. is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  31. use_deconv (bool, optional): A bool value indicates whether using deconvolution in upsampling.
  32. If False, use resize_bilinear. Default: False.
  33. pretrained (str, optional): The path or url of pretrained model for fine tuning. Default: None.
  34. """
  35. def __init__(self,
  36. num_classes,
  37. align_corners=False,
  38. use_deconv=False,
  39. pretrained=None):
  40. super().__init__()
  41. self.encode = Encoder()
  42. self.decode = Decoder(align_corners, use_deconv=use_deconv)
  43. self.cls = self.conv = nn.Conv2D(
  44. in_channels=64,
  45. out_channels=num_classes,
  46. kernel_size=3,
  47. stride=1,
  48. padding=1)
  49. self.pretrained = pretrained
  50. self.init_weight()
  51. def forward(self, x):
  52. logit_list = []
  53. x, short_cuts = self.encode(x)
  54. x = self.decode(x, short_cuts)
  55. logit = self.cls(x)
  56. logit_list.append(logit)
  57. return logit_list
  58. def init_weight(self):
  59. if self.pretrained is not None:
  60. utils.load_entire_model(self, self.pretrained)
  61. class Encoder(nn.Layer):
  62. def __init__(self):
  63. super().__init__()
  64. self.double_conv = nn.Sequential(
  65. layers.ConvBNReLU(3, 64, 3), layers.ConvBNReLU(64, 64, 3))
  66. down_channels = [[64, 128], [128, 256], [256, 512], [512, 512]]
  67. self.down_sample_list = nn.LayerList([
  68. self.down_sampling(channel[0], channel[1])
  69. for channel in down_channels
  70. ])
  71. def down_sampling(self, in_channels, out_channels):
  72. modules = []
  73. modules.append(nn.MaxPool2D(kernel_size=2, stride=2))
  74. modules.append(layers.ConvBNReLU(in_channels, out_channels, 3))
  75. modules.append(layers.ConvBNReLU(out_channels, out_channels, 3))
  76. return nn.Sequential(*modules)
  77. def forward(self, x):
  78. short_cuts = []
  79. x = self.double_conv(x)
  80. for down_sample in self.down_sample_list:
  81. short_cuts.append(x)
  82. x = down_sample(x)
  83. return x, short_cuts
  84. class Decoder(nn.Layer):
  85. def __init__(self, align_corners, use_deconv=False):
  86. super().__init__()
  87. up_channels = [[512, 256], [256, 128], [128, 64], [64, 64]]
  88. self.up_sample_list = nn.LayerList([
  89. UpSampling(channel[0], channel[1], align_corners, use_deconv)
  90. for channel in up_channels
  91. ])
  92. def forward(self, x, short_cuts):
  93. for i in range(len(short_cuts)):
  94. x = self.up_sample_list[i](x, short_cuts[-(i + 1)])
  95. return x
  96. class UpSampling(nn.Layer):
  97. def __init__(self,
  98. in_channels,
  99. out_channels,
  100. align_corners,
  101. use_deconv=False):
  102. super().__init__()
  103. self.align_corners = align_corners
  104. self.use_deconv = use_deconv
  105. if self.use_deconv:
  106. self.deconv = nn.Conv2DTranspose(
  107. in_channels,
  108. out_channels // 2,
  109. kernel_size=2,
  110. stride=2,
  111. padding=0)
  112. in_channels = in_channels + out_channels // 2
  113. else:
  114. in_channels *= 2
  115. self.double_conv = nn.Sequential(
  116. layers.ConvBNReLU(in_channels, out_channels, 3),
  117. layers.ConvBNReLU(out_channels, out_channels, 3))
  118. def forward(self, x, short_cut):
  119. if self.use_deconv:
  120. x = self.deconv(x)
  121. else:
  122. x = F.interpolate(
  123. x,
  124. paddle.shape(short_cut)[2:],
  125. mode='bilinear',
  126. align_corners=self.align_corners)
  127. x = paddle.concat([x, short_cut], axis=1)
  128. x = self.double_conv(x)
  129. return x