Browse Source

Merge pull request #4 from PaddlePaddle/develop

Update
mamingjie-China 5 years ago
parent
commit
f1bc0bac9b

+ 124 - 0
docs/apis/models/detection.md

@@ -1,5 +1,129 @@
 # Object Detection
 
+## paddlex.det.PPYOLO
+
+```python
+paddlex.det.PPYOLO(num_classes=80, backbone='ResNet50_vd_ssld', with_dcn_v2=True, anchors=None, anchor_masks=None, use_coord_conv=True, use_iou_aware=True, use_spp=True, use_drop_block=True, scale_x_y=1.05, ignore_threshold=0.7, label_smooth=False, use_iou_loss=True, use_matrix_nms=True, nms_score_threshold=0.01, nms_topk=1000, nms_keep_topk=100, nms_iou_threshold=0.45, train_random_shapes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608])
+```
+
+> 构建PPYOLO检测器。**注意在PPYOLO,num_classes不需要包含背景类,如目标包括human、dog两种,则num_classes设为2即可,这里与FasterRCNN/MaskRCNN有差别**
+
+> **参数**
+>
+> > - **num_classes** (int): 类别数。默认为80。
+> > - **backbone** (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd_ssld']。默认为'ResNet50_vd_ssld'。
+> > - **with_dcn_v2** (bool): Backbone是否使用DCNv2结构。默认为True。
+> > - **anchors** (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
+> >                  [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
+>                   [59, 119], [116, 90], [156, 198], [373, 326]]。
+> > - **anchor_masks** (list|tuple): 在计算PPYOLO损失时,使用anchor的mask索引,为None时表示使用默认值
+> >                    [[6, 7, 8], [3, 4, 5], [0, 1, 2]]。
+> > - **use_coord_conv** (bool): 是否使用CoordConv。默认值为True。
+> > - **use_iou_aware** (bool): 是否使用IoU Aware分支。默认值为True。
+> > - **use_spp** (bool): 是否使用Spatial Pyramid Pooling结构。默认值为True。
+> > - **use_drop_block** (bool): 是否使用Drop Block。默认值为True。
+> > - **scale_x_y** (float): 调整中心点位置时的系数因子。默认值为1.05。
+> > - **use_iou_loss** (bool): 是否使用IoU loss。默认值为True。
+> > - **use_matrix_nms** (bool): 是否使用Matrix NMS。默认值为True。  
+> > - **ignore_threshold** (float): 在计算PPYOLO损失时,IoU大于`ignore_threshold`的预测框的置信度被忽略。默认为0.7。
+> > - **nms_score_threshold** (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
+> > - **nms_topk** (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
+> > - **nms_keep_topk** (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
+> > - **nms_iou_threshold** (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
+> > - **label_smooth** (bool): 是否使用label smooth。默认值为False。
+> > - **train_random_shapes** (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
+
+### train
+
+```python
+train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None, use_ema=True, ema_decay=0.9998)
+```
+
+> PPYOLO模型的训练接口,函数内置了`piecewise`学习率衰减策略和`momentum`优化器。
+
+> **参数**
+>
+> > - **num_epochs** (int): 训练迭代轮数。
+> > - **train_dataset** (paddlex.datasets): 训练数据读取器。
+> > - **train_batch_size** (int): 训练数据batch大小。目前检测仅支持单卡评估,训练数据batch大小与显卡数量之商为验证数据batch大小。默认值为8。
+> > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
+> > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为20。
+> > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
+> > - **save_dir** (str): 模型保存路径。默认值为'output'。
+> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet图片数据上预训练的模型权重;若为字符串'COCO',则自动下载在COCO数据集上预训练的模型权重;若为None,则不使用预训练模型。默认为None。
+> > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
+> > - **learning_rate** (float): 默认优化器的学习率。默认为1.0/8000。
+> > - **warmup_steps** (int):  默认优化器进行warmup过程的步数。默认为1000。
+> > - **warmup_start_lr** (int): 默认优化器warmup的起始学习率。默认为0.0。
+> > - **lr_decay_epochs** (list): 默认优化器的学习率衰减轮数。默认为[213, 240]。
+> > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
+> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
+> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
+> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
+> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
+> > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
+> > - **use_ema** (bool): 是否使用指数衰减计算参数的滑动平均值。默认值为True。
+> > - **ema_decay** (float): 指数衰减率。默认值为0.9998。
+
+### evaluate
+
+```python
+evaluate(self, eval_dataset, batch_size=1, epoch_id=None, metric=None, return_details=False)
+```
+
+> PPYOLO模型的评估接口,模型评估后会返回在验证集上的指标`box_map`(metric指定为'VOC'时)或`box_mmap`(metric指定为`COCO`时)。
+
+> **参数**
+>
+> > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
+> > - **batch_size** (int): 验证数据批大小。默认为1。
+> > - **epoch_id** (int): 当前评估模型所在的训练轮数。
+> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,根据用户传入的Dataset自动选择,如为VOCDetection,则`metric`为'VOC';如为COCODetection,则`metric`为'COCO'默认为None, 如为EasyData类型数据集,同时也会使用'VOC'。
+> > - **return_details** (bool): 是否返回详细信息。默认值为False。
+> >
+>  **返回值**
+>
+> > - **tuple** (metrics, eval_details) | **dict** (metrics): 当`return_details`为True时,返回(metrics, eval_details),当`return_details`为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'或者’bbox_map‘,分别表示平均准确率平均值在各个阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
+
+### predict
+
+```python
+predict(self, img_file, transforms=None)
+```
+
+> PPYOLO模型预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`YOLOv3.test_transforms`和`YOLOv3.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`predict`接口时,用户需要再重新定义`test_transforms`传入给`predict`接口
+
+> **参数**
+>
+> > - **img_file** (str|np.ndarray): 预测图像路径或numpy数组(HWC排列,BGR格式)。
+> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
+>
+> **返回值**
+>
+> > - **list**: 预测结果列表,列表中每个元素均为一个dict,key包括'bbox', 'category', 'category_id', 'score',分别表示每个预测目标的框坐标信息、类别、类别id、置信度,其中框坐标信息为[xmin, ymin, w, h],即左上角x, y坐标和框的宽和高。
+
+
+### batch_predict
+
+```python
+batch_predict(self, img_file_list, transforms=None, thread_num=2)
+```
+
+> PPYOLO模型批量预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`YOLOv3.test_transforms`和`YOLOv3.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`batch_predict`接口时,用户需要再重新定义`test_transforms`传入给`batch_predict`接口
+
+> **参数**
+>
+> > - **img_file_list** (str|np.ndarray): 对列表(或元组)中的图像同时进行预测,列表中的元素是预测图像路径或numpy数组(HWC排列,BGR格式)。
+> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
+> > - **thread_num** (int): 并发执行各图像预处理时的线程数。
+>
+> **返回值**
+>
+> > - **list**: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个元素均为一个dict,key包括'bbox', 'category', 'category_id', 'score',分别表示每个预测目标的框坐标信息、类别、类别id、置信度,其中框坐标信息为[xmin, ymin, w, h],即左上角x, y坐标和框的宽和高。
+
+
 ## paddlex.det.YOLOv3
 
 ```python

+ 1 - 0
docs/appendix/model_zoo.md

@@ -45,6 +45,7 @@
 |[FasterRCNN-ResNet101-FPN](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_1x.tar)| 244.2MB | 119.788 | 38.7 |
 |[FasterRCNN-ResNet101_vd-FPN](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar) |244.3MB | 156.097 | 40.5 |
 |[FasterRCNN-HRNet_W18-FPN](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_1x.tar) |115.5MB | 81.592 | 36 |
+|[PPYOLO](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams) | 329.1MB | - |45.9 |
 |[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar)|249.2MB | 42.672 | 38.9 |
 |[YOLOv3-MobileNetV1](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar) |99.2MB | 15.442 | 29.3 |
 |[YOLOv3-MobileNetV3_large](https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams)|100.7MB | 143.322 | 31.6 |

+ 1 - 0
docs/examples/solutions.md

@@ -42,6 +42,7 @@ PaddleX针对图像分类、目标检测、实例分割和语义分割4种视觉
 | YOLOv3-MobileNetV3_larget | 适用于追求高速预测的移动端场景 | 100.7MB | 143.322 | - | - | 31.6 |
 | YOLOv3-MobileNetV1 | 精度相对偏低,适用于追求高速预测的服务器端场景 | 99.2MB| 15.422 | - | - | 29.3 |
 | YOLOv3-DarkNet53 | 在预测速度和模型精度上都有较好的表现,适用于大多数的服务器端场景| 249.2MB | 42.672 | - | - | 38.9 |
+| PPYOLO | 预测速度和模型精度都比YOLOv3-DarkNet53优异,适用于大多数的服务器端场景 | 329.1MB | - | - | - | 45.9 |
 | FasterRCNN-ResNet50-FPN | 经典的二阶段检测器,预测速度相对较慢,适用于重视模型精度的服务器端场景 | 167.MB | 83.189 | - | -| 37.2 |
 | FasterRCNN-HRNet_W18-FPN | 适用于对图像分辨率较为敏感、对目标细节预测要求更高的服务器端场景 | 115.5MB | 81.592 | - | - | 36 |
 | FasterRCNN-ResNet101_vd-FPN | 超高精度模型,预测时间更长,在处理较大数据量时有较高的精度,适用于服务器端场景 | 244.3MB | 156.097 | - | - | 40.5 |

+ 1 - 0
docs/train/object_detection.md

@@ -13,6 +13,7 @@ PaddleX目前提供了FasterRCNN和YOLOv3两种检测结构,多种backbone模
 | [YOLOv3-MobileNetV1](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/yolov3_mobilenetv1.py) |  29.3%  |  99.2MB  |  15.442ms   | -  |  模型小,预测速度快,适用于低性能或移动端设备   |
 | [YOLOv3-MobileNetV3](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/yolov3_mobilenetv3.py)        | 31.6%  | 100.7MB   |  143.322ms  | -  |  模型小,移动端上预测速度有优势   |
 | [YOLOv3-DarkNet53](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/yolov3_darknet53.py)     | 38.9%  | 249.2MB   | 42.672ms   | -  |  模型较大,预测速度快,适用于服务端   |
+| [PPYOLO](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/ppyolo.py) | 45.9% | 329.1MB | - | - | 模型较大,预测速度比YOLOv3-DarkNet53更快,适用于服务端 |
 | [FasterRCNN-ResNet50-FPN](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/faster_rcnn_r50_fpn.py)   |  37.2%   |   167.7MB    |  197.715ms       |   -    | 模型精度高,适用于服务端部署   |
 | [FasterRCNN-ResNet18-FPN](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/faster_rcnn_r18_fpn.py)   |  32.6%   |   173.2MB    |  -       |   -    | 模型精度高,适用于服务端部署   |
 | [FasterRCNN-HRNet-FPN](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/object_detection/faster_rcnn_hrnet_fpn.py)   |  36.0%   |   115.MB    |  81.592ms       |   -    | 模型精度高,预测速度快,适用于服务端部署   |

+ 2 - 2
paddlex/cv/models/base.py

@@ -548,7 +548,7 @@ class BaseAPI:
                 current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
                 if not osp.isdir(current_save_dir):
                     os.makedirs(current_save_dir)
-                if hasattr(self, 'use_ema'):
+                if getattr(self, 'use_ema', False):
                     self.exe.run(self.ema.apply_program)
                 if eval_dataset is not None and eval_dataset.num_samples > 0:
                     self.eval_metrics, self.eval_details = self.evaluate(
@@ -576,7 +576,7 @@ class BaseAPI:
                             log_writer.add_scalar(
                                 "Metrics/Eval(Epoch): {}".format(k), v, i + 1)
                 self.save_model(save_dir=current_save_dir)
-                if hasattr(self, 'use_ema'):
+                if getattr(self, 'use_ema', False):
                     self.exe.run(self.ema.restore_program)
                 time_eval_one_epoch = time.time() - eval_epoch_start_time
                 eval_epoch_start_time = time.time()

+ 13 - 3
paddlex/cv/models/ppyolo.py

@@ -37,11 +37,19 @@ class PPYOLO(BaseAPI):
     Args:
         num_classes (int): 类别数。默认为80。
         backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd']。默认为'ResNet50_vd'。
+        with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。
         anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
                     [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
                     [59, 119], [116, 90], [156, 198], [373, 326]]。
         anchor_masks (list|tuple): 在计算PPYOLO损失时,使用anchor的mask索引,为None时表示使用默认值
                     [[6, 7, 8], [3, 4, 5], [0, 1, 2]]。
+        use_coord_conv (bool): 是否使用CoordConv。默认值为True。
+        use_iou_aware (bool): 是否使用IoU Aware分支。默认值为True。
+        use_spp (bool): 是否使用Spatial Pyramid Pooling结构。默认值为True。
+        use_drop_block (bool): 是否使用Drop Block。默认值为True。
+        scale_x_y (float): 调整中心点位置时的系数因子。默认值为1.05。
+        use_iou_loss (bool): 是否使用IoU loss。默认值为True。
+        use_matrix_nms (bool): 是否使用Matrix NMS。默认值为True。
         ignore_threshold (float): 在计算PPYOLO损失时,IoU大于`ignore_threshold`的预测框的置信度被忽略。默认为0.7。
         nms_score_threshold (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
         nms_topk (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
@@ -54,7 +62,7 @@ class PPYOLO(BaseAPI):
     def __init__(
             self,
             num_classes=80,
-            backbone='ResNet50_vd',
+            backbone='ResNet50_vd_ssld',
             with_dcn_v2=True,
             # YOLO Head
             anchors=None,
@@ -79,7 +87,7 @@ class PPYOLO(BaseAPI):
             ]):
         self.init_params = locals()
         super(PPYOLO, self).__init__('detector')
-        backbones = ['ResNet50_vd']
+        backbones = ['ResNet50_vd_ssld']
         assert backbone in backbones, "backbone should be one of {}".format(
             backbones)
         self.backbone = backbone
@@ -116,7 +124,7 @@ class PPYOLO(BaseAPI):
         self.with_dcn_v2 = with_dcn_v2
 
     def _get_backbone(self, backbone_name):
-        if backbone_name == 'ResNet50_vd':
+        if backbone_name.startswith('ResNet50_vd'):
             backbone = paddlex.cv.nets.ResNet(
                 norm_type='sync_bn',
                 layers=50,
@@ -252,6 +260,8 @@ class PPYOLO(BaseAPI):
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
             resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
+            use_ema (bool): 是否使用指数衰减计算参数的滑动平均值。默认值为True。
+            ema_decay (float): 指数衰减率。默认值为0.9998。
 
         Raises:
             ValueError: 评估类型不在指定列表中。

+ 6 - 2
paddlex/cv/models/utils/pretrain_weights.py

@@ -116,7 +116,9 @@ coco_pretrain = {
     'DeepLabv3p_MobileNetV2_x1.0_COCO':
     'https://bj.bcebos.com/v1/paddleseg/deeplab_mobilenet_x1_0_coco.tgz',
     'DeepLabv3p_Xception65_COCO':
-    'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz'
+    'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz',
+    'PPYOLO_ResNet50_vd_ssld_COCO':
+    'https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams'
 }
 
 cityscapes_pretrain = {
@@ -226,7 +228,9 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
         new_save_dir = save_dir
         if hasattr(paddlex, 'pretrain_dir'):
             new_save_dir = paddlex.pretrain_dir
-        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p']:
+        if class_name in [
+                'YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p', 'PPYOLO'
+        ]:
             backbone = '{}_{}'.format(class_name, backbone)
         backbone = "{}_{}".format(backbone, flag)
         if flag == 'COCO':

+ 1 - 1
paddlex/cv/models/yolo_v3.py

@@ -60,7 +60,7 @@ class YOLOv3(PPYOLO):
         ]
         assert backbone in backbones, "backbone should be one of {}".format(
             backbones)
-        super(YOLOv3, self).__init__('detector')
+        super(PPYOLO, self).__init__('detector')
         self.backbone = backbone
         self.num_classes = num_classes
         self.anchors = anchors

+ 3 - 3
paddlex/tools/x2coco.py

@@ -147,7 +147,7 @@ class LabelMe2COCO(X2COCO):
             img_name_part = osp.splitext(img_file)[0]
             json_file = osp.join(json_dir, img_name_part + ".json")
             if not osp.exists(json_file):
-                os.remove(osp.join(image_dir, img_file))
+                os.remove(osp.join(img_dir, img_file))
                 continue
             image_id = image_id + 1
             with open(json_file, mode='r', \
@@ -220,7 +220,7 @@ class EasyData2COCO(X2COCO):
             img_name_part = osp.splitext(img_file)[0]
             json_file = osp.join(json_dir, img_name_part + ".json")
             if not osp.exists(json_file):
-                os.remove(osp.join(image_dir, img_file))
+                os.remove(osp.join(img_dir, img_file))
                 continue
             image_id = image_id + 1
             with open(json_file, mode='r', \
@@ -317,7 +317,7 @@ class JingLing2COCO(X2COCO):
             img_name_part = osp.splitext(img_file)[0]
             json_file = osp.join(json_dir, img_name_part + ".json")
             if not osp.exists(json_file):
-                os.remove(osp.join(image_dir, img_file))
+                os.remove(osp.join(img_dir, img_file))
                 continue
             image_id = image_id + 1
             with open(json_file, mode='r', \

+ 4 - 3
paddlex/tools/x2seg.py

@@ -23,6 +23,7 @@ import shutil
 import numpy as np
 import PIL.Image
 from .base import MyEncoder, is_pic, get_encoding
+import math
 
 class X2Seg(object):
     def __init__(self):
@@ -140,7 +141,7 @@ class JingLing2Seg(X2Seg):
             img_name_part = osp.splitext(img_name)[0]
             json_file = osp.join(json_dir, img_name_part + ".json")
             if not osp.exists(json_file):
-                os.remove(os.remove(osp.join(image_dir, img_name)))
+                os.remove(osp.join(image_dir, img_name))
                 continue
             with open(json_file, mode="r", \
                               encoding=get_encoding(json_file)) as j:
@@ -226,7 +227,7 @@ class LabelMe2Seg(X2Seg):
             img_name_part = osp.splitext(img_name)[0]
             json_file = osp.join(json_dir, img_name_part + ".json")
             if not osp.exists(json_file):
-                os.remove(os.remove(osp.join(image_dir, img_name)))
+                os.remove(osp.join(image_dir, img_name))
                 continue
             img_file = osp.join(image_dir, img_name)
             img = np.asarray(PIL.Image.open(img_file))
@@ -260,7 +261,7 @@ class EasyData2Seg(X2Seg):
             img_name_part = osp.splitext(img_name)[0]
             json_file = osp.join(json_dir, img_name_part + ".json")
             if not osp.exists(json_file):
-                os.remove(os.remove(osp.join(image_dir, img_name)))
+                os.remove(osp.join(image_dir, img_name))
                 continue
             with open(json_file, mode="r", \
                               encoding=get_encoding(json_file)) as j:

+ 1 - 0
tutorials/train/README.md

@@ -12,6 +12,7 @@
 |object_detection/faster_rcnn_hrnet_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
 |object_detection/faster_rcnn_r18_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
 |object_detection/faster_rcnn_r50_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
+|object_detection/ppyolo.py | 目标检测PPYOLO | 昆虫检测 |
 |object_detection/yolov3_darknet53.py | 目标检测YOLOv3 | 昆虫检测 |
 |object_detection/yolov3_mobilenetv1.py | 目标检测YOLOv3 | 昆虫检测 |
 |object_detection/yolov3_mobilenetv3.py | 目标检测YOLOv3 | 昆虫检测 |