Browse Source

Merge pull request #22 from SunAhong1993/syf_det_trans

add easydata
Jason 5 years ago
parent
commit
4cc445a40c

+ 63 - 4
docs/apis/datasets.md

@@ -16,7 +16,7 @@ paddlex.datasets.ImageNet(data_dir, file_list, label_list, transforms=None, num_
 > * **transforms** (paddlex.cls.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.cls.transforms](./transforms/cls_transforms.md)。  
 > * **transforms** (paddlex.cls.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.cls.transforms](./transforms/cls_transforms.md)。  
 > * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。  
 > * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。  
 > * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
 > * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
-> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'thread'(Windows和Mac下会强制使用thread,该参数无效)。  
+> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。  
 > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 
 
 ## VOCDetection类
 ## VOCDetection类
@@ -37,7 +37,7 @@ paddlex.datasets.VOCDetection(data_dir, file_list, label_list, transforms=None,
 > * **transforms** (paddlex.det.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.det.transforms](./transforms/det_transforms.md)。  
 > * **transforms** (paddlex.det.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.det.transforms](./transforms/det_transforms.md)。  
 > * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
 > * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
 > * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
 > * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
-> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'thread'(Windows和Mac下会强制使用thread,该参数无效)。  
+> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。  
 > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 
 
 ## CocoDetection类
 ## CocoDetection类
@@ -57,7 +57,7 @@ paddlex.datasets.CocoDetection(data_dir, ann_file, transforms=None, num_workers=
 > * **transforms** (paddlex.det.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.det.transforms](./transforms/det_transforms.md)。  
 > * **transforms** (paddlex.det.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.det.transforms](./transforms/det_transforms.md)。  
 > * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。  
 > * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。  
 > * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
 > * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
-> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'thread'(Windows和Mac下会强制使用thread,该参数无效)。  
+> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。  
 > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 
 
 ## SegDataset类
 ## SegDataset类
@@ -78,5 +78,64 @@ paddlex.datasets.SegDataset(data_dir, file_list, label_list, transforms=None, nu
 > * **transforms** (paddlex.seg.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.seg.transforms](./transforms/seg_transforms.md)。  
 > * **transforms** (paddlex.seg.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.seg.transforms](./transforms/seg_transforms.md)。  
 > * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
 > * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
 > * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
 > * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
-> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'thread'(Windows和Mac下会强制使用thread,该参数无效)。  
+> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。  
+> * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。 
+
+## EasyDataCls类
+```
+paddlex.datasets.SegDataset(data_dir, file_list, label_list, transforms=None, num_workers='auto', buffer_size=100, parallel_method='thread', shuffle=False)
+```
+读取EasyData图像分类数据集,并对样本进行相应的处理。EasyData图像分类任务数据集格式的介绍可查看文档:[数据集格式说明](../datasets.md)  
+
+
+### 参数
+
+> * **data_dir** (str): 数据集所在的目录路径。  
+> * **file_list** (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对`data_dir`的相对路径)。
+> * **label_list** (str): 描述数据集包含的类别信息文件路径。  
+> * **transforms** (paddlex.seg.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.cls.transforms](./transforms/cls_transforms.md)。  
+> * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
+> * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
+> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。  
+> * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。
+
+## EasyDataDet类
+
+```
+paddlex.datasets.EasyDataDet(data_dir, file_list, label_list, transforms=None, num_workers=‘auto’, buffer_size=100, parallel_method='thread', shuffle=False)
+```
+
+读取EasyData目标检测格式数据集,并对样本进行相应的处理,该格式的数据集同样可以应用到实例分割模型的训练中。EasyData目标检测或实例分割任务数据集格式的介绍可查看文档:[数据集格式说明](../datasets.md)  
+
+
+### 参数
+
+> * **data_dir** (str): 数据集所在的目录路径。  
+> * **file_list** (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对`data_dir`的相对路径)。
+> * **label_list** (str): 描述数据集包含的类别信息文件路径。  
+> * **transforms** (paddlex.det.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.det.transforms](./transforms/det_transforms.md)。  
+> * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
+> * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
+> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。  
 > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
 > * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。  
+
+
+## EasyDataSeg类
+
+```
+paddlex.datasets.EasyDataSeg(data_dir, file_list, label_list, transforms=None, num_workers='auto', buffer_size=100, parallel_method='thread', shuffle=False)
+```
+
+读取EasyData语分分割任务数据集,并对样本进行相应的处理。EasyData语义分割任务数据集格式的介绍可查看文档:[数据集格式说明](../datasets.md)  
+
+
+### 参数
+
+> * **data_dir** (str): 数据集所在的目录路径。  
+> * **file_list** (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对`data_dir`的相对路径)。
+> * **label_list** (str): 描述数据集包含的类别信息文件路径。  
+> * **transforms** (paddlex.seg.transforms): 数据集中每个样本的预处理/增强算子,详见[paddlex.seg.transforms](./transforms/seg_transforms.md)。  
+> * **num_workers** (int|str):数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
+> * **buffer_size** (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。  
+> * **parallel_method** (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。  
+> * **shuffle** (bool): 是否需要对数据集中样本打乱顺序。默认为False。 

+ 173 - 10
docs/datasets.md

@@ -41,8 +41,8 @@ labelA
 labelB
 labelB
 ...
 ...
 ```
 ```
-[点击这里](https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz),下载蔬菜分类分类数据集
-在PaddleX中,使用`paddlex.cv.datasets.ImageNet`([API说明](./apis/datasets.html#imagenet))加载分类数据集
+[点击这里](https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz),下载蔬菜分类分类数据集
+在PaddleX中,使用`paddlex.cv.datasets.ImageNet`([API说明](./apis/datasets.html#imagenet))加载分类数据集
 
 
 ## 目标检测VOC
 ## 目标检测VOC
 目标检测VOC数据集包含图像文件夹、标注信息文件夹、标签文件及图像列表文件。
 目标检测VOC数据集包含图像文件夹、标注信息文件夹、标签文件及图像列表文件。
@@ -81,8 +81,8 @@ labelA
 labelB
 labelB
 ...
 ...
 ```
 ```
-[点击这里](https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz),下载昆虫检测数据集
-在PaddleX中,使用`paddlex.cv.datasets.VOCDetection`([API说明](./apis/datasets.html#vocdetection))加载目标检测VOC数据集
+[点击这里](https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz),下载昆虫检测数据集
+在PaddleX中,使用`paddlex.cv.datasets.VOCDetection`([API说明](./apis/datasets.html#vocdetection))加载目标检测VOC数据集
 
 
 ## 目标检测和实例分割COCO
 ## 目标检测和实例分割COCO
 目标检测和实例分割COCO数据集包含图像文件夹及图像标注信息文件。
 目标检测和实例分割COCO数据集包含图像文件夹及图像标注信息文件。
@@ -135,7 +135,7 @@ labelB
   ]
   ]
 }
 }
 ```
 ```
-每个字段的含义如下所示:
+其中,每个字段的含义如下所示:
 
 
 | 域名 | 字段名 | 含义 | 数据类型 | 备注 |
 | 域名 | 字段名 | 含义 | 数据类型 | 备注 |
 |:-----|:--------|:------------|------|:-----|
 |:-----|:--------|:------------|------|:-----|
@@ -155,8 +155,8 @@ labelB
 | categories | supercategory | 类别父类的标签名 | str |  |
 | categories | supercategory | 类别父类的标签名 | str |  |
 
 
 
 
-[点击这里](https://bj.bcebos.com/paddlex/datasets/garbage_ins_det.tar.gz),下载垃圾实例分割数据集
-在PaddleX中,使用`paddlex.cv.datasets.COCODetection`([API说明](./apis/datasets.html#cocodetection))加载COCO格式数据集
+[点击这里](https://bj.bcebos.com/paddlex/datasets/garbage_ins_det.tar.gz),下载垃圾实例分割数据集
+在PaddleX中,使用`paddlex.cv.datasets.COCODetection`([API说明](./apis/datasets.html#cocodetection))加载COCO格式数据集
 
 
 ## 语义分割数据
 ## 语义分割数据
 语义分割数据集包含原图、标注图及相应的文件列表文件。
 语义分割数据集包含原图、标注图及相应的文件列表文件。
@@ -191,13 +191,176 @@ images/xxx2.png annotations/xxx2.png
 
 
 `labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
 `labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
 ```
 ```
+background
 labelA
 labelA
 labelB
 labelB
 ...
 ...
 ```
 ```
 
 
-标注图像为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增,
+标注图像为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增(一般第一个类别为`background`)
 例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)。
 例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)。
 
 
-[点击这里](https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz),下载视盘语义分割数据集
-在PaddleX中,使用`paddlex.cv.datasets.SegReader`([API说明](./apis/datasets.html#segreader))加载语义分割数据集
+[点击这里](https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz),下载视盘语义分割数据集。  
+在PaddleX中,使用`paddlex.cv.datasets.SegReader`([API说明](./apis/datasets.html#segreader))加载语义分割数据集。
+
+
+## 图像分类EasyDataCls
+
+图像分类EasyDataCls数据集包含存放图像和json文件的文件夹、标签文件及图像列表文件。
+参考数据文件结构如下:
+```
+./dataset/  # 数据集根目录
+|--easydata  # 存放图像和json文件的文件夹
+|  |--0001.jpg
+|  |--0001.json
+|  |--0002.jpg
+|  |--0002.json
+|  └--...
+|
+|--train_list.txt  # 训练文件列表文件
+|
+|--val_list.txt  # 验证文件列表文件
+|
+└--labels.txt  # 标签列表文件
+
+```
+其中,图像文件名应与json文件名一一对应。   
+
+每个json文件存储于`labels`相关的信息。如下所示:
+```
+{"labels": [{"name": "labelA"}]}
+```
+其中,`name`字段代表对应图像的类别。  
+
+`train_list.txt`和`val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为json文件相对于dataset的相对路径。如下所示:
+```
+easydata/0001.jpg easydata/0001.json
+easydata/0002.jpg easydata/0002.json
+...
+```
+
+`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
+```
+labelA
+labelB
+...
+```
+[点击这里](https://ai.baidu.com/easydata/),可以标注图像分类EasyDataCls数据集。  
+在PaddleX中,使用`paddlex.cv.datasets.EasyDataCls`([API说明](./apis/datasets.html#easydatacls))加载分类数据集。
+
+
+## 目标检测和实例分割EasyDataDet
+
+目标检测和实例分割EasyDataDet数据集包含存放图像和json文件的文件夹、标签文件及图像列表文件。
+参考数据文件结构如下:
+```
+./dataset/  # 数据集根目录ß
+|--easydata  # 存放图像和json文件的文件夹
+|  |--0001.jpg
+|  |--0001.json
+|  |--0002.jpg
+|  |--0002.json
+|  └--...
+|
+|--train_list.txt  # 训练文件列表文件
+|
+|--val_list.txt  # 验证文件列表文件
+|
+└--labels.txt  # 标签列表文件
+
+```
+其中,图像文件名应与json文件名一一对应。   
+
+每个json文件存储于`labels`相关的信息。如下所示:
+```
+"labels": [{"y1": 18, "x2": 883, "x1": 371, "y2": 404, "name": "labelA", 
+            "mask": "kVfc0`0Zg0<F7J7I5L5K4L4L4L3N3L3N3L3N2N3M2N2N2N2N2N2N1O2N2O1N2N1O2O1N101N1O2O1N101N10001N101N10001N10001O0O10001O000O100000001O0000000000000000000000O1000001O00000O101O000O101O0O101O0O2O0O101O0O2O0O2N2O0O2O0O2N2O1N1O2N2N2O1N2N2N2N2N2N2M3N3M2M4M2M4M3L4L4L4K6K5J7H9E\\iY1"}, 
+           {"y1": 314, "x2": 666, "x1": 227, "y2": 676, "name": "labelB",
+            "mask": "mdQ8g0Tg0:G8I6K5J5L4L4L4L4M2M4M2M4M2N2N2N3L3N2N2N2N2O1N1O2N2N2O1N1O2N2O0O2O1N1O2O0O2O0O2O001N100O2O000O2O000O2O00000O2O000000001N100000000000000000000000000000000001O0O100000001O0O10001N10001O0O101N10001N101N101N101N101N2O0O2N2O0O2N2N2O0O2N2N2N2N2N2N2N2N2N3L3N2N3L3N3L4M2M4L4L5J5L5J7H8H;BUcd<"}, 
+           ...]}
+```
+其中,list中的每个元素代表一个标注信息,标注信息中字段的含义如下所示: 
+
+| 字段名 | 含义 | 数据类型 | 备注 |
+|:--------|:------------|------|:-----|
+| x1 | 标注框左下角横坐标 | int | |
+| y1 | 标注框左下角纵坐标 | int | |
+| x2 | 标注框右上角横坐标 | int | |
+| y2 | 标注框右上角纵坐标 | int | |
+| name | 标注框中物体类标 | str | |
+| mask | 分割区域布尔型numpy编码后的字符串 | str | 该字段可以不存在,当不存在时只能进行目标检测 |
+
+`train_list.txt`和`val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为json文件相对于dataset的相对路径。如下所示:
+```
+easydata/0001.jpg easydata/0001.json
+easydata/0002.jpg easydata/0002.json
+...
+```
+
+`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
+```
+labelA
+labelB
+...
+```
+
+[点击这里](https://ai.baidu.com/easydata/),可以标注图像分类EasyDataDet数据集。  
+在PaddleX中,使用`paddlex.cv.datasets.EasyDataDet`([API说明](./apis/datasets.html#easydatadet))加载分类数据集。
+
+## 语义分割EasyDataSeg
+
+语义分割EasyDataSeg数据集包含存放图像和json文件的文件夹、标签文件及图像列表文件。
+参考数据文件结构如下:
+```
+./dataset/  # 数据集根目录ß
+|--easydata  # 存放图像和json文件的文件夹
+|  |--0001.jpg
+|  |--0001.json
+|  |--0002.jpg
+|  |--0002.json
+|  └--...
+|
+|--train_list.txt  # 训练文件列表文件
+|
+|--val_list.txt  # 验证文件列表文件
+|
+└--labels.txt  # 标签列表文件
+
+```
+其中,图像文件名应与json文件名一一对应。   
+
+每个json文件存储于`labels`相关的信息。如下所示:
+```
+"labels": [{"y1": 18, "x2": 883, "x1": 371, "y2": 404, "name": "labelA", 
+            "mask": "kVfc0`0Zg0<F7J7I5L5K4L4L4L3N3L3N3L3N2N3M2N2N2N2N2N2N1O2N2O1N2N1O2O1N101N1O2O1N101N10001N101N10001N10001O0O10001O000O100000001O0000000000000000000000O1000001O00000O101O000O101O0O101O0O2O0O101O0O2O0O2N2O0O2O0O2N2O1N1O2N2N2O1N2N2N2N2N2N2M3N3M2M4M2M4M3L4L4L4K6K5J7H9E\\iY1"}, 
+           {"y1": 314, "x2": 666, "x1": 227, "y2": 676, "name": "labelB",
+            "mask": "mdQ8g0Tg0:G8I6K5J5L4L4L4L4M2M4M2M4M2N2N2N3L3N2N2N2N2O1N1O2N2N2O1N1O2N2O0O2O1N1O2O0O2O0O2O001N100O2O000O2O000O2O00000O2O000000001N100000000000000000000000000000000001O0O100000001O0O10001N10001O0O101N10001N101N101N101N101N2O0O2N2O0O2N2N2O0O2N2N2N2N2N2N2N2N2N3L3N2N3L3N3L4M2M4L4L5J5L5J7H8H;BUcd<"}, 
+           ...]}
+```
+其中,list中的每个元素代表一个标注信息,标注信息中字段的含义如下所示: 
+
+| 字段名 | 含义 | 数据类型 | 备注 |
+|:--------|:------------|------|:-----|
+| x1 | 标注框左下角横坐标 | int | |
+| y1 | 标注框左下角纵坐标 | int | |
+| x2 | 标注框右上角横坐标 | int | |
+| y2 | 标注框右上角纵坐标 | int | |
+| name | 标注框中物体类标 | str | |
+| mask | 分割区域布尔型numpy编码后的字符串 | str | 该字段必须存在 |
+
+`train_list.txt`和`val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为json文件相对于dataset的相对路径。如下所示:
+```
+easydata/0001.jpg easydata/0001.json
+easydata/0002.jpg easydata/0002.json
+...
+```
+
+`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
+```
+labelA
+labelB
+...
+```
+
+[点击这里](https://ai.baidu.com/easydata/),可以标注图像分类EasyDataSeg数据集。  
+在PaddleX中,使用`paddlex.cv.datasets.EasyDataSeg`([API说明](./apis/datasets.html#easydataseg))加载分类数据集。

+ 3 - 0
paddlex/cv/datasets/__init__.py

@@ -16,3 +16,6 @@ from .imagenet import ImageNet
 from .voc import VOCDetection
 from .voc import VOCDetection
 from .coco import CocoDetection
 from .coco import CocoDetection
 from .seg_dataset import SegDataset
 from .seg_dataset import SegDataset
+from .easydata_cls import EasyDataCls
+from .easydata_det import EasyDataDet
+from .easydata_seg import EasyDataSeg

+ 1 - 1
paddlex/cv/datasets/coco.py

@@ -34,7 +34,7 @@ class CocoDetection(VOCDetection):
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
         buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
         buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
         parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
         parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
-            线程和'process'进程两种方式。默认为'thread'(Windows和Mac下会强制使用thread,该参数无效)。
+            线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
     """
     """
 
 

+ 86 - 0
paddlex/cv/datasets/easydata_cls.py

@@ -0,0 +1,86 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import os.path as osp
+import random
+import copy
+import json
+import paddlex.utils.logging as logging
+from .imagenet import ImageNet
+from .dataset import is_pic
+from .dataset import get_encoding
+
+
+class EasyDataCls(ImageNet):
+    """读取EasyDataCls格式的分类数据集,并对样本进行相应的处理。
+
+    Args:
+        data_dir (str): 数据集所在的目录路径。
+        file_list (str): 描述数据集图片文件和类别id的文件路径(文本内每行路径为相对data_dir的相对路)。
+        label_list (str): 描述数据集包含的类别信息文件路径。
+        transforms (paddlex.cls.transforms): 数据集中每个样本的预处理/增强算子。
+        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
+            系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核
+            数的一半。
+        buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
+        parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
+            线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
+        shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
+    """
+    
+    def __init__(self,
+                 data_dir,
+                 file_list,
+                 label_list,
+                 transforms=None,
+                 num_workers='auto',
+                 buffer_size=100,
+                 parallel_method='process',
+                 shuffle=False):
+        super(ImageNet, self).__init__(
+            transforms=transforms,
+            num_workers=num_workers,
+            buffer_size=buffer_size,
+            parallel_method=parallel_method,
+            shuffle=shuffle)
+        self.file_list = list()
+        self.labels = list()
+        self._epoch = 0
+        
+        with open(label_list, encoding=get_encoding(label_list)) as f:
+            for line in f:
+                item = line.strip()
+                self.labels.append(item)
+        logging.info("Starting to read file list from dataset...")
+        with open(file_list, encoding=get_encoding(file_list)) as f:
+            for line in f:
+                img_file, json_file = [osp.join(data_dir, x) \
+                        for x in line.strip().split()[:2]]
+                if not is_pic(img_file):
+                    continue
+                if not osp.isfile(json_file):
+                    continue
+                if not osp.exists(img_file):
+                    raise IOError(
+                        'The image file {} is not exist!'.format(img_file))
+                with open(json_file, mode='r', \
+                          encoding=get_encoding(json_file)) as j:
+                    json_info = json.load(j)
+                label = json_info['labels'][0]['name']
+                self.file_list.append([img_file, self.labels.index(label)])
+        self.num_samples = len(self.file_list)
+        logging.info("{} samples in file {}".format(
+            len(self.file_list), file_list))
+    

+ 190 - 0
paddlex/cv/datasets/easydata_det.py

@@ -0,0 +1,190 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import os.path as osp
+import random
+import copy
+import json
+import cv2
+import numpy as np
+import paddlex.utils.logging as logging
+from .voc import VOCDetection
+from .dataset import is_pic
+from .dataset import get_encoding
+
+class EasyDataDet(VOCDetection):
+    """读取EasyDataDet格式的检测数据集,并对样本进行相应的处理。
+
+    Args:
+        data_dir (str): 数据集所在的目录路径。
+        file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
+        label_list (str): 描述数据集包含的类别信息文件路径。
+        transforms (paddlex.det.transforms): 数据集中每个样本的预处理/增强算子。
+        num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
+            系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
+            一半。
+        buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
+        parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
+            线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
+        shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
+    """
+    
+    def __init__(self,
+                 data_dir,
+                 file_list,
+                 label_list,
+                 transforms=None,
+                 num_workers='auto',
+                 buffer_size=100,
+                 parallel_method='process',
+                 shuffle=False):
+        super(VOCDetection, self).__init__(
+            transforms=transforms,
+            num_workers=num_workers,
+            buffer_size=buffer_size,
+            parallel_method=parallel_method,
+            shuffle=shuffle)
+        self.file_list = list()
+        self.labels = list()
+        self._epoch = 0
+        
+        annotations = {}
+        annotations['images'] = []
+        annotations['categories'] = []
+        annotations['annotations'] = []
+        
+        cname2cid = {}
+        label_id = 1
+        with open(label_list, encoding=get_encoding(label_list)) as fr:
+            for line in fr.readlines():
+                cname2cid[line.strip()] = label_id
+                label_id += 1
+                self.labels.append(line.strip())
+        logging.info("Starting to read file list from dataset...")
+        for k, v in cname2cid.items():
+            annotations['categories'].append({
+                'supercategory': 'component',
+                'id': v,
+                'name': k
+            })
+            
+        from pycocotools.mask import decode
+        ct = 0
+        ann_ct = 0
+        with open(file_list, encoding=get_encoding(file_list)) as f:
+            for line in f:
+                img_file, json_file = [osp.join(data_dir, x) \
+                        for x in line.strip().split()[:2]]
+                if not is_pic(img_file):
+                    continue
+                if not osp.isfile(json_file):
+                    continue
+                if not osp.exists(img_file):
+                    raise IOError(
+                        'The image file {} is not exist!'.format(img_file))
+                with open(json_file, mode='r', \
+                          encoding=get_encoding(json_file)) as j:
+                    json_info = json.load(j)
+                im_id = np.array([ct])
+                im = cv2.imread(img_file)
+                im_w = im.shape[1]
+                im_h = im.shape[0]
+                objs = json_info['labels']
+                gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
+                gt_class = np.zeros((len(objs), 1), dtype=np.int32)
+                gt_score = np.ones((len(objs), 1), dtype=np.float32)
+                is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
+                difficult = np.zeros((len(objs), 1), dtype=np.int32)
+                gt_poly = [None] * len(objs)
+                for i, obj in enumerate(objs):
+                    cname = obj['name']
+                    gt_class[i][0] = cname2cid[cname]
+                    x1 = max(0, obj['x1'])
+                    y1 = max(0, obj['y1'])
+                    x2 = min(im_w - 1, obj['x2'])
+                    y2 = min(im_h - 1, obj['y2'])
+                    gt_bbox[i] = [x1, y1, x2, y2]
+                    is_crowd[i][0] = 0
+                    if 'mask' in obj:
+                        mask_dict = {}
+                        mask_dict['size'] = [im_h, im_w]
+                        mask_dict['counts'] = obj['mask'].encode()
+                        mask = decode(mask_dict)
+                        gt_poly[i] = self.mask2polygon(mask)
+                    annotations['annotations'].append({
+                        'iscrowd':
+                        0,
+                        'image_id':
+                        int(im_id[0]),
+                        'bbox': [x1, y1, x2 - x1 + 1, y2 - y1 + 1],
+                        'area':
+                        float((x2 - x1 + 1) * (y2 - y1 + 1)),
+                        'segmentation':
+                        [[x1, y1, x1, y2, x2, y2, x2, y1]] if gt_poly[i] is None else gt_poly[i],
+                        'category_id':
+                        cname2cid[cname],
+                        'id':
+                        ann_ct,
+                        'difficult':
+                        0
+                    })
+                    ann_ct += 1
+                im_info = {
+                    'im_id': im_id,
+                    'origin_shape': np.array([im_h, im_w]).astype('int32'),
+                }
+                label_info = {
+                    'is_crowd': is_crowd,
+                    'gt_class': gt_class,
+                    'gt_bbox': gt_bbox,
+                    'gt_score': gt_score,
+                    'difficult': difficult
+                }
+                if None not in gt_poly:
+                    label_info['gt_poly'] = gt_poly
+                voc_rec = (im_info, label_info)
+                if len(objs) != 0:
+                    self.file_list.append([img_file, voc_rec])
+                    ct += 1
+                    annotations['images'].append({
+                        'height':
+                        im_h,
+                        'width':
+                        im_w,
+                        'id':
+                        int(im_id[0]),
+                        'file_name':
+                        osp.split(img_file)[1]
+                    })
+
+        if not len(self.file_list) > 0:
+            raise Exception('not found any voc record in %s' % (file_list))
+        logging.info("{} samples in file {}".format(
+            len(self.file_list), file_list))
+        self.num_samples = len(self.file_list)
+        from pycocotools.coco import COCO
+        self.coco_gt = COCO()
+        self.coco_gt.dataset = annotations
+        self.coco_gt.createIndex()
+        
+    def mask2polygon(self, mask):
+        contours, hierarchy = cv2.findContours(
+            (mask).astype(np.uint8), cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
+        segmentation = []
+        for contour in contours:
+            contour_list = contour.flatten().tolist()
+            if len(contour_list) > 4:
+                segmentation.append(contour_list)
+        return segmentation

+ 116 - 0
paddlex/cv/datasets/easydata_seg.py

@@ -0,0 +1,116 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import os.path as osp
+import random
+import copy
+import json
+import cv2
+import numpy as np
+import paddlex.utils.logging as logging
+from .dataset import Dataset
+from .dataset import get_encoding
+from .dataset import is_pic
+
+class EasyDataSeg(Dataset):
+    """读取EasyDataSeg语义分割任务数据集,并对样本进行相应的处理。
+
+    Args:
+        data_dir (str): 数据集所在的目录路径。
+        file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
+        label_list (str): 描述数据集包含的类别信息文件路径。
+        transforms (list): 数据集中每个样本的预处理/增强算子。
+        num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。
+        buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
+        parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
+            线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
+        shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
+    """
+
+    def __init__(self,
+                 data_dir,
+                 file_list,
+                 label_list,
+                 transforms=None,
+                 num_workers='auto',
+                 buffer_size=100,
+                 parallel_method='process',
+                 shuffle=False):
+        super(EasyDataSeg, self).__init__(
+            transforms=transforms,
+            num_workers=num_workers,
+            buffer_size=buffer_size,
+            parallel_method=parallel_method,
+            shuffle=shuffle)
+        self.file_list = list()
+        self.labels = list()
+        self._epoch = 0
+
+        from pycocotools.mask import decode
+        cname2cid = {}
+        label_id = 0
+        with open(label_list, encoding=get_encoding(label_list)) as fr:
+            for line in fr.readlines():
+                cname2cid[line.strip()] = label_id
+                label_id += 1
+                self.labels.append(line.strip())
+                
+        with open(file_list, encoding=get_encoding(file_list)) as f:
+            for line in f:
+                img_file, json_file = [osp.join(data_dir, x) \
+                        for x in line.strip().split()[:2]]
+                if not is_pic(img_file):
+                    continue
+                if not osp.isfile(json_file):
+                    continue
+                if not osp.exists(img_file):
+                    raise IOError(
+                        'The image file {} is not exist!'.format(img_file))
+                with open(json_file, mode='r', \
+                          encoding=get_encoding(json_file)) as j:
+                    json_info = json.load(j)
+                im = cv2.imread(img_file)
+                im_w = im.shape[1]
+                im_h = im.shape[0]
+                objs = json_info['labels']
+                lable_npy = np.zeros([im_h, im_w]).astype('uint8')
+                for i, obj in enumerate(objs):
+                    cname = obj['name']
+                    cid = cname2cid[cname]
+                    mask_dict = {}
+                    mask_dict['size'] = [im_h, im_w]
+                    mask_dict['counts'] = obj['mask'].encode()
+                    mask = decode(mask_dict)
+                    mask *= cid
+                    conflict_index = np.where(((lable_npy > 0) & (mask == cid)) == True)
+                    mask[conflict_index] = 0
+                    lable_npy += mask
+                self.file_list.append([img_file, lable_npy])
+        self.num_samples = len(self.file_list)
+        logging.info("{} samples in file {}".format(
+            len(self.file_list), file_list))
+
+    def iterator(self):
+        self._epoch += 1
+        self._pos = 0
+        files = copy.deepcopy(self.file_list)
+        if self.shuffle:
+            random.shuffle(files)
+        files = files[:self.num_samples]
+        self.num_samples = len(files)
+        for f in files:
+            lable_npy = f[1]
+            sample = [f[0], None, lable_npy]
+            yield sample

+ 1 - 1
paddlex/cv/datasets/imagenet.py

@@ -35,7 +35,7 @@ class ImageNet(Dataset):
             数的一半。
             数的一半。
         buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
         buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
         parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
         parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
-            线程和'process'进程两种方式。默认为'thread'(Windows和Mac下会强制使用thread,该参数无效)。
+            线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
     """
     """
 
 

+ 1 - 1
paddlex/cv/datasets/seg_dataset.py

@@ -33,7 +33,7 @@ class SegDataset(Dataset):
         num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。
         num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。
         buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
         buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
         parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
         parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
-            线程和'process'进程两种方式。默认为'thread'(Windows和Mac下会强制使用thread,该参数无效)。
+            线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
     """
     """
 
 

+ 1 - 1
paddlex/cv/datasets/voc.py

@@ -38,7 +38,7 @@ class VOCDetection(Dataset):
             一半。
             一半。
         buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
         buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
         parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
         parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
-            线程和'process'进程两种方式。默认为'thread'(Windows和Mac下会强制使用thread,该参数无效)。
+            线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
     """
     """
 
 

+ 4 - 3
paddlex/cv/models/faster_rcnn.py

@@ -178,7 +178,7 @@ class FasterRCNN(BaseAPI):
             log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为20。
             log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为20。
             save_dir (str): 模型保存路径。默认值为'output'。
             save_dir (str): 模型保存路径。默认值为'output'。
             pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
             pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
-                则自动下载在ImageNet图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为None
+                则自动下载在ImageNet图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'
             optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
             optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
                 fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
                 fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
             learning_rate (float): 默认优化器的初始学习率。默认为0.0025。
             learning_rate (float): 默认优化器的初始学习率。默认为0.0025。
@@ -199,11 +199,12 @@ class FasterRCNN(BaseAPI):
         if metric is None:
         if metric is None:
             if isinstance(train_dataset, paddlex.datasets.CocoDetection):
             if isinstance(train_dataset, paddlex.datasets.CocoDetection):
                 metric = 'COCO'
                 metric = 'COCO'
-            elif isinstance(train_dataset, paddlex.datasets.VOCDetection):
+            elif isinstance(train_dataset, paddlex.datasets.VOCDetection) or \
+                    isinstance(train_dataset, paddlex.datasets.EasyDataDet):
                 metric = 'VOC'
                 metric = 'VOC'
             else:
             else:
                 raise ValueError(
                 raise ValueError(
-                    "train_dataset should be datasets.VOCDetection or datasets.COCODetection."
+                    "train_dataset should be datasets.VOCDetection or datasets.COCODetection or datasets.EasyDataDet."
                 )
                 )
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         self.metric = metric
         self.metric = metric

+ 3 - 2
paddlex/cv/models/mask_rcnn.py

@@ -162,11 +162,12 @@ class MaskRCNN(FasterRCNN):
             ValueError: 模型从inference model进行加载。
             ValueError: 模型从inference model进行加载。
         """
         """
         if metric is None:
         if metric is None:
-            if isinstance(train_dataset, paddlex.datasets.CocoDetection):
+            if isinstance(train_dataset, paddlex.datasets.CocoDetection) or \
+                    isinstance(train_dataset, paddlex.datasets.EasyDataDet):
                 metric = 'COCO'
                 metric = 'COCO'
             else:
             else:
                 raise Exception(
                 raise Exception(
-                    "train_dataset should be datasets.COCODetection.")
+                    "train_dataset should be datasets.COCODetection or datasets.EasyDataDet.")
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         self.metric = metric
         self.metric = metric
         if not self.trainable:
         if not self.trainable:

+ 4 - 3
paddlex/cv/models/yolo_v3.py

@@ -177,7 +177,7 @@ class YOLOv3(BaseAPI):
             log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为10。
             log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为10。
             save_dir (str): 模型保存路径。默认值为'output'。
             save_dir (str): 模型保存路径。默认值为'output'。
             pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
             pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
-                则自动下载在ImageNet图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为None
+                则自动下载在ImageNet图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'
             optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
             optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
                 fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
                 fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
             learning_rate (float): 默认优化器的学习率。默认为1.0/8000。
             learning_rate (float): 默认优化器的学习率。默认为1.0/8000。
@@ -203,11 +203,12 @@ class YOLOv3(BaseAPI):
         if metric is None:
         if metric is None:
             if isinstance(train_dataset, paddlex.datasets.CocoDetection):
             if isinstance(train_dataset, paddlex.datasets.CocoDetection):
                 metric = 'COCO'
                 metric = 'COCO'
-            elif isinstance(train_dataset, paddlex.datasets.VOCDetection):
+            elif isinstance(train_dataset, paddlex.datasets.VOCDetection) or \
+                    isinstance(train_dataset, paddlex.datasets.EasyDataDet):
                 metric = 'VOC'
                 metric = 'VOC'
             else:
             else:
                 raise ValueError(
                 raise ValueError(
-                    "train_dataset should be datasets.VOCDetection or datasets.COCODetection."
+                    "train_dataset should be datasets.VOCDetection or datasets.COCODetection or datasets.EasyDataDet."
                 )
                 )
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         self.metric = metric
         self.metric = metric

+ 2 - 2
paddlex/cv/transforms/seg_transforms.py

@@ -66,8 +66,8 @@ class Compose:
         if self.to_rgb:
         if self.to_rgb:
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         if label is not None:
         if label is not None:
-            label = np.asarray(Image.open(label))
-
+            if not isinstance(label, np.ndarray):
+                label = np.asarray(Image.open(label))
         for op in self.transforms:
         for op in self.transforms:
             outputs = op(im, im_info, label)
             outputs = op(im, im_info, label)
             im = outputs[0]
             im = outputs[0]