Ver Fonte

compatible with PaddleX1.x model initiation apis

will-jl944 há 4 anos atrás
pai
commit
a1f4cc6fa5
3 ficheiros alterados com 587 adições e 33 exclusões
  1. 207 23
      dygraph/paddlex/cls.py
  2. 170 5
      dygraph/paddlex/det.py
  3. 210 5
      dygraph/paddlex/seg.py

+ 207 - 23
dygraph/paddlex/cls.py

@@ -14,29 +14,213 @@
 
 from . import cv
 from paddlex.cv.transforms import cls_transforms
+import paddlex.utils.logging as logging
 
 transforms = cls_transforms
 
-ResNet18 = cv.models.ResNet18
-ResNet34 = cv.models.ResNet34
-ResNet50 = cv.models.ResNet50
-ResNet101 = cv.models.ResNet101
-ResNet50_vd = cv.models.ResNet50_vd
-ResNet101_vd = cv.models.ResNet101_vd
-ResNet50_vd_ssld = cv.models.ResNet50_vd_ssld
-ResNet101_vd_ssld = cv.models.ResNet101_vd_ssld
-DarkNet53 = cv.models.DarkNet53
-MobileNetV1 = cv.models.MobileNetV1
-MobileNetV2 = cv.models.MobileNetV2
-MobileNetV3_small = cv.models.MobileNetV3_small
-MobileNetV3_large = cv.models.MobileNetV3_large
-MobileNetV3_small_ssld = cv.models.MobileNetV3_small_ssld
-MobileNetV3_large_ssld = cv.models.MobileNetV3_large_ssld
-Xception41 = cv.models.Xception41
-Xception65 = cv.models.Xception65
-DenseNet121 = cv.models.DenseNet121
-DenseNet161 = cv.models.DenseNet161
-DenseNet201 = cv.models.DenseNet201
-ShuffleNetV2 = cv.models.ShuffleNetV2
-HRNet_W18 = cv.models.HRNet_W18_C
-AlexNet = cv.models.AlexNet
+
+class ResNet18(cv.models.ResNet18):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(ResNet18, self).__init__(num_classes=num_classes)
+
+
+class ResNet34(cv.models.ResNet34):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(ResNet34, self).__init__(num_classes=num_classes)
+
+
+class ResNet50(cv.models.ResNet50):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(ResNet50, self).__init__(num_classes=num_classes)
+
+
+class ResNet101(cv.models.ResNet101):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(ResNet101, self).__init__(num_classes=num_classes)
+
+
+class ResNet50_vd(cv.models.ResNet50_vd):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(ResNet50_vd, self).__init__(num_classes=num_classes)
+
+
+class ResNet101_vd(cv.models.ResNet101_vd):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(ResNet101_vd, self).__init__(num_classes=num_classes)
+
+
+class ResNet50_vd_ssld(cv.models.ResNet50_vd_ssld):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(ResNet50_vd_ssld, self).__init__(num_classes=num_classes)
+
+
+class ResNet101_vd_ssld(cv.models.ResNet101_vd_ssld):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(ResNet101_vd_ssld, self).__init__(num_classes=num_classes)
+
+
+class DarkNet53(cv.models.DarkNet53):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(DarkNet53, self).__init__(num_classes=num_classes)
+
+
+class MobileNetV1(cv.models.MobileNetV1):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(MobileNetV1, self).__init__(num_classes=num_classes)
+
+
+class MobileNetV2(cv.models.MobileNetV2):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(MobileNetV2, self).__init__(num_classes=num_classes)
+
+
+class MobileNetV3_small(cv.models.MobileNetV3_small):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(MobileNetV3_small, self).__init__(num_classes=num_classes)
+
+
+class MobileNetV3_large(cv.models.MobileNetV3_large):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(MobileNetV3_large, self).__init__(num_classes=num_classes)
+
+
+class MobileNetV3_small_ssld(cv.models.MobileNetV3_small_ssld):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(MobileNetV3_small_ssld, self).__init__(num_classes=num_classes)
+
+
+class MobileNetV3_large_ssld(cv.models.MobileNetV3_large_ssld):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(MobileNetV3_large_ssld, self).__init__(num_classes=num_classes)
+
+
+class Xception41(cv.models.Xception41):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(Xception41, self).__init__(num_classes=num_classes)
+
+
+class Xception65(cv.models.Xception65):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(Xception65, self).__init__(num_classes=num_classes)
+
+
+class DenseNet121(cv.models.DenseNet121):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(DenseNet121, self).__init__(num_classes=num_classes)
+
+
+class DenseNet161(cv.models.DenseNet161):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(DenseNet161, self).__init__(num_classes=num_classes)
+
+
+class DenseNet201(cv.models.DenseNet201):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(DenseNet201, self).__init__(num_classes=num_classes)
+
+
+class ShuffleNetV2(cv.models.ShuffleNetV2):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(ShuffleNetV2, self).__init__(num_classes=num_classes)
+
+
+class HRNet_W18(cv.models.HRNet_W18_C):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(HRNet_W18, self).__init__(num_classes=num_classes)
+
+
+class AlexNet(cv.models.AlexNet):
+    def __init__(self, num_classes=1000, input_channel=None):
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(AlexNet, self).__init__(num_classes=num_classes)

+ 170 - 5
dygraph/paddlex/det.py

@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 
 from . import cv
 from .cv.models.utils.visualize import visualize_detection, draw_pr_curve
@@ -18,10 +19,174 @@ from paddlex.cv.transforms import det_transforms
 
 transforms = det_transforms
 
-FasterRCNN = cv.models.FasterRCNN
-YOLOv3 = cv.models.YOLOv3
-PPYOLO = cv.models.PPYOLO
-MaskRCNN = cv.models.MaskRCNN
-
 visualize = visualize_detection
 draw_pr_curve = draw_pr_curve
+
+
+class FasterRCNN(cv.models.FasterRCNN):
+    def __init__(self,
+                 num_classes=81,
+                 backbone='ResNet50',
+                 with_fpn=True,
+                 aspect_ratios=[0.5, 1.0, 2.0],
+                 anchor_sizes=[32, 64, 128, 256, 512],
+                 with_dcn=None,
+                 rpn_cls_loss=None,
+                 rpn_focal_loss_alpha=None,
+                 rpn_focal_loss_gamma=None,
+                 rcnn_bbox_loss=None,
+                 rcnn_nms=None,
+                 keep_top_k=100,
+                 nms_threshold=0.5,
+                 score_threshold=0.05,
+                 softnms_sigma=None,
+                 bbox_assigner=None,
+                 fpn_num_channels=256,
+                 input_channel=None,
+                 rpn_batch_size_per_im=256,
+                 rpn_fg_fraction=0.5,
+                 test_pre_nms_top_n=None,
+                 test_post_nms_top_n=1000):
+        if with_dcn is not None:
+            logging.warning(
+                "`with_dcn` is deprecated in PaddleX 2.0 and won't take effect. Defaults to False."
+            )
+        if rpn_cls_loss is not None:
+            logging.warning(
+                "`rpn_cls_loss` is deprecated in PaddleX 2.0 and won't take effect. "
+                "Defaults to 'SigmoidCrossEntropy'.")
+        if rpn_focal_loss_alpha is not None or rpn_focal_loss_gamma is not None:
+            logging.warning(
+                "Focal loss is deprecated in PaddleX 2.0."
+                " `rpn_focal_loss_alpha` and `rpn_focal_loss_gamma` won't take effect."
+            )
+        if rcnn_bbox_loss is not None:
+            logging.warning(
+                "`rcnn_bbox_loss` is deprecated in PaddleX 2.0 and won't take effect. "
+                "Defaults to 'SmoothL1Loss'")
+        if rcnn_nms is not None:
+            logging.warning(
+                "MultiClassSoftNMS is deprecated in PaddleX 2.0. "
+                "`rcnn_nms` and `softnms_sigma` won't take effect. MultiClassNMS will be used by default"
+            )
+        if bbox_assigner is not None:
+            logging.warning(
+                "`bbox_assigner` is deprecated in PaddleX 2.0 and won't take effect. "
+                "Defaults to 'BBoxAssigner'")
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(FasterRCNN, self).__init__(
+            num_classes=num_classes - 1,
+            backbone=backbone,
+            with_fpn=with_fpn,
+            aspect_ratios=aspect_ratios,
+            anchor_sizes=anchor_sizes,
+            keep_top_k=keep_top_k,
+            nms_threshold=nms_threshold,
+            score_threshold=score_threshold,
+            fpn_num_channels=fpn_num_channels,
+            rpn_batch_size_per_im=rpn_batch_size_per_im,
+            rpn_fg_fraction=rpn_fg_fraction,
+            test_pre_nms_top_n=test_pre_nms_top_n,
+            test_post_nms_top_n=test_post_nms_top_n)
+
+
+class YOLOv3(cv.models.YOLOv3):
+    def __init__(self,
+                 num_classes=80,
+                 backbone='MobileNetV1',
+                 anchors=None,
+                 anchor_masks=None,
+                 ignore_threshold=0.7,
+                 nms_score_threshold=0.01,
+                 nms_topk=1000,
+                 nms_keep_topk=100,
+                 nms_iou_threshold=0.45,
+                 label_smooth=False,
+                 train_random_shapes=None,
+                 input_channel=None):
+        if train_random_shapes is not None:
+            logging.warning(
+                "`train_random_shapes` is deprecated in PaddleX 2.0 and won't take effect. "
+                "To apply multi_scale training, please refer to paddlex.transforms.BatchRandomResize: "
+                "'https://github.com/PaddlePaddle/PaddleX/blob/develop/dygraph/paddlex/cv/transforms/batch_operators.py#L53'"
+            )
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(YOLOv3, self).__init__(
+            num_classes=num_classes,
+            backbone=backbone,
+            anchors=anchors,
+            anchor_masks=anchor_masks,
+            ignore_threshold=ignore_threshold,
+            nms_score_threshold=nms_score_threshold,
+            nms_topk=nms_topk,
+            nms_keep_topk=nms_keep_topk,
+            nms_iou_threshold=nms_iou_threshold,
+            label_smooth=label_smooth)
+
+
+class PPYOLO(cv.models.PPYOLO):
+    def __init__(
+            self,
+            num_classes=80,
+            backbone='ResNet50_vd_ssld',
+            with_dcn_v2=None,
+            # YOLO Head
+            anchors=None,
+            anchor_masks=None,
+            use_coord_conv=True,
+            use_iou_aware=True,
+            use_spp=True,
+            use_drop_block=True,
+            scale_x_y=1.05,
+            # PPYOLO Loss
+            ignore_threshold=0.7,
+            label_smooth=False,
+            use_iou_loss=True,
+            # NMS
+            use_matrix_nms=True,
+            nms_score_threshold=0.01,
+            nms_topk=1000,
+            nms_keep_topk=100,
+            nms_iou_threshold=0.45,
+            train_random_shapes=None,
+            input_channel=None):
+        if with_dcn_v2 is not None:
+            logging.warning(
+                "`with_dcn_v2` is deprecated in PaddleX 2.0 and will not take effect. "
+                "To use backbone with deformable convolutional networks, "
+                "please specify in `backbone_name`. "
+                "Currently the only backbone with dcn is 'ResNet50_vd_dcn'.")
+        if train_random_shapes is not None:
+            logging.warning(
+                "`train_random_shapes` is deprecated in PaddleX 2.0 and won't take effect. "
+                "To apply multi_scale training, please refer to paddlex.transforms.BatchRandomResize: "
+                "'https://github.com/PaddlePaddle/PaddleX/blob/develop/dygraph/paddlex/cv/transforms/batch_operators.py#L53'"
+            )
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+        super(PPYOLO, self).__init__(
+            num_classes=num_classes,
+            backbone=backbone,
+            anchors=anchors,
+            anchor_masks=anchor_masks,
+            use_coord_conv=use_coord_conv,
+            use_iou_aware=use_iou_aware,
+            use_spp=use_spp,
+            use_drop_block=use_drop_block,
+            scale_x_y=scale_x_y,
+            ignore_threshold=ignore_threshold,
+            label_smooth=label_smooth,
+            use_iou_loss=use_iou_loss,
+            use_matrix_nms=use_matrix_nms,
+            nms_score_threshold=nms_score_threshold,
+            nms_topk=nms_topk,
+            nms_keep_topk=nms_keep_topk,
+            nms_iou_threshold=nms_iou_threshold)

+ 210 - 5
dygraph/paddlex/seg.py

@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 
 from . import cv
 from .cv.models.utils.visualize import visualize_segmentation
@@ -18,9 +19,213 @@ from paddlex.cv.transforms import seg_transforms
 
 transforms = seg_transforms
 
-UNet = cv.models.UNet
-DeepLabv3p = cv.models.DeepLabV3P
-HRNet = cv.models.HRNet
-FastSCNN = cv.models.FastSCNN
-
 visualize = visualize_segmentation
+
+
+class UNet(cv.models.UNet):
+    def __init__(self,
+                 num_classes=2,
+                 upsample_mode='bilinear',
+                 use_bce_loss=False,
+                 use_dice_loss=False,
+                 class_weight=None,
+                 ignore_index=None,
+                 input_channel=None):
+        if num_classes > 2 and (use_bce_loss or use_dice_loss):
+            raise ValueError(
+                "dice loss and bce loss is only applicable to binary classfication"
+            )
+        elif num_classes == 2:
+            if use_bce_loss and use_dice_loss:
+                use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
+            elif use_bce_loss:
+                use_mixed_loss = [('CrossEntropyLoss', 1)]
+            elif use_dice_loss:
+                use_mixed_loss = [('DiceLoss', 1)]
+            else:
+                use_mixed_loss = False
+        else:
+            use_mixed_loss = False
+
+        if class_weight is not None:
+            logging.warning(
+                "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
+            )
+        if ignore_index is not None:
+            logging.warning(
+                "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
+            )
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+
+        if upsample_mode == 'bilinear':
+            use_deconv = False
+        else:
+            use_deconv = True
+        super(UNet, self).__init__(
+            num_classes=num_classes,
+            use_mixed_loss=use_mixed_loss,
+            use_deconv=use_deconv)
+
+
+class DeepLabv3p(cv.models.DeepLabV3P):
+    def __init__(self,
+                 num_classes=2,
+                 backbone='ResNet50_vd',
+                 output_stride=8,
+                 aspp_with_sep_conv=None,
+                 decoder_use_sep_conv=None,
+                 encoder_with_aspp=None,
+                 enable_decoder=None,
+                 use_bce_loss=False,
+                 use_dice_loss=False,
+                 class_weight=None,
+                 ignore_index=None,
+                 pooling_crop_size=None,
+                 input_channel=None):
+        if num_classes > 2 and (use_bce_loss or use_dice_loss):
+            raise ValueError(
+                "dice loss and bce loss is only applicable to binary classfication"
+            )
+        elif num_classes == 2:
+            if use_bce_loss and use_dice_loss:
+                use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
+            elif use_bce_loss:
+                use_mixed_loss = [('CrossEntropyLoss', 1)]
+            elif use_dice_loss:
+                use_mixed_loss = [('DiceLoss', 1)]
+            else:
+                use_mixed_loss = False
+        else:
+            use_mixed_loss = False
+
+        if aspp_with_sep_conv is not None:
+            logging.warning(
+                "`aspp_with_sep_conv` is deprecated in PaddleX 2.0 and will not take effect. "
+                "Defaults to True")
+        if decoder_use_sep_conv is not None:
+            logging.warning(
+                "`decoder_use_sep_conv` is deprecated in PaddleX 2.0 and will not take effect. "
+                "Defaults to True")
+        if encoder_with_aspp is not None:
+            logging.warning(
+                "`encoder_with_aspp` is deprecated in PaddleX 2.0 and will not take effect. "
+                "Defaults to True")
+        if enable_decoder is not None:
+            logging.warning(
+                "`enable_decoder` is deprecated in PaddleX 2.0 and will not take effect. "
+                "Defaults to True")
+        if class_weight is not None:
+            logging.warning(
+                "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
+            )
+        if ignore_index is not None:
+            logging.warning(
+                "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
+            )
+        if pooling_crop_size is not None:
+            logging.warning(
+                "Backbone 'MobileNetV3_large_x1_0_ssld' is currently not supported in PaddleX 2.0. "
+                "`pooling_crop_size` will not take effect. Defaults to None")
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+
+        super(DeepLabv3p, self).__init__(
+            num_classes=num_classes,
+            backbone=backbone,
+            use_mixed_loss=use_mixed_loss,
+            output_stride=output_stride)
+
+
+class HRNet(cv.models.HRNet):
+    def __init__(self,
+                 num_classes=2,
+                 width=18,
+                 use_bce_loss=False,
+                 use_dice_loss=False,
+                 class_weight=None,
+                 ignore_index=None,
+                 input_channel=None):
+        if num_classes > 2 and (use_bce_loss or use_dice_loss):
+            raise ValueError(
+                "dice loss and bce loss is only applicable to binary classfication"
+            )
+        elif num_classes == 2:
+            if use_bce_loss and use_dice_loss:
+                use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
+            elif use_bce_loss:
+                use_mixed_loss = [('CrossEntropyLoss', 1)]
+            elif use_dice_loss:
+                use_mixed_loss = [('DiceLoss', 1)]
+            else:
+                use_mixed_loss = False
+        else:
+            use_mixed_loss = False
+
+        if class_weight is not None:
+            logging.warning(
+                "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
+            )
+        if ignore_index is not None:
+            logging.warning(
+                "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
+            )
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+
+        super(HRNet, self).__init__(
+            num_classes=num_classes,
+            width=width,
+            use_mixed_loss=use_mixed_loss)
+
+
+class FastSCNN(cv.models.FastSCNN):
+    def __init__(self,
+                 num_classes=2,
+                 use_bce_loss=False,
+                 use_dice_loss=False,
+                 class_weight=None,
+                 ignore_index=255,
+                 multi_loss_weight=None,
+                 input_channel=3):
+        if num_classes > 2 and (use_bce_loss or use_dice_loss):
+            raise ValueError(
+                "dice loss and bce loss is only applicable to binary classfication"
+            )
+        elif num_classes == 2:
+            if use_bce_loss and use_dice_loss:
+                use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
+            elif use_bce_loss:
+                use_mixed_loss = [('CrossEntropyLoss', 1)]
+            elif use_dice_loss:
+                use_mixed_loss = [('DiceLoss', 1)]
+            else:
+                use_mixed_loss = False
+        else:
+            use_mixed_loss = False
+
+        if class_weight is not None:
+            logging.warning(
+                "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
+            )
+        if ignore_index is not None:
+            logging.warning(
+                "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
+            )
+        if multi_loss_weight is not None:
+            logging.warning(
+                "`multi_loss_weight` is deprecated in PaddleX 2.0 and will not take effect. "
+                "Defaults to [1.0, 0.4]")
+        if input_channel is not None:
+            logging.warning(
+                "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
+            )
+
+        super(FastSCNN, self).__init__(
+            num_classes=num_classes, use_mixed_loss=use_mixed_loss)