| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from paddlex.paddleseg.cvlibs import manager
- from paddlex.paddleseg.models import layers
- from paddlex.paddleseg.utils import utils
- __all__ = ['DeepLabV3P', 'DeepLabV3']
- @manager.MODELS.add_component
- class DeepLabV3P(nn.Layer):
- """
- The DeepLabV3Plus implementation based on PaddlePaddle.
- The original article refers to
- Liang-Chieh Chen, et, al. "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation"
- (https://arxiv.org/abs/1802.02611)
- Args:
- num_classes (int): The unique number of target classes.
- backbone (paddle.nn.Layer): Backbone network, currently support Resnet50_vd/Resnet101_vd/Xception65.
- backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone.
- Default: (0, 3).
- aspp_ratios (tuple, optional): The dilation rate using in ASSP module.
- If output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
- If output_stride=8, aspp_ratios is (1, 12, 24, 36).
- Default: (1, 6, 12, 18).
- aspp_out_channels (int, optional): The output channels of ASPP module. Default: 256.
- align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
- e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
- pretrained (str, optional): The path or url of pretrained model. Default: None.
- """
- def __init__(self,
- num_classes,
- backbone,
- backbone_indices=(0, 3),
- aspp_ratios=(1, 6, 12, 18),
- aspp_out_channels=256,
- align_corners=False,
- pretrained=None):
- super().__init__()
- self.backbone = backbone
- backbone_channels = [
- backbone.feat_channels[i] for i in backbone_indices
- ]
- self.head = DeepLabV3PHead(num_classes, backbone_indices,
- backbone_channels, aspp_ratios,
- aspp_out_channels, align_corners)
- self.align_corners = align_corners
- self.pretrained = pretrained
- self.init_weight()
- def forward(self, x):
- feat_list = self.backbone(x)
- logit_list = self.head(feat_list)
- return [
- F.interpolate(
- logit,
- paddle.shape(x)[2:],
- mode='bilinear',
- align_corners=self.align_corners) for logit in logit_list
- ]
- def init_weight(self):
- if self.pretrained is not None:
- utils.load_entire_model(self, self.pretrained)
- class DeepLabV3PHead(nn.Layer):
- """
- The DeepLabV3PHead implementation based on PaddlePaddle.
- Args:
- num_classes (int): The unique number of target classes.
- backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone.
- the first index will be taken as a low-level feature in Decoder component;
- the second one will be taken as input of ASPP component.
- Usually backbone consists of four downsampling stage, and return an output of
- each stage. If we set it as (0, 3), it means taking feature map of the first
- stage in backbone as low-level feature used in Decoder, and feature map of the fourth
- stage as input of ASPP.
- backbone_channels (tuple): The same length with "backbone_indices". It indicates the channels of corresponding index.
- aspp_ratios (tuple): The dilation rates using in ASSP module.
- aspp_out_channels (int): The output channels of ASPP module.
- align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
- is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
- """
- def __init__(self, num_classes, backbone_indices, backbone_channels,
- aspp_ratios, aspp_out_channels, align_corners):
- super().__init__()
- self.aspp = layers.ASPPModule(
- aspp_ratios,
- backbone_channels[1],
- aspp_out_channels,
- align_corners,
- use_sep_conv=True,
- image_pooling=True)
- self.decoder = Decoder(num_classes, backbone_channels[0],
- align_corners)
- self.backbone_indices = backbone_indices
- def forward(self, feat_list):
- logit_list = []
- low_level_feat = feat_list[self.backbone_indices[0]]
- x = feat_list[self.backbone_indices[1]]
- x = self.aspp(x)
- logit = self.decoder(x, low_level_feat)
- logit_list.append(logit)
- return logit_list
- @manager.MODELS.add_component
- class DeepLabV3(nn.Layer):
- """
- The DeepLabV3 implementation based on PaddlePaddle.
- The original article refers to
- Liang-Chieh Chen, et, al. "Rethinking Atrous Convolution for Semantic Image Segmentation"
- (https://arxiv.org/pdf/1706.05587.pdf).
- Args:
- Please Refer to DeepLabV3P above.
- """
- def __init__(self,
- num_classes,
- backbone,
- backbone_indices=(3, ),
- aspp_ratios=(1, 6, 12, 18),
- aspp_out_channels=256,
- align_corners=False,
- pretrained=None):
- super().__init__()
- self.backbone = backbone
- backbone_channels = [
- backbone.feat_channels[i] for i in backbone_indices
- ]
- self.head = DeepLabV3Head(num_classes, backbone_indices,
- backbone_channels, aspp_ratios,
- aspp_out_channels, align_corners)
- self.align_corners = align_corners
- self.pretrained = pretrained
- self.init_weight()
- def forward(self, x):
- feat_list = self.backbone(x)
- logit_list = self.head(feat_list)
- return [
- F.interpolate(
- logit,
- paddle.shape(x)[2:],
- mode='bilinear',
- align_corners=self.align_corners) for logit in logit_list
- ]
- def init_weight(self):
- if self.pretrained is not None:
- utils.load_entire_model(self, self.pretrained)
- class DeepLabV3Head(nn.Layer):
- """
- The DeepLabV3Head implementation based on PaddlePaddle.
- Args:
- Please Refer to DeepLabV3PHead above.
- """
- def __init__(self, num_classes, backbone_indices, backbone_channels,
- aspp_ratios, aspp_out_channels, align_corners):
- super().__init__()
- self.aspp = layers.ASPPModule(
- aspp_ratios,
- backbone_channels[0],
- aspp_out_channels,
- align_corners,
- use_sep_conv=False,
- image_pooling=True)
- self.cls = nn.Conv2D(
- in_channels=aspp_out_channels,
- out_channels=num_classes,
- kernel_size=1)
- self.backbone_indices = backbone_indices
- def forward(self, feat_list):
- logit_list = []
- x = feat_list[self.backbone_indices[0]]
- x = self.aspp(x)
- logit = self.cls(x)
- logit_list.append(logit)
- return logit_list
- class Decoder(nn.Layer):
- """
- Decoder module of DeepLabV3P model
- Args:
- num_classes (int): The number of classes.
- in_channels (int): The number of input channels in decoder module.
- """
- def __init__(self, num_classes, in_channels, align_corners):
- super(Decoder, self).__init__()
- self.conv_bn_relu1 = layers.ConvBNReLU(
- in_channels=in_channels, out_channels=48, kernel_size=1)
- self.conv_bn_relu2 = layers.SeparableConvBNReLU(
- in_channels=304, out_channels=256, kernel_size=3, padding=1)
- self.conv_bn_relu3 = layers.SeparableConvBNReLU(
- in_channels=256, out_channels=256, kernel_size=3, padding=1)
- self.conv = nn.Conv2D(
- in_channels=256, out_channels=num_classes, kernel_size=1)
- self.align_corners = align_corners
- def forward(self, x, low_level_feat):
- low_level_feat = self.conv_bn_relu1(low_level_feat)
- x = F.interpolate(
- x,
- paddle.shape(low_level_feat)[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- x = paddle.concat([x, low_level_feat], axis=1)
- x = self.conv_bn_relu2(x)
- x = self.conv_bn_relu3(x)
- x = self.conv(x)
- return x
|