Browse Source

Merge remote-tracking branch 'paddle/develop' into develop

syyxsxx 5 years ago
parent
commit
2c94ac53a5

+ 2 - 2
README.md

@@ -104,8 +104,8 @@ pip install paddlex -i https://mirror.baidu.com/pypi/simple
 ## 交流与反馈
 
 - 项目官网:https://www.paddlepaddle.org.cn/paddle/paddlex
-- PaddleX用户交流群:1045148026 (手机QQ扫描如下二维码快速加入)  
-  ![](./docs/gui/images/QR.jpg)
+- PaddleX用户交流群:957286141 (手机QQ扫描如下二维码快速加入)  
+  ![](./docs/gui/images/QR2.jpg)
 
 
 

+ 21 - 11
deploy/lite/android/sdk/src/main/java/com/baidu/paddlex/preprocess/Transforms.java

@@ -23,6 +23,7 @@ import org.opencv.core.Scalar;
 import org.opencv.core.Size;
 import org.opencv.imgproc.Imgproc;
 import java.util.ArrayList;
+import java.util.Date;
 import java.util.HashMap;
 import java.util.List;
 
@@ -101,6 +102,15 @@ public class Transforms {
                 if (info.containsKey("coarsest_stride")) {
                     padding.coarsest_stride = (int) info.get("coarsest_stride");
                 }
+                if (info.containsKey("im_padding_value")) {
+                    List<Double> im_padding_value = (List<Double>) info.get("im_padding_value");
+                    if (im_padding_value.size()!=3){
+                        Log.e(TAG, "len of im_padding_value in padding must == 3.");
+                    }
+                    for (int k =0; i<im_padding_value.size(); i++){
+                        padding.paddding_value[k] = im_padding_value.get(k);
+                    }
+                }
                 if (info.containsKey("target_size")) {
                     if (info.get("target_size") instanceof Integer) {
                         padding.width = (int) info.get("target_size");
@@ -124,7 +134,7 @@ public class Transforms {
         if(transformsMode.equalsIgnoreCase("RGB")){
             Imgproc.cvtColor(inputMat, inputMat, Imgproc.COLOR_BGR2RGB);
         }else if(!transformsMode.equalsIgnoreCase("BGR")){
-            Log.e(TAG, "transformsMode only support RGB or BGR");
+            Log.e(TAG, "transformsMode only support RGB or BGR.");
         }
         inputMat.convertTo(inputMat, CvType.CV_32FC(3));
 
@@ -136,16 +146,15 @@ public class Transforms {
         int h = inputMat.height();
         int c = inputMat.channels();
         imageBlob.setImageData(new float[w * h * c]);
-        int[] channelStride = new int[]{w * h, w * h * 2};
-        for (int y = 0; y < h; y++) {
-            for (int x = 0;
-                 x < w; x++) {
-                double[] color = inputMat.get(y, x);
-                imageBlob.getImageData()[y * w + x]  =  (float) (color[0]);
-                imageBlob.getImageData()[y * w + x +  channelStride[0]] = (float) (color[1]);
-                imageBlob.getImageData()[y * w + x +  channelStride[1]] = (float) (color[2]);
-            }
+
+        Mat singleChannelMat = new Mat(h, w, CvType.CV_32FC(1));
+        float[] singleChannelImageData = new float[w * h];
+        for (int i = 0; i < c; i++) {
+            Core.extractChannel(inputMat, singleChannelMat, i);
+            singleChannelMat.get(0, 0, singleChannelImageData);
+            System.arraycopy(singleChannelImageData ,0, imageBlob.getImageData(),i*w*h, w*h);
         }
+
         return imageBlob;
     }
 
@@ -248,6 +257,7 @@ public class Transforms {
         private double width;
         private double height;
         private double coarsest_stride;
+        private double[] paddding_value = {0.0, 0.0, 0.0};
 
         public Mat run(Mat inputMat, ImageBlob imageBlob) {
             int origin_w = inputMat.width();
@@ -264,7 +274,7 @@ public class Transforms {
             }
             imageBlob.setNewImageSize(inputMat.height(),2);
             imageBlob.setNewImageSize(inputMat.width(),3);
-            Core.copyMakeBorder(inputMat, inputMat, 0, (int)padding_h, 0, (int)padding_w, Core.BORDER_CONSTANT, new Scalar(0));
+            Core.copyMakeBorder(inputMat, inputMat, 0, (int)padding_h, 0, (int)padding_w, Core.BORDER_CONSTANT, new Scalar(paddding_value));
             return inputMat;
         }
     }

+ 6 - 5
deploy/lite/android/sdk/src/main/java/com/baidu/paddlex/visual/Visualize.java

@@ -31,8 +31,11 @@ import org.opencv.core.Scalar;
 import org.opencv.core.Size;
 import org.opencv.imgproc.Imgproc;
 
+import java.nio.ByteBuffer;
+import java.nio.FloatBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Date;
 import java.util.List;
 import java.util.ListIterator;
 import java.util.Map;
@@ -120,13 +123,11 @@ public class Visualize {
         int new_w = (int)imageBlob.getNewImageSize()[3];
         Mat mask = new Mat(new_h, new_w, CvType.CV_32FC(1));
         float[] scoreData = new float[new_h*new_w];
-        for  (int h = 0; h < new_h; h++) {
-            for  (int w = 0; w < new_w; w++){
-                scoreData[new_h * h + w] =  (1-result.getMask().getScoreData()[cutoutClass + h * new_h + w]) * 255;
-            }
-        }
+        System.arraycopy(result.getMask().getScoreData() ,cutoutClass*new_h*new_w, scoreData ,0, new_h*new_w);
         mask.put(0,0, scoreData);
+        Core.multiply(mask, new Scalar(255), mask);
         mask.convertTo(mask,CvType.CV_8UC(1));
+
         ListIterator<Map.Entry<String, int[]>> reverseReshapeInfo = new ArrayList<Map.Entry<String, int[]>>(imageBlob.getReshapeInfo().entrySet()).listIterator(imageBlob.getReshapeInfo().size());
         while (reverseReshapeInfo.hasPrevious()) {
             Map.Entry<String, int[]> entry = reverseReshapeInfo.previous();

+ 1 - 1
docs/README.md

@@ -1,6 +1,6 @@
 # PaddleX文档
 
-PaddleX的使用文档均在本目录结构下。文档采用Read the Docs方式组织,您可以直接访问[在线文档](https://paddlex.readthedocs.io/zh_CN/latest/index.html)进行查阅。
+PaddleX的使用文档均在本目录结构下。文档采用Read the Docs方式组织,您可以直接访问[在线文档](https://paddlex.readthedocs.io/zh_CN/develop/index.html)进行查阅。
 
 ## 编译文档
 在本目录下按如下步骤进行文档编译

+ 2 - 1
docs/apis/deploy.md

@@ -7,7 +7,7 @@
 图像分类、目标检测、实例分割、语义分割统一的预测器,实现高性能预测。
 
 ```
-paddlex.deploy.Predictor(model_dir, use_gpu=False, gpu_id=0, use_mkl=False, use_trt=False, use_glog=False, memory_optimize=True)
+paddlex.deploy.Predictor(model_dir, use_gpu=False, gpu_id=0, use_mkl=False, mkl_thread_num=4, use_trt=False, use_glog=False, memory_optimize=True)
 ```
 
 **参数**
@@ -16,6 +16,7 @@ paddlex.deploy.Predictor(model_dir, use_gpu=False, gpu_id=0, use_mkl=False, use_
 > * **use_gpu** (bool): 是否使用GPU进行预测。
 > * **gpu_id** (int): 使用的GPU序列号。
 > * **use_mkl** (bool): 是否使用mkldnn加速库。
+> * **mkl_thread_num** (int): 使用mkldnn加速库时的线程数,默认为4
 > * **use_trt** (boll): 是否使用TensorRT预测引擎。
 > * **use_glog** (bool): 是否打印中间日志。
 > * **memory_optimize** (bool): 是否优化内存使用。

+ 3 - 2
docs/apis/models/semantic_segmentation.md

@@ -3,7 +3,7 @@
 ## paddlex.seg.DeepLabv3p
 
 ```python
-paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255)
+paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, pooling_crop_size=None)
 
 ```
 
@@ -12,7 +12,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 > **参数**
 
 > > - **num_classes** (int): 类别数。
-> > - **backbone** (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0'],默认值为'MobileNetV2_x1.0'。
+> > - **backbone** (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld'],默认值为'MobileNetV2_x1.0'。
 > > - **output_stride** (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。
 > > - **aspp_with_sep_conv** (bool):  decoder模块是否采用separable convolutions。默认True。
 > > - **decoder_use_sep_conv** (bool): decoder模块是否采用separable convolutions。默认True。
@@ -22,6 +22,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 > > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用,当`use_bce_loss`和`use_dice_loss`都为False时,使用交叉熵损失函数。默认False。
 > > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
 > > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
+> > - **pooling_crop_size** (int):当backbone为`MobileNetV3_large_x1_0_ssld`时,需设置为训练过程中模型输入大小,格式为[W, H]。例如模型输入大小为[512, 512], 则`pooling_crop_size`应该设置为[512, 512]。在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用`avg_pool`算子得到平均值。默认值None。
 
 ### train
 

+ 1 - 0
docs/appendix/model_zoo.md

@@ -81,6 +81,7 @@
 
 | 模型    | 模型大小    | 预测时间(毫秒) | mIoU(%) |
 |:-------|:-----------|:-------------|:----------|
+| [DeepLabv3_MobileNetV3_large_x1_0_ssld](https://paddleseg.bj.bcebos.com/models/deeplabv3p_mobilenetv3_large_cityscapes.tar.gz) | 9.3MB | - | 73.28 |
 | [DeepLabv3_MobileNetv2_x1.0](https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz) | 14.7MB | - | 69.8 |
 | [DeepLabv3_Xception65](https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz) | 329.3MB | - | 79.3 |
 | [HRNet_W18](https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz) | 77.3MB |  | 79.36 |

+ 1 - 1
docs/deploy/server/python.md

@@ -27,7 +27,7 @@ import paddlex as pdx
 predictor = pdx.deploy.Predictor('./inference_model')
 image_list = ['xiaoduxiong_test_image/JPEGImages/WeChatIMG110.jpeg',
     'xiaoduxiong_test_image/JPEGImages/WeChatIMG111.jpeg']
-result = predictor.predict(image_list=image_list)
+result = predictor.batch_predict(image_list=image_list)
 ```
 
 * 视频流预测

+ 1 - 0
docs/examples/solutions.md

@@ -80,6 +80,7 @@ PaddleX目前提供了实例分割MaskRCNN模型,支持5种不同的backbone
 | 模型 | 模型特点 | 存储体积 | GPU预测速度 | CPU(x86)预测速度(毫秒) | 骁龙855(ARM)预测速度 (毫秒)| mIoU |
 | :---- | :------- | :---------- | :---------- | :----- | :----- |:--- |
 | DeepLabv3p-MobileNetV2_x1.0 | 轻量级模型,适用于移动端场景| - | - | - | 69.8% |
+| DeepLabv3-MobileNetV3_large_x1_0_ssld | 轻量级模型,适用于移动端场景| - | - | - | 73.28% |
 | HRNet_W18_Small_v1 | 轻量高速,适用于移动端场景 | - | - | - | - |
 | FastSCNN | 轻量高速,适用于追求高速预测的移动端或服务器端场景 | - | - | - | 69.64 |
 | HRNet_W18 | 高精度模型,适用于对图像分辨率较为敏感、对目标细节预测要求更高的服务器端场景| - | - | - | 79.36 |

BIN
docs/gui/images/QR2.jpg


+ 1 - 0
docs/train/semantic_segmentation.md

@@ -12,6 +12,7 @@ PaddleX目前提供了DeepLabv3p、UNet、HRNet和FastSCNN四种语义分割结
 | :----------------  | :------- | :------- | :---------  | :---------  | :-----    |
 | [DeepLabv3p-MobileNetV2-x0.25](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py) |  -  |  2.9MB  |  -   | -  |  模型小,预测速度快,适用于低性能或移动端设备   |
 | [DeepLabv3p-MobileNetV2-x1.0](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2.py) |  69.8%  |  11MB  |  -   | -  |  模型小,预测速度快,适用于低性能或移动端设备   |
+| [DeepLabv3_MobileNetV3_large_x1_0_ssld](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv3_large_ssld.py) | 73.28% | 9.3MB |  -   | -  |  模型小,预测速度快,精度较高,适用于低性能或移动端设备 |
 | [DeepLabv3p-Xception65](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_xception65.py)        | 79.3%  | 158MB   |  -  | -  |  模型大,精度高,适用于服务端   |
 | [UNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/unet.py)     | -  | 52MB   | -   | -  |  模型较大,精度高,适用于服务端   |
 | [HRNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/hrnet.py)   |  79.4%   |   37MB    |  -       |   -    | 模型较小,模型精度高,适用于服务端部署   |

+ 7 - 1
paddlex/command.py

@@ -52,6 +52,12 @@ def arg_parser():
         default=False,
         help="export onnx model for deployment")
     parser.add_argument(
+        "--onnx_opset",
+        "-oo",
+        type=int,
+        default=10,
+        help="when use paddle2onnx, set onnx opset version to export")
+    parser.add_argument(
         "--data_conversion",
         "-dc",
         action="store_true",
@@ -162,7 +168,7 @@ def main():
             logging.error(
                 "paddlex --export_inference --model_dir model_path --save_dir infer_model"
             )
-        pdx.convertor.export_onnx_model(model, args.save_dir)
+        pdx.convertor.export_onnx_model(model, args.save_dir, args.onnx_opset)
 
     if args.data_conversion:
         assert args.source is not None, "--source should be defined while converting dataset"

+ 407 - 5
paddlex/convertor.py

@@ -29,10 +29,12 @@ def export_onnx(model_dir, save_dir, fixed_input_shape):
     export_onnx_model(model, save_dir)
 
 
-def export_onnx_model(model, save_dir):
-    if model.model_type == "detector" or model.__class__.__name__ == "FastSCNN":
+def export_onnx_model(model, save_dir, opset_version=10):
+    if model.__class__.__name__ == "FastSCNN" or (
+            model.model_type == "detector" and
+            model.__class__.__name__ != "YOLOv3"):
         logging.error(
-            "Only image classifier models and semantic segmentation models(except FastSCNN) are supported to export to ONNX"
+            "Only image classifier models, detection models(YOLOv3) and semantic segmentation models(except FastSCNN) are supported to export to ONNX"
         )
     try:
         import x2paddle
@@ -41,6 +43,406 @@ def export_onnx_model(model, save_dir):
     except:
         logging.error(
             "You need to install x2paddle first, pip install x2paddle>=0.7.4")
-    from x2paddle.op_mapper.paddle_op_mapper import PaddleOpMapper
+    if opset_version == 10 and model.__class__.__name__ == "YOLOv3":
+        logging.warning(
+            "Export for openVINO by default, the output of multiclass_nms exported to onnx will contains background. If you need onnx completely consistent with paddle, please use X2Paddle to export"
+        )
+        x2paddle.op_mapper.paddle2onnx.opset10.paddle_custom_layer.multiclass_nms.multiclass_nms = multiclass_nms_for_openvino
+    from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper
     mapper = PaddleOpMapper()
-    mapper.convert(model.test_prog, save_dir)
+    mapper.convert(
+        model.test_prog,
+        save_dir,
+        scope=model.scope,
+        opset_version=opset_version)
+
+
+def multiclass_nms_for_openvino(op, block):
+    """
+    Convert the paddle multiclass_nms to onnx op.
+    This op is get the select boxes from origin boxes.
+    This op is for OpenVINO, which donn't support dynamic shape).
+    """
+    import math
+    import sys
+    import numpy as np
+    import paddle.fluid.core as core
+    import paddle.fluid as fluid
+    import onnx
+    import warnings
+    from onnx import helper, onnx_pb
+    inputs = dict()
+    outputs = dict()
+    attrs = dict()
+    for name in op.input_names:
+        inputs[name] = op.input(name)
+    for name in op.output_names:
+        outputs[name] = op.output(name)
+    for name in op.attr_names:
+        attrs[name] = op.attr(name)
+
+    result_name = outputs['Out'][0]
+    background = attrs['background_label']
+    normalized = attrs['normalized']
+    if normalized == False:
+        warnings.warn(
+            'The parameter normalized of multiclass_nms OP of Paddle is False, which has diff with ONNX. \
+                         Please set normalized=True in multiclass_nms of Paddle'
+        )
+
+    #convert the paddle attribute to onnx tensor
+    name_score_threshold = [outputs['Out'][0] + "@score_threshold"]
+    name_iou_threshold = [outputs['Out'][0] + "@iou_threshold"]
+    name_keep_top_k = [outputs['Out'][0] + '@keep_top_k']
+    name_keep_top_k_2D = [outputs['Out'][0] + '@keep_top_k_1D']
+
+    node_score_threshold = onnx.helper.make_node(
+        'Constant',
+        inputs=[],
+        outputs=name_score_threshold,
+        value=onnx.helper.make_tensor(
+            name=name_score_threshold[0] + "@const",
+            data_type=onnx.TensorProto.FLOAT,
+            dims=(),
+            vals=[float(attrs['score_threshold'])]))
+
+    node_iou_threshold = onnx.helper.make_node(
+        'Constant',
+        inputs=[],
+        outputs=name_iou_threshold,
+        value=onnx.helper.make_tensor(
+            name=name_iou_threshold[0] + "@const",
+            data_type=onnx.TensorProto.FLOAT,
+            dims=(),
+            vals=[float(attrs['nms_threshold'])]))
+
+    node_keep_top_k = onnx.helper.make_node(
+        'Constant',
+        inputs=[],
+        outputs=name_keep_top_k,
+        value=onnx.helper.make_tensor(
+            name=name_keep_top_k[0] + "@const",
+            data_type=onnx.TensorProto.INT64,
+            dims=(),
+            vals=[np.int64(attrs['keep_top_k'])]))
+
+    node_keep_top_k_2D = onnx.helper.make_node(
+        'Constant',
+        inputs=[],
+        outputs=name_keep_top_k_2D,
+        value=onnx.helper.make_tensor(
+            name=name_keep_top_k_2D[0] + "@const",
+            data_type=onnx.TensorProto.INT64,
+            dims=[1, 1],
+            vals=[np.int64(attrs['keep_top_k'])]))
+
+    # the paddle data format is x1,y1,x2,y2
+    kwargs = {'center_point_box': 0}
+
+    name_select_nms = [outputs['Out'][0] + "@select_index"]
+    node_select_nms= onnx.helper.make_node(
+        'NonMaxSuppression',
+        inputs=inputs['BBoxes'] + inputs['Scores'] + name_keep_top_k +\
+            name_iou_threshold + name_score_threshold,
+        outputs=name_select_nms)
+    # step 1 nodes select the nms class
+    node_list = [
+        node_score_threshold, node_iou_threshold, node_keep_top_k,
+        node_keep_top_k_2D, node_select_nms
+    ]
+
+    # create some const value to use
+    name_const_value = [result_name+"@const_0",
+        result_name+"@const_1",\
+        result_name+"@const_2",\
+        result_name+"@const_-1"]
+    value_const_value = [0, 1, 2, -1]
+    for name, value in zip(name_const_value, value_const_value):
+        node = onnx.helper.make_node(
+            'Constant',
+            inputs=[],
+            outputs=[name],
+            value=onnx.helper.make_tensor(
+                name=name + "@const",
+                data_type=onnx.TensorProto.INT64,
+                dims=[1],
+                vals=[value]))
+        node_list.append(node)
+
+    # In this code block, we will deocde the raw score data, reshape N * C * M to 1 * N*C*M
+    # and the same time, decode the select indices to 1 * D, gather the select_indices
+    outputs_gather_1_ = [result_name + "@gather_1_"]
+    node_gather_1_ = onnx.helper.make_node(
+        'Gather',
+        inputs=name_select_nms + [result_name + "@const_1"],
+        outputs=outputs_gather_1_,
+        axis=1)
+    node_list.append(node_gather_1_)
+    outputs_gather_1 = [result_name + "@gather_1"]
+    node_gather_1 = onnx.helper.make_node(
+        'Unsqueeze',
+        inputs=outputs_gather_1_,
+        outputs=outputs_gather_1,
+        axes=[0])
+    node_list.append(node_gather_1)
+
+    outputs_gather_2_ = [result_name + "@gather_2_"]
+    node_gather_2_ = onnx.helper.make_node(
+        'Gather',
+        inputs=name_select_nms + [result_name + "@const_2"],
+        outputs=outputs_gather_2_,
+        axis=1)
+    node_list.append(node_gather_2_)
+
+    outputs_gather_2 = [result_name + "@gather_2"]
+    node_gather_2 = onnx.helper.make_node(
+        'Unsqueeze',
+        inputs=outputs_gather_2_,
+        outputs=outputs_gather_2,
+        axes=[0])
+    node_list.append(node_gather_2)
+
+    # reshape scores N * C * M to (N*C*M) * 1
+    outputs_reshape_scores_rank1 = [result_name + "@reshape_scores_rank1"]
+    node_reshape_scores_rank1 = onnx.helper.make_node(
+        "Reshape",
+        inputs=inputs['Scores'] + [result_name + "@const_-1"],
+        outputs=outputs_reshape_scores_rank1)
+    node_list.append(node_reshape_scores_rank1)
+
+    # get the shape of scores
+    outputs_shape_scores = [result_name + "@shape_scores"]
+    node_shape_scores = onnx.helper.make_node(
+        'Shape', inputs=inputs['Scores'], outputs=outputs_shape_scores)
+    node_list.append(node_shape_scores)
+
+    # gather the index: 2 shape of scores
+    outputs_gather_scores_dim1 = [result_name + "@gather_scores_dim1"]
+    node_gather_scores_dim1 = onnx.helper.make_node(
+        'Gather',
+        inputs=outputs_shape_scores + [result_name + "@const_2"],
+        outputs=outputs_gather_scores_dim1,
+        axis=0)
+    node_list.append(node_gather_scores_dim1)
+
+    # mul class * M
+    outputs_mul_classnum_boxnum = [result_name + "@mul_classnum_boxnum"]
+    node_mul_classnum_boxnum = onnx.helper.make_node(
+        'Mul',
+        inputs=outputs_gather_1 + outputs_gather_scores_dim1,
+        outputs=outputs_mul_classnum_boxnum)
+    node_list.append(node_mul_classnum_boxnum)
+
+    # add class * M * index
+    outputs_add_class_M_index = [result_name + "@add_class_M_index"]
+    node_add_class_M_index = onnx.helper.make_node(
+        'Add',
+        inputs=outputs_mul_classnum_boxnum + outputs_gather_2,
+        outputs=outputs_add_class_M_index)
+    node_list.append(node_add_class_M_index)
+
+    # Squeeze the indices to 1 dim
+    outputs_squeeze_select_index = [result_name + "@squeeze_select_index"]
+    node_squeeze_select_index = onnx.helper.make_node(
+        'Squeeze',
+        inputs=outputs_add_class_M_index,
+        outputs=outputs_squeeze_select_index,
+        axes=[0, 2])
+    node_list.append(node_squeeze_select_index)
+
+    # gather the data from flatten scores
+    outputs_gather_select_scores = [result_name + "@gather_select_scores"]
+    node_gather_select_scores = onnx.helper.make_node('Gather',
+        inputs=outputs_reshape_scores_rank1 + \
+            outputs_squeeze_select_index,
+        outputs=outputs_gather_select_scores,
+        axis=0)
+    node_list.append(node_gather_select_scores)
+
+    # get nums to input TopK
+    outputs_shape_select_num = [result_name + "@shape_select_num"]
+    node_shape_select_num = onnx.helper.make_node(
+        'Shape',
+        inputs=outputs_gather_select_scores,
+        outputs=outputs_shape_select_num)
+    node_list.append(node_shape_select_num)
+
+    outputs_gather_select_num = [result_name + "@gather_select_num"]
+    node_gather_select_num = onnx.helper.make_node(
+        'Gather',
+        inputs=outputs_shape_select_num + [result_name + "@const_0"],
+        outputs=outputs_gather_select_num,
+        axis=0)
+    node_list.append(node_gather_select_num)
+
+    outputs_unsqueeze_select_num = [result_name + "@unsqueeze_select_num"]
+    node_unsqueeze_select_num = onnx.helper.make_node(
+        'Unsqueeze',
+        inputs=outputs_gather_select_num,
+        outputs=outputs_unsqueeze_select_num,
+        axes=[0])
+    node_list.append(node_unsqueeze_select_num)
+
+    outputs_concat_topK_select_num = [result_name + "@conat_topK_select_num"]
+    node_conat_topK_select_num = onnx.helper.make_node(
+        'Concat',
+        inputs=outputs_unsqueeze_select_num + name_keep_top_k_2D,
+        outputs=outputs_concat_topK_select_num,
+        axis=0)
+    node_list.append(node_conat_topK_select_num)
+
+    outputs_cast_concat_topK_select_num = [
+        result_name + "@concat_topK_select_num"
+    ]
+    node_outputs_cast_concat_topK_select_num = onnx.helper.make_node(
+        'Cast',
+        inputs=outputs_concat_topK_select_num,
+        outputs=outputs_cast_concat_topK_select_num,
+        to=6)
+    node_list.append(node_outputs_cast_concat_topK_select_num)
+    # get min(topK, num_select)
+    outputs_compare_topk_num_select = [
+        result_name + "@compare_topk_num_select"
+    ]
+    node_compare_topk_num_select = onnx.helper.make_node(
+        'ReduceMin',
+        inputs=outputs_cast_concat_topK_select_num,
+        outputs=outputs_compare_topk_num_select,
+        keepdims=0)
+    node_list.append(node_compare_topk_num_select)
+
+    # unsqueeze the indices to 1D tensor
+    outputs_unsqueeze_topk_select_indices = [
+        result_name + "@unsqueeze_topk_select_indices"
+    ]
+    node_unsqueeze_topk_select_indices = onnx.helper.make_node(
+        'Unsqueeze',
+        inputs=outputs_compare_topk_num_select,
+        outputs=outputs_unsqueeze_topk_select_indices,
+        axes=[0])
+    node_list.append(node_unsqueeze_topk_select_indices)
+
+    # cast the indices to INT64
+    outputs_cast_topk_indices = [result_name + "@cast_topk_indices"]
+    node_cast_topk_indices = onnx.helper.make_node(
+        'Cast',
+        inputs=outputs_unsqueeze_topk_select_indices,
+        outputs=outputs_cast_topk_indices,
+        to=7)
+    node_list.append(node_cast_topk_indices)
+
+    # select topk scores  indices
+    outputs_topk_select_topk_indices = [result_name + "@topk_select_topk_values",\
+        result_name + "@topk_select_topk_indices"]
+    node_topk_select_topk_indices = onnx.helper.make_node(
+        'TopK',
+        inputs=outputs_gather_select_scores + outputs_cast_topk_indices,
+        outputs=outputs_topk_select_topk_indices)
+    node_list.append(node_topk_select_topk_indices)
+
+    # gather topk label, scores, boxes
+    outputs_gather_topk_scores = [result_name + "@gather_topk_scores"]
+    node_gather_topk_scores = onnx.helper.make_node(
+        'Gather',
+        inputs=outputs_gather_select_scores +
+        [outputs_topk_select_topk_indices[1]],
+        outputs=outputs_gather_topk_scores,
+        axis=0)
+    node_list.append(node_gather_topk_scores)
+
+    outputs_gather_topk_class = [result_name + "@gather_topk_class"]
+    node_gather_topk_class = onnx.helper.make_node(
+        'Gather',
+        inputs=outputs_gather_1 + [outputs_topk_select_topk_indices[1]],
+        outputs=outputs_gather_topk_class,
+        axis=1)
+    node_list.append(node_gather_topk_class)
+
+    # gather the boxes need to gather the boxes id, then get boxes
+    outputs_gather_topk_boxes_id = [result_name + "@gather_topk_boxes_id"]
+    node_gather_topk_boxes_id = onnx.helper.make_node(
+        'Gather',
+        inputs=outputs_gather_2 + [outputs_topk_select_topk_indices[1]],
+        outputs=outputs_gather_topk_boxes_id,
+        axis=1)
+    node_list.append(node_gather_topk_boxes_id)
+
+    # squeeze the gather_topk_boxes_id to 1 dim
+    outputs_squeeze_topk_boxes_id = [result_name + "@squeeze_topk_boxes_id"]
+    node_squeeze_topk_boxes_id = onnx.helper.make_node(
+        'Squeeze',
+        inputs=outputs_gather_topk_boxes_id,
+        outputs=outputs_squeeze_topk_boxes_id,
+        axes=[0, 2])
+    node_list.append(node_squeeze_topk_boxes_id)
+
+    outputs_gather_select_boxes = [result_name + "@gather_select_boxes"]
+    node_gather_select_boxes = onnx.helper.make_node(
+        'Gather',
+        inputs=inputs['BBoxes'] + outputs_squeeze_topk_boxes_id,
+        outputs=outputs_gather_select_boxes,
+        axis=1)
+    node_list.append(node_gather_select_boxes)
+
+    # concat the final result
+    # before concat need to cast the class to float
+    outputs_cast_topk_class = [result_name + "@cast_topk_class"]
+    node_cast_topk_class = onnx.helper.make_node(
+        'Cast',
+        inputs=outputs_gather_topk_class,
+        outputs=outputs_cast_topk_class,
+        to=1)
+    node_list.append(node_cast_topk_class)
+
+    outputs_unsqueeze_topk_scores = [result_name + "@unsqueeze_topk_scores"]
+    node_unsqueeze_topk_scores = onnx.helper.make_node(
+        'Unsqueeze',
+        inputs=outputs_gather_topk_scores,
+        outputs=outputs_unsqueeze_topk_scores,
+        axes=[0, 2])
+    node_list.append(node_unsqueeze_topk_scores)
+
+    inputs_concat_final_results = outputs_cast_topk_class + outputs_unsqueeze_topk_scores +\
+        outputs_gather_select_boxes
+    outputs_sort_by_socre_results = [result_name + "@concat_topk_scores"]
+    node_sort_by_socre_results = onnx.helper.make_node(
+        'Concat',
+        inputs=inputs_concat_final_results,
+        outputs=outputs_sort_by_socre_results,
+        axis=2)
+    node_list.append(node_sort_by_socre_results)
+
+    # select topk classes indices
+    outputs_squeeze_cast_topk_class = [
+        result_name + "@squeeze_cast_topk_class"
+    ]
+    node_squeeze_cast_topk_class = onnx.helper.make_node(
+        'Squeeze',
+        inputs=outputs_cast_topk_class,
+        outputs=outputs_squeeze_cast_topk_class,
+        axes=[0, 2])
+    node_list.append(node_squeeze_cast_topk_class)
+    outputs_neg_squeeze_cast_topk_class = [
+        result_name + "@neg_squeeze_cast_topk_class"
+    ]
+    node_neg_squeeze_cast_topk_class = onnx.helper.make_node(
+        'Neg',
+        inputs=outputs_squeeze_cast_topk_class,
+        outputs=outputs_neg_squeeze_cast_topk_class)
+    node_list.append(node_neg_squeeze_cast_topk_class)
+    outputs_topk_select_classes_indices = [result_name + "@topk_select_topk_classes_scores",\
+        result_name + "@topk_select_topk_classes_indices"]
+    node_topk_select_topk_indices = onnx.helper.make_node(
+        'TopK',
+        inputs=outputs_neg_squeeze_cast_topk_class + outputs_cast_topk_indices,
+        outputs=outputs_topk_select_classes_indices)
+    node_list.append(node_topk_select_topk_indices)
+    outputs_concat_final_results = outputs['Out']
+    node_concat_final_results = onnx.helper.make_node(
+        'Gather',
+        inputs=outputs_sort_by_socre_results +
+        [outputs_topk_select_classes_indices[1]],
+        outputs=outputs_concat_final_results,
+        axis=1)
+    node_list.append(node_concat_final_results)
+    return node_list

+ 56 - 12
paddlex/cv/models/deeplabv3p.py

@@ -37,7 +37,7 @@ class DeepLabv3p(BaseAPI):
         num_classes (int): 类别数。
         backbone (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41',
             'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5',
-            'MobileNetV2_x2.0']。默认'MobileNetV2_x1.0'。
+            'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld']。默认'MobileNetV2_x1.0'。
         output_stride (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。
         aspp_with_sep_conv (bool):  在asspp模块是否采用separable convolutions。默认True。
         decoder_use_sep_conv (bool): decoder模块是否采用separable convolutions。默认True。
@@ -51,10 +51,13 @@ class DeepLabv3p(BaseAPI):
             自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1,
             即平时使用的交叉熵损失函数。
         ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
+        pooling_crop_size (list): 当backbone为MobileNetV3_large_x1_0_ssld时,需设置为训练过程中模型输入大小, 格式为[W, H]。
+            在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用'pool'算子得到平均值。
+            默认值为None。
     Raises:
         ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
         ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25',
-            'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']之内。
+            'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld']之内。
         ValueError: class_weight为list, 但长度不等于num_class。
                 class_weight为str, 但class_weight.low()不等于dynamic。
         TypeError: class_weight不为None时,其类型不是list或str。
@@ -71,7 +74,8 @@ class DeepLabv3p(BaseAPI):
                  use_bce_loss=False,
                  use_dice_loss=False,
                  class_weight=None,
-                 ignore_index=255):
+                 ignore_index=255,
+                 pooling_crop_size=None):
         self.init_params = locals()
         super(DeepLabv3p, self).__init__('segmenter')
         # dice_loss或bce_loss只适用两类分割中
@@ -85,12 +89,12 @@ class DeepLabv3p(BaseAPI):
         if backbone not in [
                 'Xception65', 'Xception41', 'MobileNetV2_x0.25',
                 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5',
-                'MobileNetV2_x2.0'
+                'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld'
         ]:
             raise ValueError(
                 "backbone: {} is set wrong. it should be one of "
                 "('Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',"
-                " 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0')".
+                " 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld')".
                 format(backbone))
 
         if class_weight is not None:
@@ -121,6 +125,30 @@ class DeepLabv3p(BaseAPI):
         self.labels = None
         self.sync_bn = True
         self.fixed_input_shape = None
+        self.pooling_stride = [1, 1]
+        self.pooling_crop_size = pooling_crop_size
+        self.aspp_with_se = False
+        self.se_use_qsigmoid = False
+        self.aspp_convs_filters = 256
+        self.aspp_with_concat_projection = True
+        self.add_image_level_feature = True
+        self.use_sum_merge = False
+        self.conv_filters = 256
+        self.output_is_logits = False
+        self.backbone_lr_mult_list = None
+        if 'MobileNetV3' in backbone:
+            self.output_stride = 32
+            self.pooling_stride = (4, 5)
+            self.aspp_with_se = True
+            self.se_use_qsigmoid = True
+            self.aspp_convs_filters = 128
+            self.aspp_with_concat_projection = False
+            self.add_image_level_feature = False
+            self.use_sum_merge = True
+            self.output_is_logits = True
+            if self.output_is_logits:
+                self.conv_filters = self.num_classes
+            self.backbone_lr_mult_list = [0.15, 0.35, 0.65, 0.85, 1]
 
     def _get_backbone(self, backbone):
         def mobilenetv2(backbone):
@@ -167,10 +195,22 @@ class DeepLabv3p(BaseAPI):
                 end_points=end_points,
                 decode_points=decode_points)
 
+        def mobilenetv3(backbone):
+            scale = 1.0
+            lr_mult_list = self.backbone_lr_mult_list
+            return paddlex.cv.nets.MobileNetV3(
+                scale=scale,
+                model_name='large',
+                output_stride=self.output_stride,
+                lr_mult_list=lr_mult_list,
+                for_seg=True)
+
         if 'Xception' in backbone:
             return xception(backbone)
         elif 'MobileNetV2' in backbone:
             return mobilenetv2(backbone)
+        elif 'MobileNetV3' in backbone:
+            return mobilenetv3(backbone)
 
     def build_net(self, mode='train'):
         model = paddlex.cv.nets.segmentation.DeepLabv3p(
@@ -186,7 +226,17 @@ class DeepLabv3p(BaseAPI):
             use_dice_loss=self.use_dice_loss,
             class_weight=self.class_weight,
             ignore_index=self.ignore_index,
-            fixed_input_shape=self.fixed_input_shape)
+            fixed_input_shape=self.fixed_input_shape,
+            pooling_stride=self.pooling_stride,
+            pooling_crop_size=self.pooling_crop_size,
+            aspp_with_se=self.aspp_with_se,
+            se_use_qsigmoid=self.se_use_qsigmoid,
+            aspp_convs_filters=self.aspp_convs_filters,
+            aspp_with_concat_projection=self.aspp_with_concat_projection,
+            add_image_level_feature=self.add_image_level_feature,
+            use_sum_merge=self.use_sum_merge,
+            conv_filters=self.conv_filters,
+            output_is_logits=self.output_is_logits)
         inputs = model.generate_inputs()
         model_out = model.build_net(inputs)
         outputs = OrderedDict()
@@ -370,9 +420,6 @@ class DeepLabv3p(BaseAPI):
                     elif info[0] == 'padding':
                         w, h = info[1][1], info[1][0]
                         one_pred = one_pred[0:h, 0:w]
-                    else:
-                        raise Exception(
-                            "Unexpected info '{}' in im_info".format(info[0]))
                 one_pred = one_pred.astype('int64')
                 one_pred = one_pred[np.newaxis, :, :, np.newaxis]
                 one_label = one_label[np.newaxis, np.newaxis, :, :]
@@ -430,9 +477,6 @@ class DeepLabv3p(BaseAPI):
                     w, h = info[1][1], info[1][0]
                     pred = pred[0:h, 0:w]
                     logit = logit[0:h, 0:w, :]
-                else:
-                    raise Exception("Unexpected info '{}' in im_info".format(
-                        info[0]))
             pred_list.append(pred)
             logit_list.append(logit)
 

+ 26 - 0
paddlex/cv/models/slim/prune_config.py

@@ -243,6 +243,32 @@ def get_prune_params(model):
         for i in params_not_prune:
             if i in prune_names:
                 prune_names.remove(i)
+    
+    elif model_type.startswith('HRNet'):
+        for param in program.global_block().all_parameters():
+            if 'weight' not in param.name:
+                continue
+            prune_names.append(param.name)
+        params_not_prune = [
+            'conv-1_weights'
+        ]
+        for i in params_not_prune:
+            if i in prune_names:
+                prune_names.remove(i)
+    
+    elif model_type.startswith('FastSCNN'):
+        for param in program.global_block().all_parameters():
+            if 'weight' not in param.name:
+                continue
+            if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
+                continue
+            prune_names.append(param.name)
+        params_not_prune = [
+            'classifier/weights'
+        ]
+        for i in params_not_prune:
+            if i in prune_names:
+                prune_names.remove(i)
 
     elif model_type.startswith('DeepLabv3p'):
         for param in program.global_block().all_parameters():

+ 4 - 1
paddlex/cv/models/utils/pretrain_weights.py

@@ -122,6 +122,8 @@ coco_pretrain = {
 }
 
 cityscapes_pretrain = {
+    'DeepLabv3p_MobileNetV3_large_x1_0_ssld_CITYSCAPES':
+    'https://paddleseg.bj.bcebos.com/models/deeplabv3p_mobilenetv3_large_cityscapes.tar.gz',
     'DeepLabv3p_MobileNetV2_x1.0_CITYSCAPES':
     'https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz',
     'DeepLabv3p_Xception65_CITYSCAPES':
@@ -144,7 +146,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
     if flag == 'COCO':
         if class_name == 'DeepLabv3p' and backbone in [
                 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
-                'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
+                'MobileNetV2_x1.5', 'MobileNetV2_x2.0',
+                'MobileNetV3_large_x1_0_ssld'
         ]:
             model_name = '{}_{}'.format(class_name, backbone)
             logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))

+ 1 - 1
paddlex/cv/nets/detection/yolo_v3.py

@@ -311,7 +311,7 @@ class YOLOv3:
 
     def _upsample(self, input, scale=2, name=None):
         out = fluid.layers.resize_nearest(
-            input=input, scale=float(scale), name=name)
+            input=input, scale=float(scale), name=name, align_corners=False)
         return out
 
     def _detection_block(self,

+ 5 - 2
paddlex/cv/nets/hrnet.py

@@ -235,10 +235,13 @@ class HRNet(object):
                         name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
                     if self.feature_maps == "stage4":
                         y = fluid.layers.resize_bilinear(
-                            input=y, out_shape=[height, width])
+                            input=y,
+                            out_shape=[height, width],
+                            align_corners=False,
+                            align_mode=1)
                     else:
                         y = fluid.layers.resize_nearest(
-                            input=y, scale=2**(j - i))
+                            input=y, scale=2**(j - i), align_corners=False)
                     residual = fluid.layers.elementwise_add(
                         x=residual, y=y, act=None)
                 elif j < i:

+ 138 - 47
paddlex/cv/nets/mobilenet_v3.py

@@ -42,7 +42,9 @@ class MobileNetV3():
                  extra_block_filters=[[256, 512], [128, 256], [128, 256],
                                       [64, 128]],
                  num_classes=None,
-                 lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
+                 lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
+                 for_seg=False,
+                 output_stride=None):
         assert len(lr_mult_list) == 5, \
             "lr_mult_list length in MobileNetV3 must be 5 but got {}!!".format(
             len(lr_mult_list))
@@ -57,48 +59,112 @@ class MobileNetV3():
         self.num_classes = num_classes
         self.lr_mult_list = lr_mult_list
         self.curr_stage = 0
-        if model_name == "large":
-            self.cfg = [
-                # kernel_size, expand, channel, se_block, act_mode, stride
-                [3, 16, 16, False, 'relu', 1],
-                [3, 64, 24, False, 'relu', 2],
-                [3, 72, 24, False, 'relu', 1],
-                [5, 72, 40, True, 'relu', 2],
-                [5, 120, 40, True, 'relu', 1],
-                [5, 120, 40, True, 'relu', 1],
-                [3, 240, 80, False, 'hard_swish', 2],
-                [3, 200, 80, False, 'hard_swish', 1],
-                [3, 184, 80, False, 'hard_swish', 1],
-                [3, 184, 80, False, 'hard_swish', 1],
-                [3, 480, 112, True, 'hard_swish', 1],
-                [3, 672, 112, True, 'hard_swish', 1],
-                [5, 672, 160, True, 'hard_swish', 2],
-                [5, 960, 160, True, 'hard_swish', 1],
-                [5, 960, 160, True, 'hard_swish', 1],
-            ]
-            self.cls_ch_squeeze = 960
-            self.cls_ch_expand = 1280
-            self.lr_interval = 3
-        elif model_name == "small":
-            self.cfg = [
-                # kernel_size, expand, channel, se_block, act_mode, stride
-                [3, 16, 16, True, 'relu', 2],
-                [3, 72, 24, False, 'relu', 2],
-                [3, 88, 24, False, 'relu', 1],
-                [5, 96, 40, True, 'hard_swish', 2],
-                [5, 240, 40, True, 'hard_swish', 1],
-                [5, 240, 40, True, 'hard_swish', 1],
-                [5, 120, 48, True, 'hard_swish', 1],
-                [5, 144, 48, True, 'hard_swish', 1],
-                [5, 288, 96, True, 'hard_swish', 2],
-                [5, 576, 96, True, 'hard_swish', 1],
-                [5, 576, 96, True, 'hard_swish', 1],
-            ]
-            self.cls_ch_squeeze = 576
-            self.cls_ch_expand = 1280
-            self.lr_interval = 2
+        self.for_seg = for_seg
+        self.decode_point = None
+
+        if self.for_seg:
+            if model_name == "large":
+                self.cfg = [
+                    # k, exp, c,  se,     nl,  s,
+                    [3, 16, 16, False, 'relu', 1],
+                    [3, 64, 24, False, 'relu', 2],
+                    [3, 72, 24, False, 'relu', 1],
+                    [5, 72, 40, True, 'relu', 2],
+                    [5, 120, 40, True, 'relu', 1],
+                    [5, 120, 40, True, 'relu', 1],
+                    [3, 240, 80, False, 'hard_swish', 2],
+                    [3, 200, 80, False, 'hard_swish', 1],
+                    [3, 184, 80, False, 'hard_swish', 1],
+                    [3, 184, 80, False, 'hard_swish', 1],
+                    [3, 480, 112, True, 'hard_swish', 1],
+                    [3, 672, 112, True, 'hard_swish', 1],
+                    # The number of channels in the last 4 stages is reduced by a
+                    # factor of 2 compared to the standard implementation.
+                    [5, 336, 80, True, 'hard_swish', 2],
+                    [5, 480, 80, True, 'hard_swish', 1],
+                    [5, 480, 80, True, 'hard_swish', 1],
+                ]
+                self.cls_ch_squeeze = 480
+                self.cls_ch_expand = 1280
+                self.lr_interval = 3
+            elif model_name == "small":
+                self.cfg = [
+                    # k, exp, c,  se,     nl,  s,
+                    [3, 16, 16, True, 'relu', 2],
+                    [3, 72, 24, False, 'relu', 2],
+                    [3, 88, 24, False, 'relu', 1],
+                    [5, 96, 40, True, 'hard_swish', 2],
+                    [5, 240, 40, True, 'hard_swish', 1],
+                    [5, 240, 40, True, 'hard_swish', 1],
+                    [5, 120, 48, True, 'hard_swish', 1],
+                    [5, 144, 48, True, 'hard_swish', 1],
+                    # The number of channels in the last 4 stages is reduced by a
+                    # factor of 2 compared to the standard implementation.
+                    [5, 144, 48, True, 'hard_swish', 2],
+                    [5, 288, 48, True, 'hard_swish', 1],
+                    [5, 288, 48, True, 'hard_swish', 1],
+                ]
+            else:
+                raise NotImplementedError
         else:
-            raise NotImplementedError
+            if model_name == "large":
+                self.cfg = [
+                    # kernel_size, expand, channel, se_block, act_mode, stride
+                    [3, 16, 16, False, 'relu', 1],
+                    [3, 64, 24, False, 'relu', 2],
+                    [3, 72, 24, False, 'relu', 1],
+                    [5, 72, 40, True, 'relu', 2],
+                    [5, 120, 40, True, 'relu', 1],
+                    [5, 120, 40, True, 'relu', 1],
+                    [3, 240, 80, False, 'hard_swish', 2],
+                    [3, 200, 80, False, 'hard_swish', 1],
+                    [3, 184, 80, False, 'hard_swish', 1],
+                    [3, 184, 80, False, 'hard_swish', 1],
+                    [3, 480, 112, True, 'hard_swish', 1],
+                    [3, 672, 112, True, 'hard_swish', 1],
+                    [5, 672, 160, True, 'hard_swish', 2],
+                    [5, 960, 160, True, 'hard_swish', 1],
+                    [5, 960, 160, True, 'hard_swish', 1],
+                ]
+                self.cls_ch_squeeze = 960
+                self.cls_ch_expand = 1280
+                self.lr_interval = 3
+            elif model_name == "small":
+                self.cfg = [
+                    # kernel_size, expand, channel, se_block, act_mode, stride
+                    [3, 16, 16, True, 'relu', 2],
+                    [3, 72, 24, False, 'relu', 2],
+                    [3, 88, 24, False, 'relu', 1],
+                    [5, 96, 40, True, 'hard_swish', 2],
+                    [5, 240, 40, True, 'hard_swish', 1],
+                    [5, 240, 40, True, 'hard_swish', 1],
+                    [5, 120, 48, True, 'hard_swish', 1],
+                    [5, 144, 48, True, 'hard_swish', 1],
+                    [5, 288, 96, True, 'hard_swish', 2],
+                    [5, 576, 96, True, 'hard_swish', 1],
+                    [5, 576, 96, True, 'hard_swish', 1],
+                ]
+                self.cls_ch_squeeze = 576
+                self.cls_ch_expand = 1280
+                self.lr_interval = 2
+            else:
+                raise NotImplementedError
+
+        if self.for_seg:
+            self.modify_bottle_params(output_stride)
+
+    def modify_bottle_params(self, output_stride=None):
+        if output_stride is not None and output_stride % 2 != 0:
+            raise Exception("output stride must to be even number")
+        if output_stride is None:
+            return
+        else:
+            stride = 2
+            for i, _cfg in enumerate(self.cfg):
+                stride = stride * _cfg[-1]
+                if stride > output_stride:
+                    s = 1
+                    self.cfg[i][-1] = s
 
     def _conv_bn_layer(self,
                        input,
@@ -153,6 +219,14 @@ class MobileNetV3():
                 bn = fluid.layers.relu6(bn)
         return bn
 
+    def make_divisible(self, v, divisor=8, min_value=None):
+        if min_value is None:
+            min_value = divisor
+        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+        if new_v < 0.9 * v:
+            new_v += divisor
+        return new_v
+
     def _hard_swish(self, x):
         return x * fluid.layers.relu6(x + 3) / 6.
 
@@ -220,6 +294,9 @@ class MobileNetV3():
             use_cudnn=False,
             name=name + '_depthwise')
 
+        if self.curr_stage == 5:
+            self.decode_point = conv1
+
         if use_se:
             conv1 = self._se_block(
                 input=conv1, num_out_filter=num_mid_filter, name=name + '_se')
@@ -282,7 +359,7 @@ class MobileNetV3():
         conv = self._conv_bn_layer(
             input,
             filter_size=3,
-            num_filters=inplanes if scale <= 1.0 else int(inplanes * scale),
+            num_filters=self.make_divisible(inplanes * scale),
             stride=2,
             padding=1,
             num_groups=1,
@@ -290,6 +367,7 @@ class MobileNetV3():
             act='hard_swish',
             name='conv1')
         i = 0
+        inplanes = self.make_divisible(inplanes * scale)
         for layer_cfg in cfg:
             self.block_stride *= layer_cfg[5]
             if layer_cfg[5] == 2:
@@ -297,19 +375,32 @@ class MobileNetV3():
             conv = self._residual_unit(
                 input=conv,
                 num_in_filter=inplanes,
-                num_mid_filter=int(scale * layer_cfg[1]),
-                num_out_filter=int(scale * layer_cfg[2]),
+                num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
+                num_out_filter=self.make_divisible(scale * layer_cfg[2]),
                 act=layer_cfg[4],
                 stride=layer_cfg[5],
                 filter_size=layer_cfg[0],
                 use_se=layer_cfg[3],
                 name='conv' + str(i + 2))
-
-            inplanes = int(scale * layer_cfg[2])
+            inplanes = self.make_divisible(scale * layer_cfg[2])
             i += 1
             self.curr_stage = i
         blocks.append(conv)
 
+        if self.for_seg:
+            conv = self._conv_bn_layer(
+                input=conv,
+                filter_size=1,
+                num_filters=self.make_divisible(scale * self.cls_ch_squeeze),
+                stride=1,
+                padding=0,
+                num_groups=1,
+                if_act=True,
+                act='hard_swish',
+                name='conv_last')
+
+            return conv, self.decode_point
+
         if self.num_classes:
             conv = self._conv_bn_layer(
                 input=conv,

+ 216 - 124
paddlex/cv/nets/segmentation/deeplabv3p.py

@@ -21,7 +21,7 @@ from collections import OrderedDict
 
 import paddle.fluid as fluid
 from .model_utils.libs import scope, name_scope
-from .model_utils.libs import bn, bn_relu, relu
+from .model_utils.libs import bn, bn_relu, relu, qsigmoid
 from .model_utils.libs import conv, max_pool, deconv
 from .model_utils.libs import separate_conv
 from .model_utils.libs import sigmoid_to_softmax
@@ -82,7 +82,17 @@ class DeepLabv3p(object):
                  use_dice_loss=False,
                  class_weight=None,
                  ignore_index=255,
-                 fixed_input_shape=None):
+                 fixed_input_shape=None,
+                 pooling_stride=[1, 1],
+                 pooling_crop_size=None,
+                 aspp_with_se=False,
+                 se_use_qsigmoid=False,
+                 aspp_convs_filters=256,
+                 aspp_with_concat_projection=True,
+                 add_image_level_feature=True,
+                 use_sum_merge=False,
+                 conv_filters=256,
+                 output_is_logits=False):
         # dice_loss或bce_loss只适用两类分割中
         if num_classes > 2 and (use_bce_loss or use_dice_loss):
             raise ValueError(
@@ -117,6 +127,17 @@ class DeepLabv3p(object):
         self.encoder_with_aspp = encoder_with_aspp
         self.enable_decoder = enable_decoder
         self.fixed_input_shape = fixed_input_shape
+        self.output_is_logits = output_is_logits
+        self.aspp_convs_filters = aspp_convs_filters
+        self.output_stride = output_stride
+        self.pooling_crop_size = pooling_crop_size
+        self.pooling_stride = pooling_stride
+        self.se_use_qsigmoid = se_use_qsigmoid
+        self.aspp_with_concat_projection = aspp_with_concat_projection
+        self.add_image_level_feature = add_image_level_feature
+        self.aspp_with_se = aspp_with_se
+        self.use_sum_merge = use_sum_merge
+        self.conv_filters = conv_filters
 
     def _encoder(self, input):
         # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
@@ -129,19 +150,36 @@ class DeepLabv3p(object):
         elif self.output_stride == 8:
             aspp_ratios = [12, 24, 36]
         else:
-            raise Exception("DeepLabv3p only support stride 8 or 16")
+            aspp_ratios = []
 
         param_attr = fluid.ParamAttr(
             name=name_scope + 'weights',
             regularizer=None,
             initializer=fluid.initializer.TruncatedNormal(
                 loc=0.0, scale=0.06))
+
+        concat_logits = []
         with scope('encoder'):
-            channel = 256
+            channel = self.aspp_convs_filters
             with scope("image_pool"):
-                image_avg = fluid.layers.reduce_mean(
-                    input, [2, 3], keep_dim=True)
-                image_avg = bn_relu(
+                if self.pooling_crop_size is None:
+                    image_avg = fluid.layers.reduce_mean(
+                        input, [2, 3], keep_dim=True)
+                else:
+                    pool_w = int((self.pooling_crop_size[0] - 1.0) /
+                                 self.output_stride + 1.0)
+                    pool_h = int((self.pooling_crop_size[1] - 1.0) /
+                                 self.output_stride + 1.0)
+                    image_avg = fluid.layers.pool2d(
+                        input,
+                        pool_size=(pool_h, pool_w),
+                        pool_stride=self.pooling_stride,
+                        pool_type='avg',
+                        pool_padding='VALID')
+
+                act = qsigmoid if self.se_use_qsigmoid else bn_relu
+
+                image_avg = act(
                     conv(
                         image_avg,
                         channel,
@@ -153,6 +191,8 @@ class DeepLabv3p(object):
                 input_shape = fluid.layers.shape(input)
                 image_avg = fluid.layers.resize_bilinear(image_avg,
                                                          input_shape[2:])
+                if self.add_image_level_feature:
+                    concat_logits.append(image_avg)
 
             with scope("aspp0"):
                 aspp0 = bn_relu(
@@ -164,77 +204,160 @@ class DeepLabv3p(object):
                         groups=1,
                         padding=0,
                         param_attr=param_attr))
-            with scope("aspp1"):
-                if self.aspp_with_sep_conv:
-                    aspp1 = separate_conv(
-                        input,
-                        channel,
-                        1,
-                        3,
-                        dilation=aspp_ratios[0],
-                        act=relu)
-                else:
-                    aspp1 = bn_relu(
-                        conv(
+                concat_logits.append(aspp0)
+
+            if aspp_ratios:
+                with scope("aspp1"):
+                    if self.aspp_with_sep_conv:
+                        aspp1 = separate_conv(
                             input,
                             channel,
-                            stride=1,
-                            filter_size=3,
+                            1,
+                            3,
                             dilation=aspp_ratios[0],
-                            padding=aspp_ratios[0],
-                            param_attr=param_attr))
-            with scope("aspp2"):
-                if self.aspp_with_sep_conv:
-                    aspp2 = separate_conv(
-                        input,
-                        channel,
-                        1,
-                        3,
-                        dilation=aspp_ratios[1],
-                        act=relu)
-                else:
-                    aspp2 = bn_relu(
-                        conv(
+                            act=relu)
+                    else:
+                        aspp1 = bn_relu(
+                            conv(
+                                input,
+                                channel,
+                                stride=1,
+                                filter_size=3,
+                                dilation=aspp_ratios[0],
+                                padding=aspp_ratios[0],
+                                param_attr=param_attr))
+                    concat_logits.append(aspp1)
+                with scope("aspp2"):
+                    if self.aspp_with_sep_conv:
+                        aspp2 = separate_conv(
                             input,
                             channel,
-                            stride=1,
-                            filter_size=3,
+                            1,
+                            3,
                             dilation=aspp_ratios[1],
-                            padding=aspp_ratios[1],
-                            param_attr=param_attr))
-            with scope("aspp3"):
-                if self.aspp_with_sep_conv:
-                    aspp3 = separate_conv(
-                        input,
-                        channel,
-                        1,
-                        3,
-                        dilation=aspp_ratios[2],
-                        act=relu)
-                else:
-                    aspp3 = bn_relu(
-                        conv(
+                            act=relu)
+                    else:
+                        aspp2 = bn_relu(
+                            conv(
+                                input,
+                                channel,
+                                stride=1,
+                                filter_size=3,
+                                dilation=aspp_ratios[1],
+                                padding=aspp_ratios[1],
+                                param_attr=param_attr))
+                    concat_logits.append(aspp2)
+                with scope("aspp3"):
+                    if self.aspp_with_sep_conv:
+                        aspp3 = separate_conv(
                             input,
                             channel,
-                            stride=1,
-                            filter_size=3,
+                            1,
+                            3,
                             dilation=aspp_ratios[2],
-                            padding=aspp_ratios[2],
-                            param_attr=param_attr))
+                            act=relu)
+                    else:
+                        aspp3 = bn_relu(
+                            conv(
+                                input,
+                                channel,
+                                stride=1,
+                                filter_size=3,
+                                dilation=aspp_ratios[2],
+                                padding=aspp_ratios[2],
+                                param_attr=param_attr))
+                    concat_logits.append(aspp3)
+
             with scope("concat"):
-                data = fluid.layers.concat(
-                    [image_avg, aspp0, aspp1, aspp2, aspp3], axis=1)
-                data = bn_relu(
+                data = fluid.layers.concat(concat_logits, axis=1)
+                if self.aspp_with_concat_projection:
+                    data = bn_relu(
+                        conv(
+                            data,
+                            channel,
+                            1,
+                            1,
+                            groups=1,
+                            padding=0,
+                            param_attr=param_attr))
+                    data = fluid.layers.dropout(data, 0.9)
+            if self.aspp_with_se:
+                data = data * image_avg
+            return data
+
+    def _decoder_with_sum_merge(self, encode_data, decode_shortcut,
+                                param_attr):
+        decode_shortcut_shape = fluid.layers.shape(decode_shortcut)
+        encode_data = fluid.layers.resize_bilinear(encode_data,
+                                                   decode_shortcut_shape[2:])
+
+        encode_data = conv(
+            encode_data,
+            self.conv_filters,
+            1,
+            1,
+            groups=1,
+            padding=0,
+            param_attr=param_attr)
+
+        with scope('merge'):
+            decode_shortcut = conv(
+                decode_shortcut,
+                self.conv_filters,
+                1,
+                1,
+                groups=1,
+                padding=0,
+                param_attr=param_attr)
+
+            return encode_data + decode_shortcut
+
+    def _decoder_with_concat(self, encode_data, decode_shortcut, param_attr):
+        with scope('concat'):
+            decode_shortcut = bn_relu(
+                conv(
+                    decode_shortcut,
+                    48,
+                    1,
+                    1,
+                    groups=1,
+                    padding=0,
+                    param_attr=param_attr))
+
+            decode_shortcut_shape = fluid.layers.shape(decode_shortcut)
+            encode_data = fluid.layers.resize_bilinear(
+                encode_data, decode_shortcut_shape[2:])
+            encode_data = fluid.layers.concat(
+                [encode_data, decode_shortcut], axis=1)
+        if self.decoder_use_sep_conv:
+            with scope("separable_conv1"):
+                encode_data = separate_conv(
+                    encode_data, self.conv_filters, 1, 3, dilation=1, act=relu)
+            with scope("separable_conv2"):
+                encode_data = separate_conv(
+                    encode_data, self.conv_filters, 1, 3, dilation=1, act=relu)
+        else:
+            with scope("decoder_conv1"):
+                encode_data = bn_relu(
                     conv(
-                        data,
-                        channel,
-                        1,
-                        1,
-                        groups=1,
-                        padding=0,
+                        encode_data,
+                        self.conv_filters,
+                        stride=1,
+                        filter_size=3,
+                        dilation=1,
+                        padding=1,
                         param_attr=param_attr))
-                data = fluid.layers.dropout(data, 0.9)
-            return data
+            with scope("decoder_conv2"):
+                encode_data = bn_relu(
+                    conv(
+                        encode_data,
+                        self.conv_filters,
+                        stride=1,
+                        filter_size=3,
+                        dilation=1,
+                        padding=1,
+                        param_attr=param_attr))
+        return encode_data
 
     def _decoder(self, encode_data, decode_shortcut):
         # 解码器配置
@@ -246,52 +369,14 @@ class DeepLabv3p(object):
             regularizer=None,
             initializer=fluid.initializer.TruncatedNormal(
                 loc=0.0, scale=0.06))
+
         with scope('decoder'):
-            with scope('concat'):
-                decode_shortcut = bn_relu(
-                    conv(
-                        decode_shortcut,
-                        48,
-                        1,
-                        1,
-                        groups=1,
-                        padding=0,
-                        param_attr=param_attr))
+            if self.use_sum_merge:
+                return self._decoder_with_sum_merge(
+                    encode_data, decode_shortcut, param_attr)
 
-                decode_shortcut_shape = fluid.layers.shape(decode_shortcut)
-                encode_data = fluid.layers.resize_bilinear(
-                    encode_data, decode_shortcut_shape[2:])
-                encode_data = fluid.layers.concat(
-                    [encode_data, decode_shortcut], axis=1)
-            if self.decoder_use_sep_conv:
-                with scope("separable_conv1"):
-                    encode_data = separate_conv(
-                        encode_data, 256, 1, 3, dilation=1, act=relu)
-                with scope("separable_conv2"):
-                    encode_data = separate_conv(
-                        encode_data, 256, 1, 3, dilation=1, act=relu)
-            else:
-                with scope("decoder_conv1"):
-                    encode_data = bn_relu(
-                        conv(
-                            encode_data,
-                            256,
-                            stride=1,
-                            filter_size=3,
-                            dilation=1,
-                            padding=1,
-                            param_attr=param_attr))
-                with scope("decoder_conv2"):
-                    encode_data = bn_relu(
-                        conv(
-                            encode_data,
-                            256,
-                            stride=1,
-                            filter_size=3,
-                            dilation=1,
-                            padding=1,
-                            param_attr=param_attr))
-            return encode_data
+            return self._decoder_with_concat(encode_data, decode_shortcut,
+                                             param_attr)
 
     def _get_loss(self, logit, label, mask):
         avg_loss = 0
@@ -335,8 +420,11 @@ class DeepLabv3p(object):
             self.num_classes = 1
         image = inputs['image']
 
-        data, decode_shortcuts = self.backbone(image)
-        decode_shortcut = decode_shortcuts[self.backbone.decode_points]
+        if 'MobileNetV3' in self.backbone.__class__.__name__:
+            data, decode_shortcut = self.backbone(image)
+        else:
+            data, decode_shortcuts = self.backbone(image)
+            decode_shortcut = decode_shortcuts[self.backbone.decode_points]
 
         # 编码器解码器设置
         if self.encoder_with_aspp:
@@ -351,18 +439,22 @@ class DeepLabv3p(object):
                 regularization_coeff=0.0),
             initializer=fluid.initializer.TruncatedNormal(
                 loc=0.0, scale=0.01))
-        with scope('logit'):
-            with fluid.name_scope('last_conv'):
-                logit = conv(
-                    data,
-                    self.num_classes,
-                    1,
-                    stride=1,
-                    padding=0,
-                    bias_attr=True,
-                    param_attr=param_attr)
-            image_shape = fluid.layers.shape(image)
-            logit = fluid.layers.resize_bilinear(logit, image_shape[2:])
+        if not self.output_is_logits:
+            with scope('logit'):
+                with fluid.name_scope('last_conv'):
+                    logit = conv(
+                        data,
+                        self.num_classes,
+                        1,
+                        stride=1,
+                        padding=0,
+                        bias_attr=True,
+                        param_attr=param_attr)
+        else:
+            logit = data
+
+        image_shape = fluid.layers.shape(image)
+        logit = fluid.layers.resize_bilinear(logit, image_shape[2:])
 
         if self.num_classes == 1:
             out = sigmoid_to_softmax(logit)

+ 4 - 0
paddlex/cv/nets/segmentation/model_utils/libs.py

@@ -112,6 +112,10 @@ def bn_relu(data, norm_type='bn', eps=1e-5):
     return fluid.layers.relu(bn(data, norm_type=norm_type, eps=eps))
 
 
+def qsigmoid(data):
+    return fluid.layers.relu6(data + 3) * 0.16667
+
+
 def relu(data):
     return fluid.layers.relu(data)
 

+ 2 - 2
paddlex/cv/transforms/seg_transforms.py

@@ -73,8 +73,6 @@ class Compose(SegTransform):
             tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
         """
 
-        if im_info is None:
-            im_info = list()
         if isinstance(im, np.ndarray):
             if len(im.shape) != 3:
                 raise Exception(
@@ -86,6 +84,8 @@ class Compose(SegTransform):
             except:
                 raise ValueError('Can\'t read The image file {}!'.format(im))
         im = im.astype('float32')
+        if im_info is None:
+            im_info = [('origin_shape', im.shape[0:2])]
         if self.to_rgb:
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         if label is not None:

+ 3 - 3
paddlex/cv/transforms/visualize.py

@@ -191,9 +191,9 @@ def det_compose(im,
                     bboxes[:, 3] = bboxes[:, 3] * h_scale
                 else:
                     bboxes = outputs[2]['gt_bbox']
-                if not isinstance(
-                        op,
-                        pdx.cv.transforms.det_transforms.RandomHorizontalFlip):
+                if not isinstance(op, (
+                        pdx.cv.transforms.det_transforms.RandomHorizontalFlip,
+                        pdx.cv.transforms.det_transforms.Padding)):
                     for i in range(bboxes.shape[0]):
                         bbox = bboxes[i]
                         cname = labels[outputs[2]['gt_class'][i][0] - 1]

+ 7 - 5
paddlex/deploy.py

@@ -19,7 +19,9 @@ import yaml
 import paddlex
 import paddle.fluid as fluid
 from paddlex.cv.transforms import build_transforms
-from paddlex.cv.models import BaseClassifier, YOLOv3, FasterRCNN, MaskRCNN, DeepLabv3p
+from paddlex.cv.models import BaseClassifier
+from paddlex.cv.models import PPYOLO, FasterRCNN, MaskRCNN
+from paddlex.cv.models import DeepLabv3p
 
 
 class Predictor:
@@ -129,8 +131,8 @@ class Predictor:
                 thread_num=thread_num)
             res['image'] = im
         elif self.model_type == "detector":
-            if self.model_name == "YOLOv3":
-                im, im_size = YOLOv3._preprocess(
+            if self.model_name in ["PPYOLO", "YOLOv3"]:
+                im, im_size = PPYOLO._preprocess(
                     image,
                     self.transforms,
                     self.model_type,
@@ -190,8 +192,8 @@ class Predictor:
             res = {'bbox': (results[0][0], offset_to_lengths(results[0][1])), }
             res['im_id'] = (np.array(
                 [[i] for i in range(batch_size)]).astype('int32'), [[]])
-            if self.model_name == "YOLOv3":
-                preds = YOLOv3._postprocess(res, batch_size, self.num_classes,
+            if self.model_name in ["PPYOLO", "YOLOv3"]:
+                preds = PPYOLO._postprocess(res, batch_size, self.num_classes,
                                             self.labels)
             elif self.model_name == "FasterRCNN":
                 preds = FasterRCNN._postprocess(res, batch_size,

+ 1 - 0
requirements.txt

@@ -8,3 +8,4 @@ paddleslim == 1.0.1
 shapely
 x2paddle
 paddlepaddle-gpu
+opencv-python

+ 1 - 1
setup.py

@@ -31,7 +31,7 @@ setuptools.setup(
     install_requires=[
         "pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm',
         'paddleslim==1.0.1', 'visualdl>=2.0.0b', 'paddlehub>=1.6.2',
-        'shapely>=1.7.0'
+        'shapely>=1.7.0', "opencv-python"
     ],
     classifiers=[
         "Programming Language :: Python :: 3",

+ 1 - 0
tutorials/train/README.md

@@ -20,6 +20,7 @@
 |instance_segmentation/mask_rcnn_r18_fpn.py | 实例分割MaskRCNN | 小度熊分拣 |
 |instance_segmentation/mask_rcnn_f50_fpn.py | 实例分割MaskRCNN | 小度熊分拣 |
 |semantic_segmentation/deeplabv3p_mobilenetv2.py | 语义分割DeepLabV3 | 视盘分割 |
+|semantic_segmentation/deeplabv3p_mobilenetv2.py | 语义分割DeepLabV3 | 视盘分割 |
 |semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py | 语义分割DeepLabV3 | 视盘分割 |
 |semantic_segmentation/deeplabv3p_xception65.py | 语义分割DeepLabV3 | 视盘分割 |
 |semantic_segmentation/fast_scnn.py | 语义分割FastSCNN | 视盘分割 |

+ 58 - 0
tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv3_large_ssld.py

@@ -0,0 +1,58 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import paddlex as pdx
+from paddlex.seg import transforms
+
+# 下载和解压视盘分割数据集
+optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
+pdx.utils.download_and_decompress(optic_dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
+train_transforms = transforms.Compose([
+    transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
+    transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
+])
+
+eval_transforms = transforms.Compose([
+    transforms.ResizeByLong(long_size=512),
+    transforms.Padding(target_size=512), transforms.Normalize()
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset
+train_dataset = pdx.datasets.SegDataset(
+    data_dir='optic_disc_seg',
+    file_list='optic_disc_seg/train_list.txt',
+    label_list='optic_disc_seg/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.SegDataset(
+    data_dir='optic_disc_seg',
+    file_list='optic_disc_seg/val_list.txt',
+    label_list='optic_disc_seg/labels.txt',
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
+num_classes = len(train_dataset.labels)
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p
+model = pdx.seg.DeepLabv3p(
+    num_classes=num_classes,
+    backbone='MobileNetV3_large_x1_0_ssld',
+    pooling_crop_size=(512, 512))
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
+# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
+model.train(
+    num_epochs=40,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    learning_rate=0.01,
+    save_dir='output/deeplabv3p_mobilenetv3_large_ssld',
+    use_vdl=True)