Browse Source

Merge pull request #264 from mamingjie-China/develop

add dataset split tools and fix some docs
Jason 5 years ago
parent
commit
91b57d5ee7

+ 7 - 7
README.md

@@ -30,7 +30,7 @@
 
 
   **前置依赖**
   **前置依赖**
 > - paddlepaddle >= 1.8.0
 > - paddlepaddle >= 1.8.0
-> - python >= 3.5
+> - python >= 3.6
 > - cython
 > - cython
 > - pycocotools
 > - pycocotools
 
 
@@ -48,13 +48,13 @@ pip install paddlex -i https://mirror.baidu.com/pypi/simple
 
 
 - 前往[PaddleX GUI使用教程](./docs/gui/how_to_use.md)了解PaddleX GUI使用详情。
 - 前往[PaddleX GUI使用教程](./docs/gui/how_to_use.md)了解PaddleX GUI使用详情。
 
 
-- https://aistudio.baidu.com/aistudio/projectdetail/440197)
+- https://aistudio.baidu.com/aistudio/projectdetail/440197
 
 
   
   
 
 
 ## 产品模块说明
 ## 产品模块说明
 
 
-- **数据准备**:兼容ImageNet、VOC、COCO等常用数据协议, 同时与Labelme、精灵标注助手、[EasyData智能数据服务平台](https://ai.baidu.com/easydata/)等无缝衔接,全方位助力开发者更快完成数据准备工作。
+- **数据准备**:兼容ImageNet、VOC、COCO等常用数据协议同时与Labelme、精灵标注助手、[EasyData智能数据服务平台](https://ai.baidu.com/easydata/)等无缝衔接,全方位助力开发者更快完成数据准备工作。
 
 
 - **数据预处理及增强**:提供极简的图像预处理和增强方法--Transforms,适配imgaug图像增强库,支持**上百种数据增强策略**,是开发者快速缓解小样本数据训练的问题。
 - **数据预处理及增强**:提供极简的图像预处理和增强方法--Transforms,适配imgaug图像增强库,支持**上百种数据增强策略**,是开发者快速缓解小样本数据训练的问题。
 
 
@@ -93,7 +93,7 @@ pip install paddlex -i https://mirror.baidu.com/pypi/simple
   * [工业表计读数](https://paddlex.readthedocs.io/zh_CN/develop/examples/meter_reader.html)
   * [工业表计读数](https://paddlex.readthedocs.io/zh_CN/develop/examples/meter_reader.html)
 
 
 * 工业质检:
 * 工业质检:
-  * 电池隔膜缺陷检测(Comming Soon)
+  * 电池隔膜缺陷检测(Coming Soon)
 
 
 * [人像分割](https://paddlex.readthedocs.io/zh_CN/develop/examples/human_segmentation.html)
 * [人像分割](https://paddlex.readthedocs.io/zh_CN/develop/examples/human_segmentation.html)
 
 
@@ -105,8 +105,8 @@ pip install paddlex -i https://mirror.baidu.com/pypi/simple
 
 
 ## 交流与反馈
 ## 交流与反馈
 
 
-- 项目官网: https://www.paddlepaddle.org.cn/paddle/paddlex
-- PaddleX用户交流群: 1045148026 (手机QQ扫描如下二维码快速加入)  
+- 项目官网https://www.paddlepaddle.org.cn/paddle/paddlex
+- PaddleX用户交流群1045148026 (手机QQ扫描如下二维码快速加入)  
   ![](./docs/gui/images/QR.jpg)
   ![](./docs/gui/images/QR.jpg)
 
 
 
 
@@ -124,4 +124,4 @@ pip install paddlex -i https://mirror.baidu.com/pypi/simple
 
 
 ## 贡献代码
 ## 贡献代码
 
 
-我们非常欢迎您为PaddleX贡献代码或者提供使用建议。如果您可以修复某个issue或者增加一个新功能,欢迎给我们提交Pull Requests.
+我们非常欢迎您为PaddleX贡献代码或者提供使用建议。如果您可以修复某个issue或者增加一个新功能,欢迎给我们提交Pull Requests

+ 1 - 1
docs/apis/deploy.md

@@ -40,7 +40,7 @@ predict(image, topk=1)
 > **参数**
 > **参数**
 >
 >
 > > * **image** (str|np.ndarray): 待预测的图片路径或numpy数组(HWC排列,BGR格式)。
 > > * **image** (str|np.ndarray): 待预测的图片路径或numpy数组(HWC排列,BGR格式)。
-> > * **topk** (int): 图像分类时使用的参数,表示预测前topk个可能的分类
+> > * **topk** (int): 图像分类时使用的参数,表示预测前topk个可能的分类
 
 
 ### batch_predict 接口
 ### batch_predict 接口
 ```
 ```

+ 1 - 1
docs/apis/models/detection.md

@@ -21,7 +21,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 > > - **nms_score_threshold** (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
 > > - **nms_score_threshold** (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
 > > - **nms_topk** (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
 > > - **nms_topk** (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
 > > - **nms_keep_topk** (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
 > > - **nms_keep_topk** (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
-> > - **nms_iou_threshold** (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
+> > - **nms_iou_threshold** (float): 进行NMS时,用于剔除检测框IoU的阈值。默认为0.45。
 > > - **label_smooth** (bool): 是否使用label smooth。默认值为False。
 > > - **label_smooth** (bool): 是否使用label smooth。默认值为False。
 > > - **train_random_shapes** (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
 > > - **train_random_shapes** (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
 
 

+ 1 - 1
docs/apis/models/instance_segmentation.md

@@ -101,4 +101,4 @@ batch_predict(self, img_file_list, transforms=None, thread_num=2)
 >
 >
 > **返回值**
 > **返回值**
 >
 >
-> > - **list**: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个元素均为一个dict,key'bbox', 'mask', 'category', 'category_id', 'score',分别表示每个预测目标的框坐标信息、Mask信息,类别、类别id、置信度。其中框坐标信息为[xmin, ymin, w, h],即左上角x, y坐标和框的宽和高。Mask信息为原图大小的二值图,1表示像素点属于预测类别,0表示像素点是背景。
+> > - **list**: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个元素均为一个dict,包含关键字:'bbox', 'mask', 'category', 'category_id', 'score',分别表示每个预测目标的框坐标信息、Mask信息,类别、类别id、置信度。其中框坐标信息为[xmin, ymin, w, h],即左上角x, y坐标和框的宽和高。Mask信息为原图大小的二值图,1表示像素点属于预测类别,0表示像素点是背景。

+ 1 - 1
docs/apis/models/semantic_segmentation.md

@@ -69,7 +69,7 @@ evaluate(self, eval_dataset, batch_size=1, epoch_id=None, return_details=False):
 > **返回值**
 > **返回值**
 > >
 > >
 > > - **dict**: 当`return_details`为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、
 > > - **dict**: 当`return_details`为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、
-> >   'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。
+> >   'category_acc'和'kappa',分别表示平均IoU、各类别IoU、平均准确率、各类别准确率和kappa系数。
 > > - **tuple** (metrics, eval_details):当`return_details`为True时,增加返回dict (eval_details),
 > > - **tuple** (metrics, eval_details):当`return_details`为True时,增加返回dict (eval_details),
 > >   包含关键字:'confusion_matrix',表示评估的混淆矩阵。
 > >   包含关键字:'confusion_matrix',表示评估的混淆矩阵。
 
 

+ 6 - 6
docs/apis/slim.md

@@ -26,16 +26,16 @@ paddlex.slim.cal_params_sensitivities(model, save_file, eval_dataset, batch_size
 ```
 ```
 paddlex.slim.export_quant_model(model, test_dataset, batch_size=2, batch_num=10, save_dir='./quant_model', cache_dir='./temp')
 paddlex.slim.export_quant_model(model, test_dataset, batch_size=2, batch_num=10, save_dir='./quant_model', cache_dir='./temp')
 ```
 ```
-导出量化模型,该接口实现了Post Quantization量化方式,需要传入测试数据集,并设定`batch_size`和`batch_num`。量化过程中会以数量为`batch_size` X `batch_num`的样本数据的计算结果为统计信息完成模型的量化。
+导出量化模型,该接口实现了Post Quantization量化方式,需要传入测试数据集,并设定`batch_size`和`batch_num`。量化过程中会以数量为`batch_size` * `batch_num`的样本数据的计算结果为统计信息完成模型的量化。
 
 
 **参数**
 **参数**
 
 
 * **model**(paddlex.cls.models/paddlex.det.models/paddlex.seg.models): paddlex加载的模型。
 * **model**(paddlex.cls.models/paddlex.det.models/paddlex.seg.models): paddlex加载的模型。
-* **test_dataset**(paddlex.dataset): 测试数据集
-* **batch_size**(int): 进行前向计算时的批数据大小
-* **batch_num**(int): 进行向前计算时批数据数量
-* **save_dir**(str): 量化后模型的保存目录
-* **cache_dir**(str): 量化过程中的统计数据临时存储目录
+* **test_dataset**(paddlex.dataset): 测试数据集
+* **batch_size**(int): 进行前向计算时的批数据大小
+* **batch_num**(int): 进行向前计算时批数据数量
+* **save_dir**(str): 量化后模型的保存目录
+* **cache_dir**(str): 量化过程中的统计数据临时存储目录
 
 
 
 
 **使用示例**
 **使用示例**

+ 1 - 1
docs/apis/transforms/augment.md

@@ -14,7 +14,7 @@ PaddleX对于图像分类、目标检测、实例分割和语义分割内置了
 
 
 ## imgaug增强库的支持
 ## imgaug增强库的支持
 
 
-PaddleX目前已适配imgaug图像增强库,用户可以直接在PaddleX构造`transforms`时,调用imgaug的方法, 如下示例
+PaddleX目前已适配imgaug图像增强库,用户可以直接在PaddleX构造`transforms`时,调用imgaug的方法,如下示例,
 ```
 ```
 import paddlex as pdx
 import paddlex as pdx
 from paddlex.cls import transforms
 from paddlex.cls import transforms

+ 5 - 5
docs/apis/transforms/seg_transforms.md

@@ -16,7 +16,7 @@ paddlex.seg.transforms.Compose(transforms)
 ```python
 ```python
 paddlex.seg.transforms.RandomHorizontalFlip(prob=0.5)
 paddlex.seg.transforms.RandomHorizontalFlip(prob=0.5)
 ```
 ```
-以一定的概率对图像进行水平翻转,模型训练时的数据增强操作。
+以一定的概率对图像进行水平翻转模型训练时的数据增强操作。
 ### 参数
 ### 参数
 * **prob** (float): 随机水平翻转的概率。默认值为0.5。
 * **prob** (float): 随机水平翻转的概率。默认值为0.5。
 
 
@@ -25,7 +25,7 @@ paddlex.seg.transforms.RandomHorizontalFlip(prob=0.5)
 ```python
 ```python
 paddlex.seg.transforms.RandomVerticalFlip(prob=0.1)
 paddlex.seg.transforms.RandomVerticalFlip(prob=0.1)
 ```
 ```
-以一定的概率对图像进行垂直翻转,模型训练时的数据增强操作。
+以一定的概率对图像进行垂直翻转模型训练时的数据增强操作。
 ### 参数
 ### 参数
 * **prob**  (float): 随机垂直翻转的概率。默认值为0.1。
 * **prob**  (float): 随机垂直翻转的概率。默认值为0.1。
 
 
@@ -59,7 +59,7 @@ paddlex.seg.transforms.ResizeByLong(long_size)
 ```python
 ```python
 paddlex.seg.transforms.ResizeRangeScaling(min_value=400, max_value=600)
 paddlex.seg.transforms.ResizeRangeScaling(min_value=400, max_value=600)
 ```
 ```
-对图像长边随机resize到指定范围内,短边按比例进行缩放,模型训练时的数据增强操作。
+对图像长边随机resize到指定范围内,短边按比例进行缩放模型训练时的数据增强操作。
 ### 参数
 ### 参数
 * **min_value** (int): 图像长边resize后的最小值。默认值400。
 * **min_value** (int): 图像长边resize后的最小值。默认值400。
 * **max_value** (int): 图像长边resize后的最大值。默认值600。
 * **max_value** (int): 图像长边resize后的最大值。默认值600。
@@ -124,7 +124,7 @@ paddlex.seg.transforms.RandomBlur(prob=0.1)
 ```python
 ```python
 paddlex.seg.transforms.RandomRotate(rotate_range=15, im_padding_value=[127.5, 127.5, 127.5], label_padding_value=255)
 paddlex.seg.transforms.RandomRotate(rotate_range=15, im_padding_value=[127.5, 127.5, 127.5], label_padding_value=255)
 ```
 ```
-对图像进行随机旋转, 模型训练时的数据增强操作。
+对图像进行随机旋转模型训练时的数据增强操作。
 
 
 在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
 在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
 并对旋转后的图像和标注图像进行相应的padding。
 并对旋转后的图像和标注图像进行相应的padding。
@@ -138,7 +138,7 @@ paddlex.seg.transforms.RandomRotate(rotate_range=15, im_padding_value=[127.5, 12
 ```python
 ```python
 paddlex.seg.transforms.RandomScaleAspect(min_scale=0.5, aspect_ratio=0.33)
 paddlex.seg.transforms.RandomScaleAspect(min_scale=0.5, aspect_ratio=0.33)
 ```
 ```
-裁剪并resize回原始尺寸的图像和标注图像,模型训练时的数据增强操作。
+裁剪并resize回原始尺寸的图像和标注图像模型训练时的数据增强操作。
 
 
 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
 ### 参数
 ### 参数

+ 1 - 1
docs/apis/visualize.md

@@ -131,7 +131,7 @@ paddlex.transforms.visualize(dataset,
 对数据预处理/增强中间结果进行可视化。
 对数据预处理/增强中间结果进行可视化。
 可使用VisualDL查看中间结果:
 可使用VisualDL查看中间结果:
 1. VisualDL启动方式: visualdl --logdir vdl_output --port 8001
 1. VisualDL启动方式: visualdl --logdir vdl_output --port 8001
-2. 浏览器打开 https://0.0.0.0:8001即可,
+2. 浏览器打开 https://0.0.0.0:8001 即可,
     其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
     其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
 
 
 ### 参数
 ### 参数

+ 1 - 1
docs/deploy/paddlelite/slim/prune.md

@@ -49,7 +49,7 @@ PaddleX提供了两种方式:
 ### 语义分割
 ### 语义分割
 实验背景:使用UNet模型,数据集为视盘分割示例数据,剪裁训练代码见[tutorials/compress/segmentation](https://github.com/PaddlePaddle/PaddleX/tree/develop/tutorials/compress/segmentation)
 实验背景:使用UNet模型,数据集为视盘分割示例数据,剪裁训练代码见[tutorials/compress/segmentation](https://github.com/PaddlePaddle/PaddleX/tree/develop/tutorials/compress/segmentation)
 
 
-| 模型 | 剪裁情况 | 模型大小 | mIOU(%) |GPU预测速度 | CPU预测速度 |
+| 模型 | 剪裁情况 | 模型大小 | mIoU(%) |GPU预测速度 | CPU预测速度 |
 | :-----| :--------| :-------- | :---------- |:---------- | :---------|
 | :-----| :--------| :-------- | :---------- |:---------- | :---------|
 |UNet | 无剪裁(原模型)| 77M | 91.22 |33.28ms |9523.55ms |
 |UNet | 无剪裁(原模型)| 77M | 91.22 |33.28ms |9523.55ms |
 |UNet | 方案一(eval_metric_loss=0.10) |26M | 90.37 |21.04ms |3936.20ms |
 |UNet | 方案一(eval_metric_loss=0.10) |26M | 90.37 |21.04ms |3936.20ms |

+ 2 - 2
docs/examples/solutions.md

@@ -74,9 +74,9 @@ PaddleX目前提供了实例分割MaskRCNN模型,支持5种不同的backbone
 > 表中GPU预测速度是使用PaddlePaddle Python预测接口测试得到(测试GPU型号为Nvidia Tesla P40)。
 > 表中GPU预测速度是使用PaddlePaddle Python预测接口测试得到(测试GPU型号为Nvidia Tesla P40)。
 > 表中CPU预测速度 (测试CPU型号为)。
 > 表中CPU预测速度 (测试CPU型号为)。
 > 表中骁龙855预测速度是使用处理器为骁龙855的手机测试得到。
 > 表中骁龙855预测速度是使用处理器为骁龙855的手机测试得到。
-> 测速时模型的输入大小为1024 x 2048,mIOU为Cityscapes数据集上评估所得。
+> 测速时模型的输入大小为1024 x 2048,mIoU为Cityscapes数据集上评估所得。
 
 
-| 模型 | 模型特点 | 存储体积 | GPU预测速度 | CPU(x86)预测速度(毫秒) | 骁龙855(ARM)预测速度 (毫秒)| mIOU |
+| 模型 | 模型特点 | 存储体积 | GPU预测速度 | CPU(x86)预测速度(毫秒) | 骁龙855(ARM)预测速度 (毫秒)| mIoU |
 | :---- | :------- | :---------- | :---------- | :----- | :----- |:--- |
 | :---- | :------- | :---------- | :---------- | :----- | :----- |:--- |
 | DeepLabv3p-MobileNetV2_x1.0 | 轻量级模型,适用于移动端场景| - | - | - | 69.8% |
 | DeepLabv3p-MobileNetV2_x1.0 | 轻量级模型,适用于移动端场景| - | - | - | 69.8% |
 | HRNet_W18_Small_v1 | 轻量高速,适用于移动端场景 | - | - | - | - |
 | HRNet_W18_Small_v1 | 轻量高速,适用于移动端场景 | - | - | - | - |

+ 2 - 2
docs/train/semantic_segmentation.md

@@ -4,11 +4,11 @@
 
 
 PaddleX目前提供了DeepLabv3p、UNet、HRNet和FastSCNN四种语义分割结构,多种backbone模型,可满足开发者不同场景和性能的需求。
 PaddleX目前提供了DeepLabv3p、UNet、HRNet和FastSCNN四种语义分割结构,多种backbone模型,可满足开发者不同场景和性能的需求。
 
 
-- **mIOU**: 模型在CityScape数据集上的测试精度
+- **mIoU**: 模型在CityScape数据集上的测试精度
 - **预测速度**:单张图片的预测用时(不包括预处理和后处理)
 - **预测速度**:单张图片的预测用时(不包括预处理和后处理)
 - "-"表示指标暂未更新
 - "-"表示指标暂未更新
 
 
-| 模型(点击获取代码)               | mIOU | 模型大小 | GPU预测速度 | Arm预测速度 | 备注 |
+| 模型(点击获取代码)               | mIoU | 模型大小 | GPU预测速度 | Arm预测速度 | 备注 |
 | :----------------  | :------- | :------- | :---------  | :---------  | :-----    |
 | :----------------  | :------- | :------- | :---------  | :---------  | :-----    |
 | [DeepLabv3p-MobileNetV2-x0.25](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py) |  -  |  2.9MB  |  -   | -  |  模型小,预测速度快,适用于低性能或移动端设备   |
 | [DeepLabv3p-MobileNetV2-x0.25](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py) |  -  |  2.9MB  |  -   | -  |  模型小,预测速度快,适用于低性能或移动端设备   |
 | [DeepLabv3p-MobileNetV2-x1.0](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2.py) |  69.8%  |  11MB  |  -   | -  |  模型小,预测速度快,适用于低性能或移动端设备   |
 | [DeepLabv3p-MobileNetV2-x1.0](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2.py) |  69.8%  |  11MB  |  -   | -  |  模型小,预测速度快,适用于低性能或移动端设备   |

+ 1 - 1
paddlex/__init__.py

@@ -1,4 +1,4 @@
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 #
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # you may not use this file except in compliance with the License.

+ 55 - 4
paddlex/command.py

@@ -15,6 +15,7 @@
 from six import text_type as _text_type
 from six import text_type as _text_type
 import argparse
 import argparse
 import sys
 import sys
+import os.path as osp
 import paddlex.utils.logging as logging
 import paddlex.utils.logging as logging
 
 
 
 
@@ -85,6 +86,33 @@ def arg_parser():
         "-fs",
         "-fs",
         default=None,
         default=None,
         help="export inference model with fixed input shape:[w,h]")
         help="export inference model with fixed input shape:[w,h]")
+    parser.add_argument(
+        "--split_dataset",
+        "-sd",
+        action="store_true",
+        default=False,
+        help="split dataset with the split value")
+    parser.add_argument(
+        "--form",
+        "-f",
+        default=None,
+        help="define dataset format(ImageNet/COCO/VOC/Seg)")
+    parser.add_argument(
+        "--dataset_dir",
+        "-dd",
+        type=_text_type,
+        default=None,
+        help="define the path of dataset to be splited")
+    parser.add_argument(
+        "--val_value",
+        "-vv",
+        default=None,
+        help="define the value of validation dataset(E.g 0.2)")
+    parser.add_argument(
+        "--test_value",
+        "-tv",
+        default=None,
+        help="define the value of test dataset(E.g 0.1)")
     return parser
     return parser
 
 
 
 
@@ -135,7 +163,7 @@ def main():
                 "paddlex --export_inference --model_dir model_path --save_dir infer_model"
                 "paddlex --export_inference --model_dir model_path --save_dir infer_model"
             )
             )
         pdx.convertor.export_onnx_model(model, args.save_dir)
         pdx.convertor.export_onnx_model(model, args.save_dir)
-        
+
     if args.data_conversion:
     if args.data_conversion:
         assert args.source is not None, "--source should be defined while converting dataset"
         assert args.source is not None, "--source should be defined while converting dataset"
         assert args.to is not None, "--to should be defined to confirm the taregt dataset format"
         assert args.to is not None, "--to should be defined to confirm the taregt dataset format"
@@ -150,9 +178,32 @@ def main():
             logging.error(
             logging.error(
                 "The jingling dataset can not convert to the PascalVOC dataset.",
                 "The jingling dataset can not convert to the PascalVOC dataset.",
                 exit=False)
                 exit=False)
-        pdx.tools.convert.dataset_conversion(args.source, args.to, 
-                                             args.pics, args.annotations, args.save_dir )
-        
+        pdx.tools.convert.dataset_conversion(args.source, args.to, args.pics,
+                                             args.annotations, args.save_dir)
+
+    if args.split_dataset:
+        assert args.dataset_dir is not None, "--dataset_dir should be defined while spliting dataset"
+        assert args.form is not None, "--form should be defined while spliting dataset"
+        assert args.val_value is not None, "--val_value should be defined while spliting dataset"
+
+        dataset_dir = args.dataset_dir
+        dataset_form = args.form.lower()
+        val_value = float(args.val_value)
+        test_value = float(args.test_value
+                           if args.test_value is not None else 0)
+        save_dir = dataset_dir
+
+        if not dataset_form in ["coco", "imagenet", "voc", "seg"]:
+            logging.error(
+                "The dataset form is not correct defined.(support COCO/ImageNet/VOC/Seg)"
+            )
+        if not osp.exists(dataset_dir):
+            logging.error("The path of dataset to be splited doesn't exist.")
+        if val_value <= 0 or val_value >= 1 or test_value < 0 or test_value >= 1 or val_value + test_value >= 1:
+            logging.error("The value of split is not correct.")
+
+        pdx.tools.split.dataset_split(dataset_dir, dataset_form, val_value,
+                                      test_value, save_dir)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 1 - 1
paddlex/cv/models/yolo_v3.py

@@ -45,7 +45,7 @@ class YOLOv3(BaseAPI):
         nms_score_threshold (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
         nms_score_threshold (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
         nms_topk (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
         nms_topk (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
         nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
         nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
-        nms_iou_threshold (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
+        nms_iou_threshold (float): 进行NMS时,用于剔除检测框IoU的阈值。默认为0.45。
         label_smooth (bool): 是否使用label smooth。默认值为False。
         label_smooth (bool): 是否使用label smooth。默认值为False。
         train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
         train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
     """
     """

+ 2 - 1
paddlex/tools/__init__.py

@@ -14,4 +14,5 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from .convert import *
+from .convert import *
+from .split import *

+ 0 - 0
paddlex/tools/dataset_split/__init__.py


+ 64 - 0
paddlex/tools/dataset_split/coco_split.py

@@ -0,0 +1,64 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import random
+import json
+from pycocotools.coco import COCO
+from .utils import MyEncoder
+import paddlex.utils.logging as logging
+
+
+def split_coco_dataset(dataset_dir, val_percent, test_percent, save_dir):
+    if not osp.exists(osp.join(dataset_dir, "annotations.json")):
+        logging.error("\'annotations.json\' is not found in {}!".format(
+            dataset_dir))
+
+    annotation_file = osp.join(dataset_dir, "annotations.json")
+    coco = COCO(annotation_file)
+    img_ids = coco.getImgIds()
+    cat_ids = coco.getCatIds()
+    anno_ids = coco.getAnnIds()
+
+    val_num = int(len(img_ids) * val_percent)
+    test_num = int(len(img_ids) * test_percent)
+    train_num = len(img_ids) - val_num - test_num
+
+    random.shuffle(img_ids)
+    train_files_ids = img_ids[:train_num]
+    val_files_ids = img_ids[train_num:train_num + val_num]
+    test_files_ids = img_ids[train_num + val_num:]
+
+    for img_id_list in [train_files_ids, val_files_ids, test_files_ids]:
+        img_anno_ids = coco.getAnnIds(imgIds=img_id_list, iscrowd=0)
+        imgs = coco.loadImgs(img_id_list)
+        instances = coco.loadAnns(img_anno_ids)
+        categories = coco.loadCats(cat_ids)
+        img_dict = {
+            "annotations": instances,
+            "images": imgs,
+            "categories": categories
+        }
+
+        if img_id_list == train_files_ids:
+            json_file = open(osp.join(save_dir, 'train.json'), 'w+')
+            json.dump(img_dict, json_file, cls=MyEncoder)
+        elif img_id_list == val_files_ids:
+            json_file = open(osp.join(save_dir, 'val.json'), 'w+')
+            json.dump(img_dict, json_file, cls=MyEncoder)
+        elif img_id_list == test_files_ids and len(test_files_ids):
+            json_file = open(osp.join(save_dir, 'test.json'), 'w+')
+            json.dump(img_dict, json_file, cls=MyEncoder)
+
+    return train_num, val_num, test_num

+ 75 - 0
paddlex/tools/dataset_split/imagenet_split.py

@@ -0,0 +1,75 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import random
+from .utils import list_files, is_pic
+import paddlex.utils.logging as logging
+
+
+def split_imagenet_dataset(dataset_dir, val_percent, test_percent, save_dir):
+    all_files = list_files(dataset_dir)
+    label_list = list()
+    train_image_anno_list = list()
+    val_image_anno_list = list()
+    test_image_anno_list = list()
+    for file in all_files:
+        if not is_pic(file):
+            continue
+        label, image_name = osp.split(file)
+        if label not in label_list:
+            label_list.append(label)
+    label_list = sorted(label_list)
+
+    for i in range(len(label_list)):
+        image_list = list_files(osp.join(dataset_dir, label_list[i]))
+        image_anno_list = list()
+        for img in image_list:
+            image_anno_list.append([osp.join(label_list[i], img), i])
+        random.shuffle(image_anno_list)
+        image_num = len(image_anno_list)
+        val_num = int(image_num * val_percent)
+        test_num = int(image_num * test_percent)
+        train_num = image_num - val_num - test_num
+
+        train_image_anno_list += image_anno_list[:train_num]
+        val_image_anno_list += image_anno_list[train_num:train_num + val_num]
+        test_image_anno_list += image_anno_list[train_num + val_num:]
+
+    with open(
+            osp.join(save_dir, 'train_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in train_image_anno_list:
+            file, label = x
+            f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'val_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in val_image_anno_list:
+            file, label = x
+            f.write('{} {}\n'.format(file, label))
+    if len(test_image_anno_list):
+        with open(
+                osp.join(save_dir, 'test_list.txt'), mode='w',
+                encoding='utf-8') as f:
+            for x in test_image_anno_list:
+                file, label = x
+                f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f:
+        for l in sorted(label_list):
+            f.write('{}\n'.format(l))
+
+    return len(train_image_anno_list), len(val_image_anno_list), len(
+        test_image_anno_list)

+ 93 - 0
paddlex/tools/dataset_split/seg_split.py

@@ -0,0 +1,93 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import random
+from .utils import list_files, is_pic, replace_ext, read_seg_ann
+import paddlex.utils.logging as logging
+
+
+def split_seg_dataset(dataset_dir, val_percent, test_percent, save_dir):
+    if not osp.exists(osp.join(dataset_dir, "JPEGImages")):
+        logging.error("\'JPEGImages\' is not found in {}!".format(dataset_dir))
+    if not osp.exists(osp.join(dataset_dir, "Annotations")):
+        logging.error("\'Annotations\' is not found in {}!".format(
+            dataset_dir))
+
+    all_image_files = list_files(osp.join(dataset_dir, "JPEGImages"))
+
+    image_anno_list = list()
+    label_list = list()
+    for image_file in all_image_files:
+        if not is_pic(image_file):
+            continue
+        anno_name = replace_ext(image_file, "png")
+        if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
+            image_anno_list.append([image_file, anno_name])
+        else:
+            anno_name = replace_ext(image_file, "PNG")
+            if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
+                image_anno_list.append([image_file, anno_name])
+
+    if not osp.exists(osp.join(dataset_dir, "labels.txt")):
+        for image_anno in image_anno_list:
+            labels = read_seg_ann(
+                osp.join(dataset_dir, "Annotations", anno_name))
+            for i in labels:
+                if i not in label_list:
+                    label_list.append(i)
+        # 如果类标签的最大值大于类别数,添加对应缺失的标签
+        if len(label_list) != max(label_list) + 1:
+            label_list = [i for i in range(max(label_list) + 1)]
+
+    random.shuffle(image_anno_list)
+    image_num = len(image_anno_list)
+    val_num = int(image_num * val_percent)
+    test_num = int(image_num * test_percent)
+    train_num = image_num - val_num - test_num
+
+    train_image_anno_list = image_anno_list[:train_num]
+    val_image_anno_list = image_anno_list[train_num:train_num + val_num]
+    test_image_anno_list = image_anno_list[train_num + val_num:]
+
+    with open(
+            osp.join(save_dir, 'train_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in train_image_anno_list:
+            file = osp.join("JPEGImages", x[0])
+            label = osp.join("Annotations", x[1])
+            f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'val_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in val_image_anno_list:
+            file = osp.join("JPEGImages", x[0])
+            label = osp.join("Annotations", x[1])
+            f.write('{} {}\n'.format(file, label))
+    if len(test_image_anno_list):
+        with open(
+                osp.join(save_dir, 'test_list.txt'), mode='w',
+                encoding='utf-8') as f:
+            for x in test_image_anno_list:
+                file = osp.join("JPEGImages", x[0])
+                label = osp.join("Annotations", x[1])
+                f.write('{} {}\n'.format(file, label))
+    if len(label_list):
+        with open(
+                osp.join(save_dir, 'labels.txt'), mode='w',
+                encoding='utf-8') as f:
+            for l in sorted(label_list):
+                f.write('{}\n'.format(l))
+
+    return train_num, val_num, test_num

+ 102 - 0
paddlex/tools/dataset_split/utils.py

@@ -0,0 +1,102 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+from PIL import Image
+import numpy as np
+import json
+
+
+class MyEncoder(json.JSONEncoder):
+    # 调整json文件存储形式
+    def default(self, obj):
+        if isinstance(obj, np.integer):
+            return int(obj)
+        elif isinstance(obj, np.floating):
+            return float(obj)
+        elif isinstance(obj, np.ndarray):
+            return obj.tolist()
+        else:
+            return super(MyEncoder, self).default(obj)
+
+
+def list_files(dirname):
+    """ 列出目录下所有文件(包括所属的一级子目录下文件)
+
+    Args:
+        dirname: 目录路径
+    """
+
+    def filter_file(f):
+        if f.startswith('.'):
+            return True
+        return False
+
+    all_files = list()
+    dirs = list()
+    for f in os.listdir(dirname):
+        if filter_file(f):
+            continue
+        if osp.isdir(osp.join(dirname, f)):
+            dirs.append(f)
+        else:
+            all_files.append(f)
+    for d in dirs:
+        for f in os.listdir(osp.join(dirname, d)):
+            if filter_file(f):
+                continue
+            if osp.isdir(osp.join(dirname, d, f)):
+                continue
+            all_files.append(osp.join(d, f))
+    return all_files
+
+
+def is_pic(filename):
+    """ 判断文件是否为图片格式
+
+    Args:
+        filename: 文件路径
+    """
+    suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
+    suffix = filename.strip().split('.')[-1]
+    if suffix not in suffixes:
+        return False
+    return True
+
+
+def replace_ext(filename, new_ext):
+    """ 替换文件后缀
+
+    Args:
+        filename: 文件路径
+        new_ext: 需要替换的新的后缀
+    """
+    items = filename.split(".")
+    items[-1] = new_ext
+    new_filename = ".".join(items)
+    return new_filename
+
+
+def read_seg_ann(pngfile):
+    """ 解析语义分割的标注png图片
+
+    Args:
+        pngfile: 包含标注信息的png图片路径
+    """
+    grt = np.asarray(Image.open(pngfile))
+    labels = list(np.unique(grt))
+    if 255 in labels:
+        labels.remove(255)
+    return labels

+ 88 - 0
paddlex/tools/dataset_split/voc_split.py

@@ -0,0 +1,88 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import random
+import xml.etree.ElementTree as ET
+from .utils import list_files, is_pic, replace_ext
+import paddlex.utils.logging as logging
+
+
+def split_voc_dataset(dataset_dir, val_percent, test_percent, save_dir):
+    if not osp.exists(osp.join(dataset_dir, "JPEGImages")):
+        logging.error("\'JPEGImages\' is not found in {}!".format(dataset_dir))
+    if not osp.exists(osp.join(dataset_dir, "Annotations")):
+        logging.error("\'Annotations\' is not found in {}!".format(
+            dataset_dir))
+
+    all_image_files = list_files(osp.join(dataset_dir, "JPEGImages"))
+
+    image_anno_list = list()
+    label_list = list()
+    for image_file in all_image_files:
+        if not is_pic(image_file):
+            continue
+        anno_name = replace_ext(image_file, "xml")
+        if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
+            image_anno_list.append([image_file, anno_name])
+            try:
+                tree = ET.parse(
+                    osp.join(dataset_dir, "Annotations", anno_name))
+            except:
+                raise Exception("文件{}不是一个良构的xml文件,请检查标注文件".format(
+                    osp.join(dataset_dir, "Annotations", anno_name)))
+            objs = tree.findall("object")
+            for i, obj in enumerate(objs):
+                cname = obj.find('name').text
+                if not cname in label_list:
+                    label_list.append(cname)
+
+    random.shuffle(image_anno_list)
+    image_num = len(image_anno_list)
+    val_num = int(image_num * val_percent)
+    test_num = int(image_num * test_percent)
+    train_num = image_num - val_num - test_num
+
+    train_image_anno_list = image_anno_list[:train_num]
+    val_image_anno_list = image_anno_list[train_num:train_num + val_num]
+    test_image_anno_list = image_anno_list[train_num + val_num:]
+
+    with open(
+            osp.join(save_dir, 'train_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in train_image_anno_list:
+            file = osp.join("JPEGImages", x[0])
+            label = osp.join("Annotations", x[1])
+            f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'val_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in val_image_anno_list:
+            file = osp.join("JPEGImages", x[0])
+            label = osp.join("Annotations", x[1])
+            f.write('{} {}\n'.format(file, label))
+    if len(test_image_anno_list):
+        with open(
+                osp.join(save_dir, 'test_list.txt'), mode='w',
+                encoding='utf-8') as f:
+            for x in test_image_anno_list:
+                file = osp.join("JPEGImages", x[0])
+                label = osp.join("Annotations", x[1])
+                f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f:
+        for l in sorted(label_list):
+            f.write('{}\n'.format(l))
+
+    return train_num, val_num, test_num

+ 40 - 0
paddlex/tools/split.py

@@ -0,0 +1,40 @@
+#!/usr/bin/env python
+# coding: utf-8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .dataset_split.coco_split import split_coco_dataset
+from .dataset_split.voc_split import split_voc_dataset
+from .dataset_split.imagenet_split import split_imagenet_dataset
+from .dataset_split.seg_split import split_seg_dataset
+
+
+def dataset_split(dataset_dir, dataset_form, val_value, test_value, save_dir):
+    if dataset_form == "coco":
+        train_num, val_num, test_num = split_coco_dataset(
+            dataset_dir, val_value, test_value, save_dir)
+    elif dataset_form == "voc":
+        train_num, val_num, test_num = split_voc_dataset(
+            dataset_dir, val_value, test_value, save_dir)
+    elif dataset_form == "seg":
+        train_num, val_num, test_num = split_seg_dataset(
+            dataset_dir, val_value, test_value, save_dir)
+    elif dataset_form == "imagenet":
+        train_num, val_num, test_num = split_imagenet_dataset(
+            dataset_dir, val_value, test_value, save_dir)
+    print("Dataset Split Done.")
+    print("Train samples: {}".format(train_num))
+    print("Eval samples: {}".format(val_num))
+    print("Test samples: {}".format(test_num))
+    print("Split file saved in {}".format(save_dir))