| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import paddle
- import paddle.nn as nn
- from paddlex.paddleseg.cvlibs import manager
- from paddlex.paddleseg.utils import load_entire_model
- from paddlex.paddleseg.cvlibs.param_init import kaiming_normal_init
- from paddlex.paddleseg.models.layers.layer_libs import SyncBatchNorm
- @manager.MODELS.add_component
- class UNetPlusPlus(nn.Layer):
- """
- The UNet++ implementation based on PaddlePaddle.
- The original article refers to
- Zongwei Zhou, et, al. "UNet++: A Nested U-Net Architecture for Medical Image Segmentation"
- (https://arxiv.org/abs/1807.10165).
- Args:
- in_channels (int): The channel number of input image.
- num_classes (int): The unique number of target classes.
- use_deconv (bool, optional): A bool value indicates whether using deconvolution in upsampling.
- If False, use resize_bilinear. Default: False.
- align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
- is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
- pretrained (str, optional): The path or url of pretrained model for fine tuning. Default: None.
- is_ds (bool): use deep supervision or not. Default: True
- """
- def __init__(self,
- in_channels,
- num_classes,
- use_deconv=False,
- align_corners=False,
- pretrained=None,
- is_ds=True):
- super(UNetPlusPlus, self).__init__()
- self.pretrained = pretrained
- self.is_ds = is_ds
- channels = [32, 64, 128, 256, 512]
- self.pool = nn.MaxPool2D(kernel_size=2, stride=2)
- self.conv0_0 = DoubleConv(in_channels, channels[0])
- self.conv1_0 = DoubleConv(channels[0], channels[1])
- self.conv2_0 = DoubleConv(channels[1], channels[2])
- self.conv3_0 = DoubleConv(channels[2], channels[3])
- self.conv4_0 = DoubleConv(channels[3], channels[4])
- self.up_cat0_1 = UpSampling(
- channels[1],
- channels[0],
- n_cat=2,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.up_cat1_1 = UpSampling(
- channels[2],
- channels[1],
- n_cat=2,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.up_cat2_1 = UpSampling(
- channels[3],
- channels[2],
- n_cat=2,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.up_cat3_1 = UpSampling(
- channels[4],
- channels[3],
- n_cat=2,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.up_cat0_2 = UpSampling(
- channels[1],
- channels[0],
- n_cat=3,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.up_cat1_2 = UpSampling(
- channels[2],
- channels[1],
- n_cat=3,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.up_cat2_2 = UpSampling(
- channels[3],
- channels[2],
- n_cat=3,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.up_cat0_3 = UpSampling(
- channels[1],
- channels[0],
- n_cat=4,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.up_cat1_3 = UpSampling(
- channels[2],
- channels[1],
- n_cat=4,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.up_cat0_4 = UpSampling(
- channels[1],
- channels[0],
- n_cat=5,
- use_deconv=use_deconv,
- align_corners=align_corners)
- self.out_1 = nn.Conv2D(channels[0], num_classes, 1, 1, 0)
- self.out_2 = nn.Conv2D(channels[0], num_classes, 1, 1, 0)
- self.out_3 = nn.Conv2D(channels[0], num_classes, 1, 1, 0)
- self.out_4 = nn.Conv2D(channels[0], num_classes, 1, 1, 0)
- self.init_weight()
- def init_weight(self):
- if self.pretrained is not None:
- load_entire_model(self, self.pretrained)
- else:
- for sublayer in self.sublayers():
- if isinstance(sublayer, nn.Conv2D):
- kaiming_normal_init(sublayer.weight)
- elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
- kaiming_normal_init(sublayer.weight)
- def forward(self, inputs):
- # 0 down
- X0_0 = self.conv0_0(inputs) # n,32,h,w
- pool_0 = self.pool(X0_0) # n,32,h/2,w/2
- X1_0 = self.conv1_0(pool_0) # n,64,h/2,w/2
- pool_1 = self.pool(X1_0) # n,64,h/4,w/4
- X2_0 = self.conv2_0(pool_1) # n,128,h/4,w/4
- pool_2 = self.pool(X2_0) # n,128,h/8,n/8
- X3_0 = self.conv3_0(pool_2) # n,256,h/8,w/8
- pool_3 = self.pool(X3_0) # n,256,h/16,w/16
- X4_0 = self.conv4_0(pool_3) # n,512,h/16,w/16
- # 1 up+concat
- X0_1 = self.up_cat0_1(X1_0, X0_0) # n,32,h,w
- X1_1 = self.up_cat1_1(X2_0, X1_0) # n,64,h/2,w/2
- X2_1 = self.up_cat2_1(X3_0, X2_0) # n,128,h/4,w/4
- X3_1 = self.up_cat3_1(X4_0, X3_0) # n,256,h/8,w/8
- # 2 up+concat
- X0_2 = self.up_cat0_2(X1_1, X0_0, X0_1) # n,32,h,w
- X1_2 = self.up_cat1_2(X2_1, X1_0, X1_1) # n,64,h/2,w/2
- X2_2 = self.up_cat2_2(X3_1, X2_0, X2_1) # n,128,h/4,w/4
- # 3 up+concat
- X0_3 = self.up_cat0_3(X1_2, X0_0, X0_1, X0_2) # n,32,h,w
- X1_3 = self.up_cat1_3(X2_2, X1_0, X1_1, X1_2) # n,64,h/2,w/2
- # 4 up+concat
- X0_4 = self.up_cat0_4(X1_3, X0_0, X0_1, X0_2, X0_3) # n,32,h,w
- # out conv1*1
- out_1 = self.out_1(X0_1) # n,num_classes,h,w
- out_2 = self.out_2(X0_2) # n,num_classes,h,w
- out_3 = self.out_3(X0_3) # n,num_classes,h,w
- out_4 = self.out_4(X0_4) # n,num_classes,h,w
- output = (out_1 + out_2 + out_3 + out_4) / 4
- if self.is_ds:
- return [output]
- else:
- return [out_4]
- class DoubleConv(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- filter_size=3,
- stride=1,
- padding=1):
- super(DoubleConv, self).__init__()
- self.conv = nn.Sequential(
- nn.Conv2D(in_channels, out_channels, filter_size, stride, padding),
- SyncBatchNorm(out_channels),
- nn.ReLU(),
- nn.Conv2D(out_channels, out_channels, filter_size, stride,
- padding), SyncBatchNorm(out_channels), nn.ReLU())
- def forward(self, inputs):
- conv = self.conv(inputs)
- return conv
- class UpSampling(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- n_cat,
- use_deconv=False,
- align_corners=False):
- super(UpSampling, self).__init__()
- if use_deconv:
- self.up = nn.Conv2DTranspose(
- in_channels, out_channels, kernel_size=2, stride=2, padding=0)
- else:
- self.up = nn.Sequential(
- nn.Upsample(
- scale_factor=2,
- mode='bilinear',
- align_corners=align_corners),
- nn.Conv2D(in_channels, out_channels, 1, 1, 0))
- self.conv = DoubleConv(n_cat * out_channels, out_channels)
- def forward(self, high_feature, *low_features):
- features = [self.up(high_feature)]
- for feature in low_features:
- features.append(feature)
- cat_features = paddle.concat(features, axis=1)
- out = self.conv(cat_features)
- return out
|