浏览代码

add human segmentation

FlyingQianMM 5 年之前
父节点
当前提交
3126248705

+ 183 - 0
Applications/HumanSeg/README.md

@@ -0,0 +1,183 @@
+# HumanSeg人像分割模型
+
+本教程基于PaddleSeg核心分割网络,提供针对人像分割场景从预训练模型、Fine-tune、视频分割预测部署的全流程应用指南。
+
+## 安装
+
+**前置依赖**
+* paddlepaddle >= 1.8.0
+* python >= 3.5
+* cython
+* pycocotools
+
+```
+pip install paddlex -i https://mirror.baidu.com/pypi/simple
+```
+安装的相关问题参考[PaddleX安装](https://paddlex.readthedocs.io/zh_CN/latest/install.html)
+
+## 预训练模型
+HumanSeg开放了在大规模人像数据上训练的两个预训练模型,满足多种使用场景的需求
+
+| 模型类型 | Checkpoint | Inference Model | Quant Inference Model | 备注 |
+| --- | --- | --- | ---| --- |
+| HumanSeg-server  | [humanseg_server_ckpt](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_ckpt.zip) | [humanseg_server_inference](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_inference.zip) | -- | 高精度模型,适用于服务端GPU且背景复杂的人像场景, 模型结构为Deeplabv3+/Xcetion65, 输入大小(512, 512) |
+| HumanSeg-mobile | [humanseg_mobile_ckpt](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_ckpt.zip) | [humanseg_mobile_inference](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_inference.zip) | [humanseg_mobile_quant](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_quant.zip) | 轻量级模型, 适用于移动端或服务端CPU的前置摄像头场景,模型结构为HRNet_w18_samll_v1,输入大小(192, 192)  |
+
+
+模型性能
+
+| 模型 | 模型大小 | 计算耗时 |
+| --- | --- | --- |
+|humanseg_server_inference| 158M | - |
+|humanseg_mobile_inference | 5.8 M | 42.35ms |
+|humanseg_mobile_quant | 1.6M | 24.93ms |
+
+计算耗时运行环境: 小米,cpu:骁龙855, 内存:6GB, 图片大小:192*192)
+
+
+**NOTE:**
+其中Checkpoint为模型权重,用于Fine-tuning场景。
+
+* Inference Model和Quant Inference Model为预测部署模型,包含`__model__`计算图结构、`__params__`模型参数和`model.yaml`基础的模型配置信息。
+
+* 其中Inference Model适用于服务端的CPU和GPU预测部署,Qunat Inference Model为量化版本,适用于通过Paddle Lite进行移动端等端侧设备部署。
+
+执行以下脚本进行HumanSeg预训练模型的下载
+```bash
+python pretrain_weights/download_pretrain_weights.py
+```
+
+## 下载测试数据
+我们提供了[supervise.ly](https://supervise.ly/)发布人像分割数据集**Supervisely Persons**, 从中随机抽取一小部分并转化成PaddleSeg可直接加载数据格式。通过运行以下代码进行快速下载,其中包含手机前置摄像头的人像测试视频`video_test.mp4`.
+
+```bash
+python data/download_data.py
+```
+
+## 快速体验视频流人像分割
+结合DIS(Dense Inverse Search-basedmethod)光流算法预测结果与分割结果,改善视频流人像分割
+```bash
+# 通过电脑摄像头进行实时分割处理
+python video_infer.py --model_dir pretrain_weights/humanseg_lite_inference
+
+# 对人像视频进行分割处理
+python video_infer.py --model_dir pretrain_weights/humanseg_lite_inference --video_path data/video_test.mp4
+```
+
+视频分割结果如下:
+
+<img src="https://paddleseg.bj.bcebos.com/humanseg/data/video_test.gif" width="20%" height="20%"><img src="https://paddleseg.bj.bcebos.com/humanseg/data/result.gif" width="20%" height="20%">
+
+根据所选背景进行背景替换,背景可以是一张图片,也可以是一段视频。
+```bash
+# 通过电脑摄像头进行实时背景替换处理, 也可通过'--background_video_path'传入背景视频
+python bg_replace.py --model_dir pretrain_weights/humanseg_lite_inference --background_image_path data/background.jpg
+
+# 对人像视频进行背景替换处理, 也可通过'--background_video_path'传入背景视频
+python bg_replace.py --model_dir pretrain_weights/humanseg_lite_inference --video_path data/video_test.mp4 --background_image_path data/background.jpg
+
+# 对单张图像进行背景替换
+python bg_replace.py --model_dir pretrain_weights/humanseg_lite_inference --image_path data/human_image.jpg --background_image_path data/background.jpg
+
+```
+
+背景替换结果如下:
+
+<img src="https://paddleseg.bj.bcebos.com/humanseg/data/video_test.gif" width="20%" height="20%"><img src="https://paddleseg.bj.bcebos.com/humanseg/data/bg_replace.gif" width="20%" height="20%">
+
+
+**NOTE**:
+
+视频分割处理时间需要几分钟,请耐心等待。
+
+提供的模型适用于手机摄像头竖屏拍摄场景,宽屏效果会略差一些。
+
+## 训练
+使用下述命令基于与训练模型进行Fine-tuning,请确保选用的模型结构`model_type`与模型参数`pretrain_weights`匹配。
+```bash
+python train.py --model_type HumanSegMobile \
+--save_dir output/ \
+--data_dir data/mini_supervisely \
+--train_list data/mini_supervisely/train.txt \
+--val_list data/mini_supervisely/val.txt \
+--pretrain_weights pretrain_weights/humanseg_mobile_ckpt \
+--batch_size 8 \
+--learning_rate 0.001 \
+--num_epochs 10 \
+--image_shape 192 192
+```
+其中参数含义如下:
+* `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
+* `--save_dir`: 模型保存路径
+* `--data_dir`: 数据集路径
+* `--train_list`: 训练集列表路径
+* `--val_list`: 验证集列表路径
+* `--pretrain_weights`: 预训练模型路径
+* `--batch_size`: 批大小
+* `--learning_rate`: 初始学习率
+* `--num_epochs`: 训练轮数
+* `--image_shape`: 网络输入图像大小(w, h)
+
+更多命令行帮助可运行下述命令进行查看:
+```bash
+python train.py --help
+```
+**NOTE**
+可通过更换`--model_type`变量与对应的`--pretrain_weights`使用不同的模型快速尝试。
+
+## 评估
+使用下述命令进行评估
+```bash
+python eval.py --model_dir output/best_model \
+--data_dir data/mini_supervisely \
+--val_list data/mini_supervisely/val.txt \
+--image_shape 192 192
+```
+其中参数含义如下:
+* `--model_dir`: 模型路径
+* `--data_dir`: 数据集路径
+* `--val_list`: 验证集列表路径
+* `--image_shape`: 网络输入图像大小(w, h)
+
+## 预测
+使用下述命令进行预测, 预测结果默认保存在`./output/result/`文件夹中。
+```bash
+python infer.py --model_dir output/best_model \
+--data_dir data/mini_supervisely \
+--test_list data/mini_supervisely/test.txt \
+--save_dir output/result \
+--image_shape 192 192
+```
+其中参数含义如下:
+* `--model_dir`: 模型路径
+* `--data_dir`: 数据集路径
+* `--test_list`: 测试集列表路径
+* `--image_shape`: 网络输入图像大小(w, h)
+
+## 模型导出
+```bash
+paddlex --export_inference --model_dir output/best_model \
+--save_dir output/export
+```
+其中参数含义如下:
+* `--model_dir`: 模型路径
+* `--save_dir`: 导出模型保存路径
+
+## 离线量化
+```bash
+python quant_offline.py --model_dir output/best_model \
+--data_dir data/mini_supervisely \
+--quant_list data/mini_supervisely/val.txt \
+--save_dir output/quant_offline \
+--image_shape 192 192
+```
+其中参数含义如下:
+* `--model_dir`: 待量化模型路径
+* `--data_dir`: 数据集路径
+* `--quant_list`: 量化数据集列表路径,一般直接选择训练集或验证集
+* `--save_dir`: 量化模型保存路径
+* `--image_shape`: 网络输入图像大小(w, h)
+
+## AIStudio在线教程
+
+我们在AI Studio平台上提供了人像分割在线体验的教程,[点击体验](https://aistudio.baidu.com/aistudio/projectdetail/475345)

+ 290 - 0
Applications/HumanSeg/bg_replace.py

@@ -0,0 +1,290 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import os.path as osp
+import cv2
+import numpy as np
+
+from utils.humanseg_postprocess import postprocess, threshold_mask
+import paddlex as pdx
+import paddlex.utils.logging as logging
+from paddlex.seg import transforms
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='HumanSeg inference for video')
+    parser.add_argument(
+        '--model_dir',
+        dest='model_dir',
+        help='Model path for inference',
+        type=str)
+    parser.add_argument(
+        '--image_path',
+        dest='image_path',
+        help='Image including human',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--background_image_path',
+        dest='background_image_path',
+        help='Background image for replacing',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--video_path',
+        dest='video_path',
+        help='Video path for inference',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--background_video_path',
+        dest='background_video_path',
+        help='Background video path for replacing',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--save_dir',
+        dest='save_dir',
+        help='The directory for saving the inference results',
+        type=str,
+        default='./output')
+    parser.add_argument(
+        "--image_shape",
+        dest="image_shape",
+        help="The image shape for net inputs.",
+        nargs=2,
+        default=[192, 192],
+        type=int)
+
+    return parser.parse_args()
+
+
+def predict(img, model, test_transforms):
+    model.arrange_transforms(transforms=test_transforms, mode='test')
+    img, im_info = test_transforms(img.astype('float32'))
+    img = np.expand_dims(img, axis=0)
+    result = model.exe.run(model.test_prog,
+                           feed={'image': img},
+                           fetch_list=list(model.test_outputs.values()))
+    score_map = result[1]
+    score_map = np.squeeze(score_map, axis=0)
+    score_map = np.transpose(score_map, (1, 2, 0))
+    return score_map, im_info
+
+
+def recover(img, im_info):
+    for info in im_info[::-1]:
+        if info[0] == 'resize':
+            w, h = info[1][1], info[1][0]
+            img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
+        elif info[0] == 'padding':
+            w, h = info[1][0], info[1][0]
+            img = img[0:h, 0:w, :]
+    return img
+
+
+def bg_replace(score_map, img, bg):
+    h, w, _ = img.shape
+    bg = cv2.resize(bg, (w, h))
+    score_map = np.repeat(score_map[:, :, np.newaxis], 3, axis=2)
+    comb = (score_map * img + (1 - score_map) * bg).astype(np.uint8)
+    return comb
+
+
+def infer(args):
+    resize_h = args.image_shape[1]
+    resize_w = args.image_shape[0]
+
+    test_transforms = transforms.Compose(
+        [transforms.Resize((resize_w, resize_h)), transforms.Normalize()])
+    model = pdx.load_model(args.model_dir)
+
+    if not osp.exists(args.save_dir):
+        os.makedirs(args.save_dir)
+
+    # 图像背景替换
+    if args.image_path is not None:
+        if not osp.exists(args.image_path):
+            raise Exception('The --image_path is not existed: {}'.format(
+                args.image_path))
+        if args.background_image_path is None:
+            raise Exception(
+                'The --background_image_path is not set. Please set it')
+        else:
+            if not osp.exists(args.background_image_path):
+                raise Exception(
+                    'The --background_image_path is not existed: {}'.format(
+                        args.background_image_path))
+        img = cv2.imread(args.image_path)
+        score_map, im_info = predict(img, model, test_transforms)
+        score_map = score_map[:, :, 1]
+        score_map = recover(score_map, im_info)
+        bg = cv2.imread(args.background_image_path)
+        save_name = osp.basename(args.image_path)
+        save_path = osp.join(args.save_dir, save_name)
+        result = bg_replace(score_map, img, bg)
+        cv2.imwrite(save_path, result)
+
+    # 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
+    else:
+        is_video_bg = False
+        if args.background_video_path is not None:
+            if not osp.exists(args.background_video_path):
+                raise Exception(
+                    'The --background_video_path is not existed: {}'.format(
+                        args.background_video_path))
+            is_video_bg = True
+        elif args.background_image_path is not None:
+            if not osp.exists(args.background_image_path):
+                raise Exception(
+                    'The --background_image_path is not existed: {}'.format(
+                        args.background_image_path))
+        else:
+            raise Exception(
+                'Please offer backgound image or video. You should set --backbground_iamge_paht or --background_video_path'
+            )
+
+        disflow = cv2.DISOpticalFlow_create(
+            cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
+        prev_gray = np.zeros((resize_h, resize_w), np.uint8)
+        prev_cfd = np.zeros((resize_h, resize_w), np.float32)
+        is_init = True
+        if args.video_path is not None:
+            logging.info('Please wait. It is computing......')
+            if not osp.exists(args.video_path):
+                raise Exception('The --video_path is not existed: {}'.format(
+                    args.video_path))
+
+            cap_video = cv2.VideoCapture(args.video_path)
+            fps = cap_video.get(cv2.CAP_PROP_FPS)
+            width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
+            height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
+            save_name = osp.basename(args.video_path)
+            save_name = save_name.split('.')[0]
+            save_path = osp.join(args.save_dir, save_name + '.avi')
+
+            cap_out = cv2.VideoWriter(
+                save_path,
+                cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
+                (width, height))
+
+            if is_video_bg:
+                cap_bg = cv2.VideoCapture(args.background_video_path)
+                frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
+                current_frame_bg = 1
+            else:
+                img_bg = cv2.imread(args.background_image_path)
+            while cap_video.isOpened():
+                ret, frame = cap_video.read()
+                if ret:
+                    score_map, im_info = predict(frame, model, test_transforms)
+                    cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+                    cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
+                    score_map = 255 * score_map[:, :, 1]
+                    optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
+                                              disflow, is_init)
+                    prev_gray = cur_gray.copy()
+                    prev_cfd = optflow_map.copy()
+                    is_init = False
+                    optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
+                    optflow_map = threshold_mask(
+                        optflow_map, thresh_bg=0.2, thresh_fg=0.8)
+                    score_map = recover(optflow_map, im_info)
+
+                    #循环读取背景帧
+                    if is_video_bg:
+                        ret_bg, frame_bg = cap_bg.read()
+                        if ret_bg:
+                            if current_frame_bg == frames_bg:
+                                current_frame_bg = 1
+                                cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
+                        else:
+                            break
+                        current_frame_bg += 1
+                        comb = bg_replace(score_map, frame, frame_bg)
+                    else:
+                        comb = bg_replace(score_map, frame, img_bg)
+
+                    cap_out.write(comb)
+                else:
+                    break
+
+            if is_video_bg:
+                cap_bg.release()
+            cap_video.release()
+            cap_out.release()
+
+        # 当没有输入预测图像和视频的时候,则打开摄像头
+        else:
+            cap_video = cv2.VideoCapture(0)
+            if not cap_video.isOpened():
+                raise IOError("Error opening video stream or file, "
+                              "--video_path whether existing: {}"
+                              " or camera whether working".format(
+                                  args.video_path))
+                return
+
+            if is_video_bg:
+                cap_bg = cv2.VideoCapture(args.background_video_path)
+                frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
+                current_frame_bg = 1
+            else:
+                img_bg = cv2.imread(args.background_image_path)
+            while cap_video.isOpened():
+                ret, frame = cap_video.read()
+                if ret:
+                    score_map, im_info = predict(frame, model, test_transforms)
+                    cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+                    cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
+                    score_map = 255 * score_map[:, :, 1]
+                    optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
+                                              disflow, is_init)
+                    prev_gray = cur_gray.copy()
+                    prev_cfd = optflow_map.copy()
+                    is_init = False
+                    optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
+                    optflow_map = threshold_mask(
+                        optflow_map, thresh_bg=0.2, thresh_fg=0.8)
+                    score_map = recover(optflow_map, im_info)
+
+                    #循环读取背景帧
+                    if is_video_bg:
+                        ret_bg, frame_bg = cap_bg.read()
+                        if ret_bg:
+                            if current_frame_bg == frames_bg:
+                                current_frame_bg = 1
+                                cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
+                        else:
+                            break
+                        current_frame_bg += 1
+                        comb = bg_replace(score_map, frame, frame_bg)
+                    else:
+                        comb = bg_replace(score_map, frame, img_bg)
+                    cv2.imshow('HumanSegmentation', comb)
+                    if cv2.waitKey(1) & 0xFF == ord('q'):
+                        break
+                else:
+                    break
+            if is_video_bg:
+                cap_bg.release()
+            cap_video.release()
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    infer(args)

二进制
Applications/HumanSeg/data/background.jpg


+ 35 - 0
Applications/HumanSeg/data/download_data.py

@@ -0,0 +1,35 @@
+# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+
+LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
+TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
+sys.path.append(TEST_PATH)
+
+import paddlex as pdx
+
+
+def download_data(savepath):
+    url = "https://paddleseg.bj.bcebos.com/humanseg/data/mini_supervisely.zip"
+    pdx.utils.download_and_decompress(url=url, path=savepath)
+
+    url = "https://paddleseg.bj.bcebos.com/humanseg/data/video_test.zip"
+    pdx.utils.download_and_decompress(url=url, path=savepath)
+
+
+if __name__ == "__main__":
+    download_data(LOCAL_PATH)
+    print("Data download finish!")

二进制
Applications/HumanSeg/data/human_image.jpg


+ 85 - 0
Applications/HumanSeg/eval.py

@@ -0,0 +1,85 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import paddlex as pdx
+import paddlex.utils.logging as logging
+from paddlex.seg import transforms
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='HumanSeg training')
+    parser.add_argument(
+        '--model_dir',
+        dest='model_dir',
+        help='Model path for evaluating',
+        type=str,
+        default='output/best_model')
+    parser.add_argument(
+        '--data_dir',
+        dest='data_dir',
+        help='The root directory of dataset',
+        type=str)
+    parser.add_argument(
+        '--val_list',
+        dest='val_list',
+        help='Val list file of dataset',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--batch_size',
+        dest='batch_size',
+        help='Mini batch size',
+        type=int,
+        default=128)
+    parser.add_argument(
+        "--image_shape",
+        dest="image_shape",
+        help="The image shape for net inputs.",
+        nargs=2,
+        default=[192, 192],
+        type=int)
+    return parser.parse_args()
+
+
+def dict2str(dict_input):
+    out = ''
+    for k, v in dict_input.items():
+        try:
+            v = round(float(v), 6)
+        except:
+            pass
+        out = out + '{}={}, '.format(k, v)
+    return out.strip(', ')
+
+
+def evaluate(args):
+    eval_transforms = transforms.Compose(
+        [transforms.Resize(args.image_shape), transforms.Normalize()])
+
+    eval_dataset = pdx.datasets.SegDataset(
+        data_dir=args.data_dir,
+        file_list=args.val_list,
+        transforms=eval_transforms)
+
+    model = pdx.load_model(args.model_dir)
+    metrics = model.evaluate(eval_dataset, args.batch_size)
+    logging.info('[EVAL] Finished, {} .'.format(dict2str(metrics)))
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    evaluate(args)

+ 48 - 0
Applications/HumanSeg/pretrain_weights/download_pretrain_weights.py

@@ -0,0 +1,48 @@
+# coding: utf8
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+
+LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
+TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
+sys.path.append(TEST_PATH)
+
+import paddlex as pdx
+
+model_urls = {
+    "humanseg_server_ckpt":
+    "https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_ckpt.zip",
+    "humanseg_server_inference":
+    "https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_inference.zip",
+    "humanseg_mobile_ckpt":
+    "https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_ckpt.zip",
+    "humanseg_mobile_inference":
+    "https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_inference.zip",
+    "humanseg_mobile_quant":
+    "https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_quant.zip",
+    "humanseg_lite_ckpt":
+    "https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_ckpt.zip",
+    "humanseg_lite_inference":
+    "https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_inference.zip",
+    "humanseg_lite_quant":
+    "https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_quant.zip",
+}
+
+if __name__ == "__main__":
+    for model_name, url in model_urls.items():
+        pdx.utils.download_and_decompress(url=url, path=LOCAL_PATH)
+
+    print("Pretrained Model download success!")

+ 85 - 0
Applications/HumanSeg/quant_offline.py

@@ -0,0 +1,85 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import paddlex as pdx
+from paddlex.seg import transforms
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='HumanSeg training')
+    parser.add_argument(
+        '--model_dir',
+        dest='model_dir',
+        help='Model path for quant',
+        type=str,
+        default='output/best_model')
+    parser.add_argument(
+        '--batch_size',
+        dest='batch_size',
+        help='Mini batch size',
+        type=int,
+        default=1)
+    parser.add_argument(
+        '--batch_nums',
+        dest='batch_nums',
+        help='Batch number for quant',
+        type=int,
+        default=10)
+    parser.add_argument(
+        '--data_dir',
+        dest='data_dir',
+        help='the root directory of dataset',
+        type=str)
+    parser.add_argument(
+        '--quant_list',
+        dest='quant_list',
+        help='Image file list for model quantization, it can be vat.txt or train.txt',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--save_dir',
+        dest='save_dir',
+        help='The directory for saving the quant model',
+        type=str,
+        default='./output/quant_offline')
+    parser.add_argument(
+        "--image_shape",
+        dest="image_shape",
+        help="The image shape for net inputs.",
+        nargs=2,
+        default=[192, 192],
+        type=int)
+    return parser.parse_args()
+
+
+def evaluate(args):
+    eval_transforms = transforms.Compose(
+        [transforms.Resize(args.image_shape), transforms.Normalize()])
+
+    eval_dataset = pdx.datasets.SegDataset(
+        data_dir=args.data_dir,
+        file_list=args.quant_list,
+        transforms=eval_transforms)
+
+    model = pdx.load_model(args.model_dir)
+    pdx.slim.export_quant_model(model, eval_dataset, args.batch_size,
+                                args.batch_nums, args.save_dir)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    evaluate(args)

+ 154 - 50
Applications/HumanSeg/train.py

@@ -1,57 +1,161 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import os
 # 选择使用0号卡
 os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+# 使用CPU
+#os.environ['CUDA_VISIBLE_DEVICES'] = ''
+import argparse
 
 import paddlex as pdx
 from paddlex.seg import transforms
 
-# 下载和解压人像分割数据集
-human_seg_data = 'https://paddlex.bj.bcebos.com/humanseg/data/human_seg_data.zip'
-pdx.utils.download_and_decompress(human_seg_data, path='./')
-
-# 下载和解压人像分割预训练模型
-pretrain_weights = 'https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_ckpt.zip'
-pdx.utils.download_and_decompress(
-    pretrain_weights, path='./output/human_seg/pretrain')
-
-# 定义训练和验证时的transforms
-train_transforms = transforms.Compose([
-    transforms.Resize([192, 192]), transforms.RandomHorizontalFlip(),
-    transforms.Normalize()
-])
-
-eval_transforms = transforms.Compose(
-    [transforms.Resize([192, 192]), transforms.Normalize()])
-
-# 定义训练和验证所用的数据集
-# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
-train_dataset = pdx.datasets.SegDataset(
-    data_dir='human_seg_data',
-    file_list='human_seg_data/train_list.txt',
-    label_list='human_seg_data/labels.txt',
-    transforms=train_transforms,
-    shuffle=True)
-eval_dataset = pdx.datasets.SegDataset(
-    data_dir='human_seg_data',
-    file_list='human_seg_data/val_list.txt',
-    label_list='human_seg_data/labels.txt',
-    transforms=eval_transforms)
-
-# 初始化模型,并进行训练
-# 可使用VisualDL查看训练指标
-# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
-# 浏览器打开 https://0.0.0.0:8001即可
-# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
-
-# https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#hrnet
-num_classes = len(train_dataset.labels)
-model = pdx.seg.HRNet(num_classes=num_classes, width='18_small_v1')
-model.train(
-    num_epochs=10,
-    train_dataset=train_dataset,
-    train_batch_size=8,
-    eval_dataset=eval_dataset,
-    learning_rate=0.001,
-    pretrain_weights='./output/human_seg/pretrain/humanseg_mobile_ckpt',
-    save_dir='output/human_seg',
-    use_vdl=True)
+MODEL_TYPE = ['HumanSegMobile', 'HumanSegServer']
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='HumanSeg training')
+    parser.add_argument(
+        '--model_type',
+        dest='model_type',
+        help="Model type for traing, which is one of ('HumanSegMobile', 'HumanSegServer')",
+        type=str,
+        default='HumanSegMobile')
+    parser.add_argument(
+        '--data_dir',
+        dest='data_dir',
+        help='The root directory of dataset',
+        type=str)
+    parser.add_argument(
+        '--train_list',
+        dest='train_list',
+        help='Train list file of dataset',
+        type=str)
+    parser.add_argument(
+        '--val_list',
+        dest='val_list',
+        help='Val list file of dataset',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--save_dir',
+        dest='save_dir',
+        help='The directory for saving the model snapshot',
+        type=str,
+        default='./output')
+    parser.add_argument(
+        '--num_classes',
+        dest='num_classes',
+        help='Number of classes',
+        type=int,
+        default=2)
+    parser.add_argument(
+        "--image_shape",
+        dest="image_shape",
+        help="The image shape for net inputs.",
+        nargs=2,
+        default=[192, 192],
+        type=int)
+    parser.add_argument(
+        '--num_epochs',
+        dest='num_epochs',
+        help='Number epochs for training',
+        type=int,
+        default=100)
+    parser.add_argument(
+        '--batch_size',
+        dest='batch_size',
+        help='Mini batch size',
+        type=int,
+        default=128)
+    parser.add_argument(
+        '--learning_rate',
+        dest='learning_rate',
+        help='Learning rate',
+        type=float,
+        default=0.01)
+    parser.add_argument(
+        '--pretrain_weights',
+        dest='pretrain_weights',
+        help='The path of pretrianed weight',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--resume_checkpoint',
+        dest='resume_checkpoint',
+        help='The path of resume checkpoint',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--use_vdl',
+        dest='use_vdl',
+        help='Whether to use visualdl',
+        action='store_true')
+    parser.add_argument(
+        '--save_interval_epochs',
+        dest='save_interval_epochs',
+        help='The interval epochs for save a model snapshot',
+        type=int,
+        default=5)
+
+    return parser.parse_args()
+
+
+def train(args):
+    train_transforms = transforms.Compose([
+        transforms.Resize(args.image_shape), transforms.RandomHorizontalFlip(),
+        transforms.Normalize()
+    ])
+
+    eval_transforms = transforms.Compose(
+        [transforms.Resize(args.image_shape), transforms.Normalize()])
+
+    train_dataset = pdx.datasets.SegDataset(
+        data_dir=args.data_dir,
+        file_list=args.train_list,
+        transforms=train_transforms,
+        shuffle=True)
+    eval_dataset = pdx.datasets.SegDataset(
+        data_dir=args.data_dir,
+        file_list=args.val_list,
+        transforms=eval_transforms)
+
+    if args.model_type == 'HumanSegMobile':
+        model = pdx.seg.HRNet(
+            num_classes=args.num_classes, width='18_small_v1')
+    elif args.model_type == 'HumanSegServer':
+        model = pdx.seg.DeepLabv3p(
+            num_classes=args.num_classes, backbone='Xception65')
+    else:
+        raise ValueError(
+            "--model_type: {} is set wrong, it shold be one of ('HumanSegMobile', "
+            "'HumanSegLite', 'HumanSegServer')".format(args.model_type))
+    model.train(
+        num_epochs=args.num_epochs,
+        train_dataset=train_dataset,
+        train_batch_size=args.batch_size,
+        eval_dataset=eval_dataset,
+        save_interval_epochs=args.save_interval_epochs,
+        learning_rate=args.learning_rate,
+        pretrain_weights=args.pretrain_weights,
+        resume_checkpoint=args.resume_checkpoint,
+        save_dir=args.save_dir,
+        use_vdl=args.use_vdl)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    train(args)

+ 15 - 0
Applications/HumanSeg/utils/__init__.py

@@ -0,0 +1,15 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import humanseg_postprocess

+ 124 - 0
Applications/HumanSeg/utils/humanseg_postprocess.py

@@ -0,0 +1,124 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+
+def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
+    """计算光流跟踪匹配点和光流图
+    输入参数:
+        pre_gray: 上一帧灰度图
+        cur_gray: 当前帧灰度图
+        prev_cfd: 上一帧光流图
+        dl_weights: 融合权重图
+        disflow: 光流数据结构
+    返回值:
+        is_track: 光流点跟踪二值图,即是否具有光流点匹配
+        track_cfd: 光流跟踪图
+    """
+    check_thres = 8
+    h, w = pre_gray.shape[:2]
+    track_cfd = np.zeros_like(prev_cfd)
+    is_track = np.zeros_like(pre_gray)
+    flow_fw = disflow.calc(pre_gray, cur_gray, None)
+    flow_bw = disflow.calc(cur_gray, pre_gray, None)
+    flow_fw = np.round(flow_fw).astype(np.int)
+    flow_bw = np.round(flow_bw).astype(np.int)
+    y_list = np.array(range(h))
+    x_list = np.array(range(w))
+    yv, xv = np.meshgrid(y_list, x_list)
+    yv, xv = yv.T, xv.T
+    cur_x = xv + flow_fw[:, :, 0]
+    cur_y = yv + flow_fw[:, :, 1]
+
+    # 超出边界不跟踪
+    not_track = (cur_x < 0) + (cur_x >= w) + (cur_y < 0) + (cur_y >= h)
+    flow_bw[~not_track] = flow_bw[cur_y[~not_track], cur_x[~not_track]]
+    not_track += (np.square(flow_fw[:, :, 0] + flow_bw[:, :, 0]) +
+                  np.square(flow_fw[:, :, 1] + flow_bw[:, :, 1])
+                  ) >= check_thres
+    track_cfd[cur_y[~not_track], cur_x[~not_track]] = prev_cfd[~not_track]
+
+    is_track[cur_y[~not_track], cur_x[~not_track]] = 1
+
+    not_flow = np.all(np.abs(flow_fw) == 0,
+                      axis=-1) * np.all(np.abs(flow_bw) == 0, axis=-1)
+    dl_weights[cur_y[not_flow], cur_x[not_flow]] = 0.05
+    return track_cfd, is_track, dl_weights
+
+
+def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
+    """光流追踪图和人像分割结构融合
+    输入参数:
+        track_cfd: 光流追踪图
+        dl_cfd: 当前帧分割结果
+        dl_weights: 融合权重图
+        is_track: 光流点匹配二值图
+    返回
+        cur_cfd: 光流跟踪图和人像分割结果融合图
+    """
+    fusion_cfd = dl_cfd.copy()
+    is_track = is_track.astype(np.bool)
+    fusion_cfd[is_track] = dl_weights[is_track] * dl_cfd[is_track] + (
+        1 - dl_weights[is_track]) * track_cfd[is_track]
+    # 确定区域
+    index_certain = ((dl_cfd > 0.9) + (dl_cfd < 0.1)) * is_track
+    index_less01 = (dl_weights < 0.1) * index_certain
+    fusion_cfd[index_less01] = 0.3 * dl_cfd[index_less01] + 0.7 * track_cfd[
+        index_less01]
+    index_larger09 = (dl_weights >= 0.1) * index_certain
+    fusion_cfd[index_larger09] = 0.4 * dl_cfd[
+        index_larger09] + 0.6 * track_cfd[index_larger09]
+    return fusion_cfd
+
+
+def threshold_mask(img, thresh_bg, thresh_fg):
+    dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
+    dst[np.where(dst > 1)] = 1
+    dst[np.where(dst < 0)] = 0
+    return dst.astype(np.float32)
+
+
+def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
+    """光流优化
+    Args:
+        cur_gray : 当前帧灰度图
+        pre_gray : 前一帧灰度图
+        pre_cfd  :前一帧融合结果
+        scoremap : 当前帧分割结果
+        difflow  : 光流
+        is_init : 是否第一帧
+    Returns:
+        fusion_cfd : 光流追踪图和预测结果融合图
+    """
+    h, w = scoremap.shape
+    cur_cfd = scoremap.copy()
+
+    if is_init:
+        if h <= 64 or w <= 64:
+            disflow.setFinestScale(1)
+        elif h <= 160 or w <= 160:
+            disflow.setFinestScale(2)
+        else:
+            disflow.setFinestScale(3)
+        fusion_cfd = cur_cfd
+    else:
+        weights = np.ones((h, w), np.float32) * 0.3
+        track_cfd, is_track, weights = human_seg_tracking(
+            prev_gray, cur_gray, pre_cfd, weights, disflow)
+        fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights,
+                                          is_track)
+
+    return fusion_cfd

+ 177 - 0
Applications/HumanSeg/video_infer.py

@@ -0,0 +1,177 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import os.path as osp
+import cv2
+import numpy as np
+
+from utils.humanseg_postprocess import postprocess, threshold_mask
+import paddlex as pdx
+import paddlex.utils.logging as logging
+from paddlex.seg import transforms
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='HumanSeg inference for video')
+    parser.add_argument(
+        '--model_dir',
+        dest='model_dir',
+        help='Model path for inference',
+        type=str)
+    parser.add_argument(
+        '--video_path',
+        dest='video_path',
+        help='Video path for inference, camera will be used if the path not existing',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--save_dir',
+        dest='save_dir',
+        help='The directory for saving the inference results',
+        type=str,
+        default='./output')
+    parser.add_argument(
+        "--image_shape",
+        dest="image_shape",
+        help="The image shape for net inputs.",
+        nargs=2,
+        default=[192, 192],
+        type=int)
+
+    return parser.parse_args()
+
+
+def predict(img, model, test_transforms):
+    model.arrange_transforms(transforms=test_transforms, mode='test')
+    img, im_info = test_transforms(img.astype('float32'))
+    img = np.expand_dims(img, axis=0)
+    result = model.exe.run(model.test_prog,
+                           feed={'image': img},
+                           fetch_list=list(model.test_outputs.values()))
+    score_map = result[1]
+    score_map = np.squeeze(score_map, axis=0)
+    score_map = np.transpose(score_map, (1, 2, 0))
+    return score_map, im_info
+
+
+def recover(img, im_info):
+    for info in im_info[::-1]:
+        if info[0] == 'resize':
+            w, h = info[1][1], info[1][0]
+            img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
+        elif info[0] == 'padding':
+            w, h = info[1][0], info[1][0]
+            img = img[0:h, 0:w, :]
+    return img
+
+
+def video_infer(args):
+    resize_h = args.image_shape[1]
+    resize_w = args.image_shape[0]
+
+    test_transforms = transforms.Compose(
+        [transforms.Resize((resize_w, resize_h)), transforms.Normalize()])
+    model = pdx.load_model(args.model_dir)
+    if not args.video_path:
+        cap = cv2.VideoCapture(0)
+    else:
+        cap = cv2.VideoCapture(args.video_path)
+    if not cap.isOpened():
+        raise IOError("Error opening video stream or file, "
+                      "--video_path whether existing: {}"
+                      " or camera whether working".format(args.video_path))
+        return
+
+    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+
+    disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
+    prev_gray = np.zeros((resize_h, resize_w), np.uint8)
+    prev_cfd = np.zeros((resize_h, resize_w), np.float32)
+    is_init = True
+
+    fps = cap.get(cv2.CAP_PROP_FPS)
+    if args.video_path:
+        logging.info("Please wait. It is computing......")
+        # 用于保存预测结果视频
+        if not osp.exists(args.save_dir):
+            os.makedirs(args.save_dir)
+        out = cv2.VideoWriter(
+            osp.join(args.save_dir, 'result.avi'),
+            cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height))
+        # 开始获取视频帧
+        while cap.isOpened():
+            ret, frame = cap.read()
+            if ret:
+                score_map, im_info = predict(frame, model, test_transforms)
+                cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+                cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
+                score_map = 255 * score_map[:, :, 1]
+                optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
+                        disflow, is_init)
+                prev_gray = cur_gray.copy()
+                prev_cfd = optflow_map.copy()
+                is_init = False
+                optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
+                optflow_map = threshold_mask(
+                    optflow_map, thresh_bg=0.2, thresh_fg=0.8)
+                img_matting = np.repeat(
+                    optflow_map[:, :, np.newaxis], 3, axis=2)
+                img_matting = recover(img_matting, im_info)
+                bg_im = np.ones_like(img_matting) * 255
+                comb = (img_matting * frame +
+                        (1 - img_matting) * bg_im).astype(np.uint8)
+                out.write(comb)
+            else:
+                break
+        cap.release()
+        out.release()
+
+    else:
+        while cap.isOpened():
+            ret, frame = cap.read()
+            if ret:
+                score_map, im_info = predict(frame, model, test_transforms)
+                cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+                cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
+                score_map = 255 * score_map[:, :, 1]
+                optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
+                                          disflow, is_init)
+                prev_gray = cur_gray.copy()
+                prev_cfd = optflow_map.copy()
+                is_init = False
+                optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
+                optflow_map = threshold_mask(
+                    optflow_map, thresh_bg=0.2, thresh_fg=0.8)
+                img_matting = np.repeat(
+                    optflow_map[:, :, np.newaxis], 3, axis=2)
+                img_matting = recover(img_matting, im_info)
+                bg_im = np.ones_like(img_matting) * 255
+                comb = (img_matting * frame +
+                        (1 - img_matting) * bg_im).astype(np.uint8)
+                cv2.imshow('HumanSegmentation', comb)
+                if cv2.waitKey(1) & 0xFF == ord('q'):
+                    break
+            else:
+                break
+        cap.release()
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    video_infer(args)

+ 10 - 9
paddlex/cv/datasets/seg_dataset.py

@@ -1,4 +1,4 @@
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -28,7 +28,7 @@ class SegDataset(Dataset):
     Args:
         data_dir (str): 数据集所在的目录路径。
         file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
-        label_list (str): 描述数据集包含的类别信息文件路径。
+        label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
         transforms (list): 数据集中每个样本的预处理/增强算子。
         num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。
         buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
@@ -40,7 +40,7 @@ class SegDataset(Dataset):
     def __init__(self,
                  data_dir,
                  file_list,
-                 label_list,
+                 label_list=None,
                  transforms=None,
                  num_workers='auto',
                  buffer_size=100,
@@ -56,10 +56,11 @@ class SegDataset(Dataset):
         self.labels = list()
         self._epoch = 0
 
-        with open(label_list, encoding=get_encoding(label_list)) as f:
-            for line in f:
-                item = line.strip()
-                self.labels.append(item)
+        if label_list is not None:
+            with open(label_list, encoding=get_encoding(label_list)) as f:
+                for line in f:
+                    item = line.strip()
+                    self.labels.append(item)
 
         with open(file_list, encoding=get_encoding(file_list)) as f:
             for line in f:
@@ -69,8 +70,8 @@ class SegDataset(Dataset):
                 full_path_im = osp.join(data_dir, items[0])
                 full_path_label = osp.join(data_dir, items[1])
                 if not osp.exists(full_path_im):
-                    raise IOError(
-                        'The image file {} is not exist!'.format(full_path_im))
+                    raise IOError('The image file {} is not exist!'.format(
+                        full_path_im))
                 if not osp.exists(full_path_label):
                     raise IOError('The image file {} is not exist!'.format(
                         full_path_label))