Browse Source

Merge pull request #266 from FlyingQianMM/develop_qh

add ppyolo
Jason 5 năm trước cách đây
mục cha
commit
f80854690b

+ 1 - 0
paddlex/cv/__init__.py

@@ -26,6 +26,7 @@ ResNet50 = models.ResNet50
 DarkNet53 = models.DarkNet53
 # detection
 YOLOv3 = models.YOLOv3
+PPYOLO = models.PPYOLO
 #EAST = models.EAST
 FasterRCNN = models.FasterRCNN
 MaskRCNN = models.MaskRCNN

+ 10 - 8
paddlex/cv/datasets/dataset.py

@@ -115,7 +115,7 @@ def multithread_reader(mapper,
         while not isinstance(sample, EndSignal):
             batch_data.append(sample)
             if len(batch_data) == batch_size:
-                batch_data = generate_minibatch(batch_data)
+                batch_data = generate_minibatch(batch_data, mapper=mapper)
                 yield batch_data
                 batch_data = []
             sample = out_queue.get()
@@ -127,11 +127,11 @@ def multithread_reader(mapper,
             else:
                 batch_data.append(sample)
                 if len(batch_data) == batch_size:
-                    batch_data = generate_minibatch(batch_data)
+                    batch_data = generate_minibatch(batch_data, mapper=mapper)
                     yield batch_data
                     batch_data = []
         if not drop_last and len(batch_data) != 0:
-            batch_data = generate_minibatch(batch_data)
+            batch_data = generate_minibatch(batch_data, mapper=mapper)
             yield batch_data
             batch_data = []
 
@@ -188,18 +188,21 @@ def multiprocess_reader(mapper,
             else:
                 batch_data.append(sample)
                 if len(batch_data) == batch_size:
-                    batch_data = generate_minibatch(batch_data)
+                    batch_data = generate_minibatch(batch_data, mapper=mapper)
                     yield batch_data
                     batch_data = []
         if len(batch_data) != 0 and not drop_last:
-            batch_data = generate_minibatch(batch_data)
+            batch_data = generate_minibatch(batch_data, mapper=mapper)
             yield batch_data
             batch_data = []
 
     return queue_reader
 
 
-def generate_minibatch(batch_data, label_padding_value=255):
+def generate_minibatch(batch_data, label_padding_value=255, mapper=None):
+    if mapper is not None and mapper.batch_transforms is not None:
+        for op in mapper.batch_transforms:
+            batch_data = op(batch_data)
     # if batch_size is 1, do not pad the image
     if len(batch_data) == 1:
         return batch_data
@@ -218,14 +221,13 @@ def generate_minibatch(batch_data, label_padding_value=255):
             (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
         padding_im[:, :im_h, :im_w] = data[0]
         if len(data) > 2:
-           # padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
+            # padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
             if len(data[1]) == 0 or 'padding' not in [
                     data[1][i][0] for i in range(len(data[1]))
             ]:
                 data[1].append(('padding', [im_h, im_w]))
             padding_batch.append((padding_im, data[1], data[2]))
 
-            
         elif len(data) > 1:
             if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
                 # padding the image and label of segmentation during the training

+ 1 - 0
paddlex/cv/models/__init__.py

@@ -38,6 +38,7 @@ from .classifier import HRNet_W18
 from .classifier import AlexNet
 from .base import BaseAPI
 from .yolo_v3 import YOLOv3
+from .ppyolo import PPYOLO
 from .faster_rcnn import FasterRCNN
 from .mask_rcnn import MaskRCNN
 from .unet import UNet

+ 17 - 8
paddlex/cv/models/base.py

@@ -246,8 +246,8 @@ class BaseAPI:
             logging.info(
                 "Load pretrain weights from {}.".format(pretrain_weights),
                 use_color=True)
-            paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog,
-                                                      pretrain_weights, fuse_bn)
+            paddlex.utils.utils.load_pretrain_weights(
+                self.exe, self.train_prog, pretrain_weights, fuse_bn)
         # 进行裁剪
         if sensitivities_file is not None:
             import paddleslim
@@ -351,7 +351,9 @@ class BaseAPI:
         logging.info("Model saved in {}.".format(save_dir))
 
     def export_inference_model(self, save_dir):
-        test_input_names = [var.name for var in list(self.test_inputs.values())]
+        test_input_names = [
+            var.name for var in list(self.test_inputs.values())
+        ]
         test_outputs = list(self.test_outputs.values())
         with fluid.scope_guard(self.scope):
             if self.__class__.__name__ == 'MaskRCNN':
@@ -389,7 +391,8 @@ class BaseAPI:
 
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
-        logging.info("Model for inference deploy saved in {}.".format(save_dir))
+        logging.info("Model for inference deploy saved in {}.".format(
+            save_dir))
 
     def train_loop(self,
                    num_epochs,
@@ -516,11 +519,13 @@ class BaseAPI:
                         eta = ((num_epochs - i) * total_num_steps - step - 1
                                ) * avg_step_time
                     if time_eval_one_epoch is not None:
-                        eval_eta = (total_eval_times - i // save_interval_epochs
-                                    ) * time_eval_one_epoch
+                        eval_eta = (
+                            total_eval_times - i // save_interval_epochs
+                        ) * time_eval_one_epoch
                     else:
-                        eval_eta = (total_eval_times - i // save_interval_epochs
-                                    ) * total_num_steps_eval * avg_step_time
+                        eval_eta = (
+                            total_eval_times - i // save_interval_epochs
+                        ) * total_num_steps_eval * avg_step_time
                     eta_str = seconds_to_hms(eta + eval_eta)
 
                     logging.info(
@@ -543,6 +548,8 @@ class BaseAPI:
                 current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
                 if not osp.isdir(current_save_dir):
                     os.makedirs(current_save_dir)
+                if hasattr(self, 'use_ema'):
+                    self.exe.run(self.ema.apply_program)
                 if eval_dataset is not None and eval_dataset.num_samples > 0:
                     self.eval_metrics, self.eval_details = self.evaluate(
                         eval_dataset=eval_dataset,
@@ -569,6 +576,8 @@ class BaseAPI:
                             log_writer.add_scalar(
                                 "Metrics/Eval(Epoch): {}".format(k), v, i + 1)
                 self.save_model(save_dir=current_save_dir)
+                if hasattr(self, 'use_ema'):
+                    self.exe.run(self.ema.restore_program)
                 time_eval_one_epoch = time.time() - eval_epoch_start_time
                 eval_epoch_start_time = time.time()
                 if best_model_epoch > 0:

+ 555 - 0
paddlex/cv/models/ppyolo.py

@@ -0,0 +1,555 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import math
+import tqdm
+import os.path as osp
+import numpy as np
+from multiprocessing.pool import ThreadPool
+import paddle.fluid as fluid
+from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
+from paddle.fluid.optimizer import ExponentialMovingAverage
+import paddlex.utils.logging as logging
+import paddlex
+import copy
+from paddlex.cv.transforms import arrange_transforms
+from paddlex.cv.datasets import generate_minibatch
+from .base import BaseAPI
+from collections import OrderedDict
+from .utils.detection_eval import eval_results, bbox2out
+
+
+class PPYOLO(BaseAPI):
+    """构建PPYOLO,并实现其训练、评估、预测和模型导出。
+
+    Args:
+        num_classes (int): 类别数。默认为80。
+        backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd']。默认为'ResNet50_vd'。
+        anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
+                    [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
+                    [59, 119], [116, 90], [156, 198], [373, 326]]。
+        anchor_masks (list|tuple): 在计算PPYOLO损失时,使用anchor的mask索引,为None时表示使用默认值
+                    [[6, 7, 8], [3, 4, 5], [0, 1, 2]]。
+        ignore_threshold (float): 在计算PPYOLO损失时,IoU大于`ignore_threshold`的预测框的置信度被忽略。默认为0.7。
+        nms_score_threshold (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
+        nms_topk (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
+        nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
+        nms_iou_threshold (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
+        label_smooth (bool): 是否使用label smooth。默认值为False。
+        train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
+    """
+
+    def __init__(
+            self,
+            num_classes=80,
+            backbone='ResNet50_vd',
+            with_dcn_v2=True,
+            # YOLO Head
+            anchors=None,
+            anchor_masks=None,
+            use_coord_conv=True,
+            use_iou_aware=True,
+            use_spp=True,
+            use_drop_block=True,
+            scale_x_y=1.05,
+            # PPYOLO Loss
+            ignore_threshold=0.7,
+            label_smooth=False,
+            use_iou_loss=True,
+            # NMS
+            use_matrix_nms=True,
+            nms_score_threshold=0.01,
+            nms_topk=1000,
+            nms_keep_topk=100,
+            nms_iou_threshold=0.45,
+            train_random_shapes=[
+                320, 352, 384, 416, 448, 480, 512, 544, 576, 608
+            ]):
+        self.init_params = locals()
+        super(PPYOLO, self).__init__('detector')
+        backbones = ['ResNet50_vd']
+        assert backbone in backbones, "backbone should be one of {}".format(
+            backbones)
+        self.backbone = backbone
+        self.num_classes = num_classes
+        self.anchors = anchors
+        self.anchor_masks = anchor_masks
+        if anchors is None:
+            self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
+                            [59, 119], [116, 90], [156, 198], [373, 326]]
+        if anchor_masks is None:
+            self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+        self.ignore_threshold = ignore_threshold
+        self.nms_score_threshold = nms_score_threshold
+        self.nms_topk = nms_topk
+        self.nms_keep_topk = nms_keep_topk
+        self.nms_iou_threshold = nms_iou_threshold
+        self.label_smooth = label_smooth
+        self.sync_bn = True
+        self.train_random_shapes = train_random_shapes
+        self.fixed_input_shape = None
+        self.use_fine_grained_loss = False
+        if use_coord_conv or use_iou_aware or use_spp or use_drop_block or use_iou_loss:
+            self.use_fine_grained_loss = True
+        self.use_coord_conv = use_coord_conv
+        self.use_iou_aware = use_iou_aware
+        self.use_spp = use_spp
+        self.use_drop_block = use_drop_block
+        self.use_iou_loss = use_iou_loss
+        self.scale_x_y = scale_x_y
+        self.max_height = 608
+        self.max_width = 608
+        self.use_matrix_nms = use_matrix_nms
+        self.use_ema = False
+        self.with_dcn_v2 = with_dcn_v2
+
+    def _get_backbone(self, backbone_name):
+        if backbone_name == 'ResNet50_vd':
+            backbone = paddlex.cv.nets.ResNet(
+                norm_type='sync_bn',
+                layers=50,
+                freeze_norm=False,
+                norm_decay=0.,
+                feature_maps=[3, 4, 5],
+                freeze_at=0,
+                variant='d',
+                dcn_v2_stages=[5] if self.with_dcn_v2 else [])
+        return backbone
+
+    def build_net(self, mode='train'):
+        model = paddlex.cv.nets.detection.YOLOv3(
+            backbone=self._get_backbone(self.backbone),
+            num_classes=self.num_classes,
+            mode=mode,
+            anchors=self.anchors,
+            anchor_masks=self.anchor_masks,
+            ignore_threshold=self.ignore_threshold,
+            label_smooth=self.label_smooth,
+            nms_score_threshold=self.nms_score_threshold,
+            nms_topk=self.nms_topk,
+            nms_keep_topk=self.nms_keep_topk,
+            nms_iou_threshold=self.nms_iou_threshold,
+            fixed_input_shape=self.fixed_input_shape,
+            coord_conv=self.use_coord_conv,
+            iou_aware=self.use_iou_aware,
+            scale_x_y=self.scale_x_y,
+            spp=self.use_spp,
+            drop_block=self.use_drop_block,
+            use_matrix_nms=self.use_matrix_nms,
+            use_fine_grained_loss=self.use_fine_grained_loss,
+            use_iou_loss=self.use_iou_loss,
+            batch_size=self.batch_size_per_gpu
+            if hasattr(self, 'batch_size_per_gpu') else 8)
+        if mode == 'train' and self.use_iou_loss or self.use_iou_aware:
+            model.max_height = self.max_height
+            model.max_width = self.max_width
+        inputs = model.generate_inputs()
+        model_out = model.build_net(inputs)
+        outputs = OrderedDict([('bbox', model_out)])
+        if mode == 'train':
+            self.optimizer.minimize(model_out)
+            outputs = OrderedDict([('loss', model_out)])
+            if self.use_ema:
+                global_steps = _decay_step_counter()
+                self.ema = ExponentialMovingAverage(
+                    self.ema_decay, thres_steps=global_steps)
+                self.ema.update()
+        return inputs, outputs
+
+    def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
+                          lr_decay_epochs, lr_decay_gamma,
+                          num_steps_each_epoch):
+        if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
+            logging.error(
+                "In function train(), parameters should satisfy: warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
+                exit=False)
+            logging.error(
+                "See this doc for more information: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
+                exit=False)
+            logging.error(
+                "warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
+                format(lr_decay_epochs[0] * num_steps_each_epoch, warmup_steps
+                       // num_steps_each_epoch))
+        boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
+        values = [(lr_decay_gamma**i) * learning_rate
+                  for i in range(len(lr_decay_epochs) + 1)]
+        lr_decay = fluid.layers.piecewise_decay(
+            boundaries=boundaries, values=values)
+        lr_warmup = fluid.layers.linear_lr_warmup(
+            learning_rate=lr_decay,
+            warmup_steps=warmup_steps,
+            start_lr=warmup_start_lr,
+            end_lr=learning_rate)
+        optimizer = fluid.optimizer.Momentum(
+            learning_rate=lr_warmup,
+            momentum=0.9,
+            regularization=fluid.regularizer.L2DecayRegularizer(5e-04))
+        return optimizer
+
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=8,
+              eval_dataset=None,
+              save_interval_epochs=20,
+              log_interval_steps=2,
+              save_dir='output',
+              pretrain_weights='IMAGENET',
+              optimizer=None,
+              learning_rate=1.0 / 8000,
+              warmup_steps=1000,
+              warmup_start_lr=0.0,
+              lr_decay_epochs=[213, 240],
+              lr_decay_gamma=0.1,
+              metric=None,
+              use_vdl=False,
+              sensitivities_file=None,
+              eval_metric_loss=0.05,
+              early_stop=False,
+              early_stop_patience=5,
+              resume_checkpoint=None,
+              use_ema=True,
+              ema_decay=0.9998):
+        """训练。
+
+        Args:
+            num_epochs (int): 训练迭代轮数。
+            train_dataset (paddlex.datasets): 训练数据读取器。
+            train_batch_size (int): 训练数据batch大小。目前检测仅支持单卡评估,训练数据batch大小与显卡
+                数量之商为验证数据batch大小。默认值为8。
+            eval_dataset (paddlex.datasets): 验证数据读取器。
+            save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为20。
+            log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为10。
+            save_dir (str): 模型保存路径。默认值为'output'。
+            pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
+                则自动下载在ImageNet图片数据上预训练的模型权重;若为字符串'COCO',
+                则自动下载在COCO数据集上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'。
+            optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
+                fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
+            learning_rate (float): 默认优化器的学习率。默认为1.0/8000。
+            warmup_steps (int):  默认优化器进行warmup过程的步数。默认为1000。
+            warmup_start_lr (int): 默认优化器warmup的起始学习率。默认为0.0。
+            lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[213, 240]。
+            lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
+            metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
+            use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
+            sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
+                则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
+            eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
+            early_stop (bool): 是否使用提前终止训练策略。默认值为False。
+            early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
+                连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
+
+        Raises:
+            ValueError: 评估类型不在指定列表中。
+            ValueError: 模型从inference model进行加载。
+        """
+        if not self.trainable:
+            raise ValueError("Model is not trainable from load_model method.")
+        if metric is None:
+            if isinstance(train_dataset, paddlex.datasets.CocoDetection):
+                metric = 'COCO'
+            elif isinstance(train_dataset, paddlex.datasets.VOCDetection) or \
+                    isinstance(train_dataset, paddlex.datasets.EasyDataDet):
+                metric = 'VOC'
+            else:
+                raise ValueError(
+                    "train_dataset should be datasets.VOCDetection or datasets.COCODetection or datasets.EasyDataDet."
+                )
+        assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
+        self.metric = metric
+
+        self.labels = train_dataset.labels
+        # 构建训练网络
+        if optimizer is None:
+            # 构建默认的优化策略
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            optimizer = self.default_optimizer(
+                learning_rate=learning_rate,
+                warmup_steps=warmup_steps,
+                warmup_start_lr=warmup_start_lr,
+                lr_decay_epochs=lr_decay_epochs,
+                lr_decay_gamma=lr_decay_gamma,
+                num_steps_each_epoch=num_steps_each_epoch)
+        self.optimizer = optimizer
+        self.use_ema = use_ema
+        self.ema_decay = ema_decay
+
+        self.batch_size_per_gpu = int(train_batch_size /
+                                      paddlex.env_info['num'])
+        if self.use_fine_grained_loss:
+            for transform in train_dataset.transforms.transforms:
+                if isinstance(transform, paddlex.det.transforms.Resize):
+                    self.max_height = transform.target_size
+                    self.max_width = transform.target_size
+                    break
+        if train_dataset.transforms.batch_transforms is None:
+            train_dataset.transforms.batch_transforms = list()
+        define_random_shape = False
+        for bt in train_dataset.transforms.batch_transforms:
+            if isinstance(bt, paddlex.det.transforms.BatchRandomShape):
+                define_random_shape = True
+        if not define_random_shape:
+            if isinstance(self.train_random_shapes,
+                          (list, tuple)) and len(self.train_random_shapes) > 0:
+                train_dataset.transforms.batch_transforms.append(
+                    paddlex.det.transforms.BatchRandomShape(
+                        random_shapes=self.train_random_shapes))
+                if self.use_fine_grained_loss:
+                    self.max_height = max(self.max_height,
+                                          max(self.train_random_shapes))
+                    self.max_width = max(self.max_width,
+                                         max(self.train_random_shapes))
+        if self.use_fine_grained_loss:
+            define_generate_target = False
+            for bt in train_dataset.transforms.batch_transforms:
+                if isinstance(bt, paddlex.det.transforms.GenerateYoloTarget):
+                    define_generate_target = True
+            if not define_generate_target:
+                train_dataset.transforms.batch_transforms.append(
+                    paddlex.det.transforms.GenerateYoloTarget(
+                        anchors=self.anchors,
+                        anchor_masks=self.anchor_masks,
+                        num_classes=self.num_classes,
+                        downsample_ratios=[32, 16, 8]))
+        # 构建训练、验证、预测网络
+        self.build_program()
+        # 初始化网络权重
+        self.net_initialize(
+            startup_prog=fluid.default_startup_program(),
+            pretrain_weights=pretrain_weights,
+            save_dir=save_dir,
+            sensitivities_file=sensitivities_file,
+            eval_metric_loss=eval_metric_loss,
+            resume_checkpoint=resume_checkpoint)
+        # 训练
+        self.train_loop(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            use_vdl=use_vdl,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience)
+
+    def evaluate(self,
+                 eval_dataset,
+                 batch_size=1,
+                 epoch_id=None,
+                 metric=None,
+                 return_details=False):
+        """评估。
+
+        Args:
+            eval_dataset (paddlex.datasets): 验证数据读取器。
+            batch_size (int): 验证数据批大小。默认为1。
+            epoch_id (int): 当前评估模型所在的训练轮数。
+            metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
+                根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
+                如为COCODetection,则metric为'COCO'。
+            return_details (bool): 是否返回详细信息。
+
+        Returns:
+            tuple (metrics, eval_details) | dict (metrics): 当return_details为True时,返回(metrics, eval_details),
+                当return_details为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'或者’bbox_map‘,
+                分别表示平均准确率平均值在各个IoU阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。
+                eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、
+                预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
+        """
+        arrange_transforms(
+            model_type=self.model_type,
+            class_name=self.__class__.__name__,
+            transforms=eval_dataset.transforms,
+            mode='eval')
+        if metric is None:
+            if hasattr(self, 'metric') and self.metric is not None:
+                metric = self.metric
+            else:
+                if isinstance(eval_dataset, paddlex.datasets.CocoDetection):
+                    metric = 'COCO'
+                elif isinstance(eval_dataset, paddlex.datasets.VOCDetection):
+                    metric = 'VOC'
+                else:
+                    raise Exception(
+                        "eval_dataset should be datasets.VOCDetection or datasets.COCODetection."
+                    )
+        assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
+
+        total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
+        results = list()
+
+        data_generator = eval_dataset.generator(
+            batch_size=batch_size, drop_last=False)
+        logging.info(
+            "Start to evaluating(total_samples={}, total_steps={})...".format(
+                eval_dataset.num_samples, total_steps))
+        for step, data in tqdm.tqdm(
+                enumerate(data_generator()), total=total_steps):
+            images = np.array([d[0] for d in data])
+            im_sizes = np.array([d[1] for d in data])
+            feed_data = {'image': images, 'im_size': im_sizes}
+            with fluid.scope_guard(self.scope):
+                outputs = self.exe.run(
+                    self.test_prog,
+                    feed=[feed_data],
+                    fetch_list=list(self.test_outputs.values()),
+                    return_numpy=False)
+            res = {
+                'bbox': (np.array(outputs[0]),
+                         outputs[0].recursive_sequence_lengths())
+            }
+            res_id = [np.array([d[2]]) for d in data]
+            res['im_id'] = (res_id, [])
+            if metric == 'VOC':
+                res_gt_box = [d[3].reshape(-1, 4) for d in data]
+                res_gt_label = [d[4].reshape(-1, 1) for d in data]
+                res_is_difficult = [d[5].reshape(-1, 1) for d in data]
+                res_id = [np.array([d[2]]) for d in data]
+                res['gt_box'] = (res_gt_box, [])
+                res['gt_label'] = (res_gt_label, [])
+                res['is_difficult'] = (res_is_difficult, [])
+            results.append(res)
+            logging.debug("[EVAL] Epoch={}, Step={}/{}".format(epoch_id, step +
+                                                               1, total_steps))
+        box_ap_stats, eval_details = eval_results(
+            results, metric, eval_dataset.coco_gt, with_background=False)
+        evaluate_metrics = OrderedDict(
+            zip(['bbox_mmap'
+                 if metric == 'COCO' else 'bbox_map'], box_ap_stats))
+        if return_details:
+            return evaluate_metrics, eval_details
+        return evaluate_metrics
+
+    @staticmethod
+    def _preprocess(images, transforms, model_type, class_name, thread_num=1):
+        arrange_transforms(
+            model_type=model_type,
+            class_name=class_name,
+            transforms=transforms,
+            mode='test')
+        pool = ThreadPool(thread_num)
+        batch_data = pool.map(transforms, images)
+        pool.close()
+        pool.join()
+        padding_batch = generate_minibatch(batch_data)
+        im = np.array(
+            [data[0] for data in padding_batch],
+            dtype=padding_batch[0][0].dtype)
+        im_size = np.array([data[1] for data in padding_batch], dtype=np.int32)
+
+        return im, im_size
+
+    @staticmethod
+    def _postprocess(res, batch_size, num_classes, labels):
+        clsid2catid = dict({i: i for i in range(num_classes)})
+        xywh_results = bbox2out([res], clsid2catid)
+        preds = [[] for i in range(batch_size)]
+        for xywh_res in xywh_results:
+            image_id = xywh_res['image_id']
+            del xywh_res['image_id']
+            xywh_res['category'] = labels[xywh_res['category_id']]
+            preds[image_id].append(xywh_res)
+
+        return preds
+
+    def predict(self, img_file, transforms=None):
+        """预测。
+
+        Args:
+            img_file (str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
+            transforms (paddlex.det.transforms): 数据预处理操作。
+
+        Returns:
+            list: 预测结果列表,每个预测结果由预测框类别标签、
+              预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
+              预测框得分组成。
+        """
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise Exception("transforms need to be defined, now is None.")
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
+        else:
+            raise Exception("img_file must be str/np.ndarray")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_size = PPYOLO._preprocess(images, transforms, self.model_type,
+                                         self.__class__.__name__)
+
+        with fluid.scope_guard(self.scope):
+            result = self.exe.run(self.test_prog,
+                                  feed={'image': im,
+                                        'im_size': im_size},
+                                  fetch_list=list(self.test_outputs.values()),
+                                  return_numpy=False,
+                                  use_program_cache=True)
+
+        res = {
+            k: (np.array(v), v.recursive_sequence_lengths())
+            for k, v in zip(list(self.test_outputs.keys()), result)
+        }
+        res['im_id'] = (np.array(
+            [[i] for i in range(len(images))]).astype('int32'), [[]])
+        preds = PPYOLO._postprocess(res,
+                                    len(images), self.num_classes, self.labels)
+        return preds[0]
+
+    def batch_predict(self, img_file_list, transforms=None, thread_num=2):
+        """预测。
+
+        Args:
+            img_file_list (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径,也可以是解码后的排列格式为(H,W,C)
+                且类型为float32且为BGR格式的数组。
+            transforms (paddlex.det.transforms): 数据预处理操作。
+            thread_num (int): 并发执行各图像预处理时的线程数。
+        Returns:
+            list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、
+              预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
+              预测框得分组成。
+        """
+        if transforms is None and not hasattr(self, 'test_transforms'):
+            raise Exception("transforms need to be defined, now is None.")
+
+        if not isinstance(img_file_list, (list, tuple)):
+            raise Exception("im_file must be list/tuple")
+
+        if transforms is None:
+            transforms = self.test_transforms
+        im, im_size = PPYOLO._preprocess(img_file_list, transforms,
+                                         self.model_type,
+                                         self.__class__.__name__, thread_num)
+
+        with fluid.scope_guard(self.scope):
+            result = self.exe.run(self.test_prog,
+                                  feed={'image': im,
+                                        'im_size': im_size},
+                                  fetch_list=list(self.test_outputs.values()),
+                                  return_numpy=False,
+                                  use_program_cache=True)
+
+        res = {
+            k: (np.array(v), v.recursive_sequence_lengths())
+            for k, v in zip(list(self.test_outputs.keys()), result)
+        }
+        res['im_id'] = (np.array(
+            [[i] for i in range(len(img_file_list))]).astype('int32'), [[]])
+        preds = PPYOLO._postprocess(res,
+                                    len(img_file_list), self.num_classes,
+                                    self.labels)
+        return preds

+ 20 - 321
paddlex/cv/models/yolo_v3.py

@@ -15,21 +15,11 @@
 from __future__ import absolute_import
 import math
 import tqdm
-import os.path as osp
-import numpy as np
-from multiprocessing.pool import ThreadPool
-import paddle.fluid as fluid
-import paddlex.utils.logging as logging
 import paddlex
-import copy
-from paddlex.cv.transforms import arrange_transforms
-from paddlex.cv.datasets import generate_minibatch
-from .base import BaseAPI
-from collections import OrderedDict
-from .utils.detection_eval import eval_results, bbox2out
+from .ppyolo import PPYOLO
 
 
-class YOLOv3(BaseAPI):
+class YOLOv3(PPYOLO):
     """构建YOLOv3,并实现其训练、评估、预测和模型导出。
 
     Args:
@@ -65,12 +55,12 @@ class YOLOv3(BaseAPI):
                      320, 352, 384, 416, 448, 480, 512, 544, 576, 608
                  ]):
         self.init_params = locals()
-        super(YOLOv3, self).__init__('detector')
         backbones = [
             'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large'
         ]
         assert backbone in backbones, "backbone should be one of {}".format(
             backbones)
+        super(YOLOv3, self).__init__('detector')
         self.backbone = backbone
         self.num_classes = num_classes
         self.anchors = anchors
@@ -84,6 +74,16 @@ class YOLOv3(BaseAPI):
         self.sync_bn = True
         self.train_random_shapes = train_random_shapes
         self.fixed_input_shape = None
+        self.use_fine_grained_loss = False
+        self.use_coord_conv = False
+        self.use_iou_aware = False
+        self.use_spp = False
+        self.use_drop_block = False
+        self.use_iou_loss = False
+        self.scale_x_y = 1.
+        self.use_matrix_nms = False
+        self.use_ema = False
+        self.with_dcn_v2 = False
 
     def _get_backbone(self, backbone_name):
         if backbone_name == 'DarkNet53':
@@ -104,59 +104,6 @@ class YOLOv3(BaseAPI):
                 norm_type='sync_bn', model_name=model_name)
         return backbone
 
-    def build_net(self, mode='train'):
-        model = paddlex.cv.nets.detection.YOLOv3(
-            backbone=self._get_backbone(self.backbone),
-            num_classes=self.num_classes,
-            mode=mode,
-            anchors=self.anchors,
-            anchor_masks=self.anchor_masks,
-            ignore_threshold=self.ignore_threshold,
-            label_smooth=self.label_smooth,
-            nms_score_threshold=self.nms_score_threshold,
-            nms_topk=self.nms_topk,
-            nms_keep_topk=self.nms_keep_topk,
-            nms_iou_threshold=self.nms_iou_threshold,
-            train_random_shapes=self.train_random_shapes,
-            fixed_input_shape=self.fixed_input_shape)
-        inputs = model.generate_inputs()
-        model_out = model.build_net(inputs)
-        outputs = OrderedDict([('bbox', model_out)])
-        if mode == 'train':
-            self.optimizer.minimize(model_out)
-            outputs = OrderedDict([('loss', model_out)])
-        return inputs, outputs
-
-    def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
-                          lr_decay_epochs, lr_decay_gamma,
-                          num_steps_each_epoch):
-        if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
-            logging.error(
-                "In function train(), parameters should satisfy: warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
-                exit=False)
-            logging.error(
-                "See this doc for more information: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
-                exit=False)
-            logging.error(
-                "warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
-                format(lr_decay_epochs[0] * num_steps_each_epoch, warmup_steps
-                       // num_steps_each_epoch))
-        boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
-        values = [(lr_decay_gamma**i) * learning_rate
-                  for i in range(len(lr_decay_epochs) + 1)]
-        lr_decay = fluid.layers.piecewise_decay(
-            boundaries=boundaries, values=values)
-        lr_warmup = fluid.layers.linear_lr_warmup(
-            learning_rate=lr_decay,
-            warmup_steps=warmup_steps,
-            start_lr=warmup_start_lr,
-            end_lr=learning_rate)
-        optimizer = fluid.optimizer.Momentum(
-            learning_rate=lr_warmup,
-            momentum=0.9,
-            regularization=fluid.regularizer.L2DecayRegularizer(5e-04))
-        return optimizer
-
     def train(self,
               num_epochs,
               train_dataset,
@@ -214,259 +161,11 @@ class YOLOv3(BaseAPI):
             ValueError: 评估类型不在指定列表中。
             ValueError: 模型从inference model进行加载。
         """
-        if not self.trainable:
-            raise ValueError("Model is not trainable from load_model method.")
-        if metric is None:
-            if isinstance(train_dataset, paddlex.datasets.CocoDetection):
-                metric = 'COCO'
-            elif isinstance(train_dataset, paddlex.datasets.VOCDetection) or \
-                    isinstance(train_dataset, paddlex.datasets.EasyDataDet):
-                metric = 'VOC'
-            else:
-                raise ValueError(
-                    "train_dataset should be datasets.VOCDetection or datasets.COCODetection or datasets.EasyDataDet."
-                )
-        assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
-        self.metric = metric
-
-        self.labels = train_dataset.labels
-        # 构建训练网络
-        if optimizer is None:
-            # 构建默认的优化策略
-            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
-            optimizer = self.default_optimizer(
-                learning_rate=learning_rate,
-                warmup_steps=warmup_steps,
-                warmup_start_lr=warmup_start_lr,
-                lr_decay_epochs=lr_decay_epochs,
-                lr_decay_gamma=lr_decay_gamma,
-                num_steps_each_epoch=num_steps_each_epoch)
-        self.optimizer = optimizer
-        # 构建训练、验证、预测网络
-        self.build_program()
-        # 初始化网络权重
-        self.net_initialize(
-            startup_prog=fluid.default_startup_program(),
-            pretrain_weights=pretrain_weights,
-            save_dir=save_dir,
-            sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss,
-            resume_checkpoint=resume_checkpoint)
-        # 训练
-        self.train_loop(
-            num_epochs=num_epochs,
-            train_dataset=train_dataset,
-            train_batch_size=train_batch_size,
-            eval_dataset=eval_dataset,
-            save_interval_epochs=save_interval_epochs,
-            log_interval_steps=log_interval_steps,
-            save_dir=save_dir,
-            use_vdl=use_vdl,
-            early_stop=early_stop,
-            early_stop_patience=early_stop_patience)
-
-    def evaluate(self,
-                 eval_dataset,
-                 batch_size=1,
-                 epoch_id=None,
-                 metric=None,
-                 return_details=False):
-        """评估。
-
-        Args:
-            eval_dataset (paddlex.datasets): 验证数据读取器。
-            batch_size (int): 验证数据批大小。默认为1。
-            epoch_id (int): 当前评估模型所在的训练轮数。
-            metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
-                根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
-                如为COCODetection,则metric为'COCO'。
-            return_details (bool): 是否返回详细信息。
-
-        Returns:
-            tuple (metrics, eval_details) | dict (metrics): 当return_details为True时,返回(metrics, eval_details),
-                当return_details为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'或者’bbox_map‘,
-                分别表示平均准确率平均值在各个IoU阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。
-                eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、
-                预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
-        """
-        arrange_transforms(
-            model_type=self.model_type,
-            class_name=self.__class__.__name__,
-            transforms=eval_dataset.transforms,
-            mode='eval')
-        if metric is None:
-            if hasattr(self, 'metric') and self.metric is not None:
-                metric = self.metric
-            else:
-                if isinstance(eval_dataset, paddlex.datasets.CocoDetection):
-                    metric = 'COCO'
-                elif isinstance(eval_dataset, paddlex.datasets.VOCDetection):
-                    metric = 'VOC'
-                else:
-                    raise Exception(
-                        "eval_dataset should be datasets.VOCDetection or datasets.COCODetection."
-                    )
-        assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
-
-        total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
-        results = list()
-
-        data_generator = eval_dataset.generator(
-            batch_size=batch_size, drop_last=False)
-        logging.info(
-            "Start to evaluating(total_samples={}, total_steps={})...".format(
-                eval_dataset.num_samples, total_steps))
-        for step, data in tqdm.tqdm(
-                enumerate(data_generator()), total=total_steps):
-            images = np.array([d[0] for d in data])
-            im_sizes = np.array([d[1] for d in data])
-            feed_data = {'image': images, 'im_size': im_sizes}
-            with fluid.scope_guard(self.scope):
-                outputs = self.exe.run(
-                    self.test_prog,
-                    feed=[feed_data],
-                    fetch_list=list(self.test_outputs.values()),
-                    return_numpy=False)
-            res = {
-                'bbox': (np.array(outputs[0]),
-                         outputs[0].recursive_sequence_lengths())
-            }
-            res_id = [np.array([d[2]]) for d in data]
-            res['im_id'] = (res_id, [])
-            if metric == 'VOC':
-                res_gt_box = [d[3].reshape(-1, 4) for d in data]
-                res_gt_label = [d[4].reshape(-1, 1) for d in data]
-                res_is_difficult = [d[5].reshape(-1, 1) for d in data]
-                res_id = [np.array([d[2]]) for d in data]
-                res['gt_box'] = (res_gt_box, [])
-                res['gt_label'] = (res_gt_label, [])
-                res['is_difficult'] = (res_is_difficult, [])
-            results.append(res)
-            logging.debug("[EVAL] Epoch={}, Step={}/{}".format(epoch_id, step +
-                                                               1, total_steps))
-        box_ap_stats, eval_details = eval_results(
-            results, metric, eval_dataset.coco_gt, with_background=False)
-        evaluate_metrics = OrderedDict(
-            zip(['bbox_mmap'
-                 if metric == 'COCO' else 'bbox_map'], box_ap_stats))
-        if return_details:
-            return evaluate_metrics, eval_details
-        return evaluate_metrics
-
-    @staticmethod
-    def _preprocess(images, transforms, model_type, class_name, thread_num=1):
-        arrange_transforms(
-            model_type=model_type,
-            class_name=class_name,
-            transforms=transforms,
-            mode='test')
-        pool = ThreadPool(thread_num)
-        batch_data = pool.map(transforms, images)
-        pool.close()
-        pool.join()
-        padding_batch = generate_minibatch(batch_data)
-        im = np.array(
-            [data[0] for data in padding_batch],
-            dtype=padding_batch[0][0].dtype)
-        im_size = np.array([data[1] for data in padding_batch], dtype=np.int32)
-
-        return im, im_size
-
-    @staticmethod
-    def _postprocess(res, batch_size, num_classes, labels):
-        clsid2catid = dict({i: i for i in range(num_classes)})
-        xywh_results = bbox2out([res], clsid2catid)
-        preds = [[] for i in range(batch_size)]
-        for xywh_res in xywh_results:
-            image_id = xywh_res['image_id']
-            del xywh_res['image_id']
-            xywh_res['category'] = labels[xywh_res['category_id']]
-            preds[image_id].append(xywh_res)
-
-        return preds
-
-    def predict(self, img_file, transforms=None):
-        """预测。
-
-        Args:
-            img_file (str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
-            transforms (paddlex.det.transforms): 数据预处理操作。
-
-        Returns:
-            list: 预测结果列表,每个预测结果由预测框类别标签、
-              预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
-              预测框得分组成。
-        """
-        if transforms is None and not hasattr(self, 'test_transforms'):
-            raise Exception("transforms need to be defined, now is None.")
-        if isinstance(img_file, (str, np.ndarray)):
-            images = [img_file]
-        else:
-            raise Exception("img_file must be str/np.ndarray")
-
-        if transforms is None:
-            transforms = self.test_transforms
-        im, im_size = YOLOv3._preprocess(images, transforms, self.model_type,
-                                         self.__class__.__name__)
-
-        with fluid.scope_guard(self.scope):
-            result = self.exe.run(self.test_prog,
-                                  feed={'image': im,
-                                        'im_size': im_size},
-                                  fetch_list=list(self.test_outputs.values()),
-                                  return_numpy=False,
-                                  use_program_cache=True)
-
-        res = {
-            k: (np.array(v), v.recursive_sequence_lengths())
-            for k, v in zip(list(self.test_outputs.keys()), result)
-        }
-        res['im_id'] = (np.array(
-            [[i] for i in range(len(images))]).astype('int32'), [[]])
-        preds = YOLOv3._postprocess(res,
-                                    len(images), self.num_classes, self.labels)
-        return preds[0]
-
-    def batch_predict(self, img_file_list, transforms=None, thread_num=2):
-        """预测。
-
-        Args:
-            img_file_list (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径,也可以是解码后的排列格式为(H,W,C)
-                且类型为float32且为BGR格式的数组。
-            transforms (paddlex.det.transforms): 数据预处理操作。
-            thread_num (int): 并发执行各图像预处理时的线程数。
-        Returns:
-            list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、
-              预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
-              预测框得分组成。
-        """
-        if transforms is None and not hasattr(self, 'test_transforms'):
-            raise Exception("transforms need to be defined, now is None.")
-
-        if not isinstance(img_file_list, (list, tuple)):
-            raise Exception("im_file must be list/tuple")
-
-        if transforms is None:
-            transforms = self.test_transforms
-        im, im_size = YOLOv3._preprocess(img_file_list, transforms,
-                                         self.model_type,
-                                         self.__class__.__name__, thread_num)
-
-        with fluid.scope_guard(self.scope):
-            result = self.exe.run(self.test_prog,
-                                  feed={'image': im,
-                                        'im_size': im_size},
-                                  fetch_list=list(self.test_outputs.values()),
-                                  return_numpy=False,
-                                  use_program_cache=True)
 
-        res = {
-            k: (np.array(v), v.recursive_sequence_lengths())
-            for k, v in zip(list(self.test_outputs.keys()), result)
-        }
-        res['im_id'] = (np.array(
-            [[i] for i in range(len(img_file_list))]).astype('int32'), [[]])
-        preds = YOLOv3._postprocess(res,
-                                    len(img_file_list), self.num_classes,
-                                    self.labels)
-        return preds
+        return super(YOLOv3, self).train(
+            num_epochs, train_dataset, train_batch_size, eval_dataset,
+            save_interval_epochs, log_interval_steps, save_dir,
+            pretrain_weights, optimizer, learning_rate, warmup_steps,
+            warmup_start_lr, lr_decay_epochs, lr_decay_gamma, metric, use_vdl,
+            sensitivities_file, eval_metric_loss, early_stop,
+            early_stop_patience, resume_checkpoint, False)

+ 85 - 0
paddlex/cv/nets/detection/iou_aware.py

@@ -0,0 +1,85 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import fluid
+
+
+def _split_ioup(output, an_num, num_classes):
+    """
+    Split new output feature map to output, predicted iou
+    along channel dimension
+    """
+    ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
+    ioup = fluid.layers.sigmoid(ioup)
+
+    oriout = fluid.layers.slice(
+        output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)])
+
+    return (ioup, oriout)
+
+
+def _de_sigmoid(x, eps=1e-7):
+    x = fluid.layers.clip(x, eps, 1 / eps)
+    one = fluid.layers.fill_constant(
+        shape=[1, 1, 1, 1], dtype=x.dtype, value=1.)
+    x = fluid.layers.clip((one / x - 1.0), eps, 1 / eps)
+    x = -fluid.layers.log(x)
+    return x
+
+
+def _postprocess_output(ioup, output, an_num, num_classes, iou_aware_factor):
+    """
+    post process output objectness score
+    """
+    tensors = []
+    stride = output.shape[1] // an_num
+    for m in range(an_num):
+        tensors.append(
+            fluid.layers.slice(
+                output,
+                axes=[1],
+                starts=[stride * m + 0],
+                ends=[stride * m + 4]))
+        obj = fluid.layers.slice(
+            output, axes=[1], starts=[stride * m + 4], ends=[stride * m + 5])
+        obj = fluid.layers.sigmoid(obj)
+        ip = fluid.layers.slice(ioup, axes=[1], starts=[m], ends=[m + 1])
+
+        new_obj = fluid.layers.pow(obj, (
+            1 - iou_aware_factor)) * fluid.layers.pow(ip, iou_aware_factor)
+        new_obj = _de_sigmoid(new_obj)
+
+        tensors.append(new_obj)
+
+        tensors.append(
+            fluid.layers.slice(
+                output,
+                axes=[1],
+                starts=[stride * m + 5],
+                ends=[stride * m + 5 + num_classes]))
+
+    output = fluid.layers.concat(tensors, axis=1)
+
+    return output
+
+
+def get_iou_aware_score(output, an_num, num_classes, iou_aware_factor):
+    ioup, output = _split_ioup(output, an_num, num_classes)
+    output = _postprocess_output(ioup, output, an_num, num_classes,
+                                 iou_aware_factor)
+    return output

+ 77 - 0
paddlex/cv/nets/detection/loss/iou_aware_loss.py

@@ -0,0 +1,77 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.initializer import NumpyArrayInitializer
+
+from paddle import fluid
+from .iou_loss import IouLoss
+
+
+class IouAwareLoss(IouLoss):
+    """
+    iou aware loss, see https://arxiv.org/abs/1912.05992
+    Args:
+        loss_weight (float): iou aware loss weight, default is 1.0
+        max_height (int): max height of input to support random shape input
+        max_width (int): max width of input to support random shape input
+    """
+
+    def __init__(self, loss_weight=1.0, max_height=608, max_width=608):
+        super(IouAwareLoss, self).__init__(
+            loss_weight=loss_weight,
+            max_height=max_height,
+            max_width=max_width)
+
+    def __call__(self,
+                 ioup,
+                 x,
+                 y,
+                 w,
+                 h,
+                 tx,
+                 ty,
+                 tw,
+                 th,
+                 anchors,
+                 downsample_ratio,
+                 batch_size,
+                 scale_x_y,
+                 eps=1.e-10):
+        '''
+        Args:
+            ioup ([Variables]): the predicted iou
+            x  | y | w | h  ([Variables]): the output of yolov3 for encoded x|y|w|h
+            tx |ty |tw |th  ([Variables]): the target of yolov3 for encoded x|y|w|h
+            anchors ([float]): list of anchors for current output layer
+            downsample_ratio (float): the downsample ratio for current output layer
+            batch_size (int): training batch size
+            eps (float): the decimal to prevent the denominator eqaul zero
+        '''
+
+        pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
+                                    batch_size, False, scale_x_y, eps)
+        gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
+                                  batch_size, True, scale_x_y, eps)
+        iouk = self._iou(pred, gt, ioup, eps)
+        iouk.stop_gradient = True
+
+        loss_iou_aware = fluid.layers.cross_entropy(
+            ioup, iouk, soft_label=True)
+        loss_iou_aware = loss_iou_aware * self._loss_weight
+        return loss_iou_aware

+ 235 - 0
paddlex/cv/nets/detection/loss/iou_loss.py

@@ -0,0 +1,235 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.initializer import NumpyArrayInitializer
+
+from paddle import fluid
+
+
+class IouLoss(object):
+    """
+    iou loss, see https://arxiv.org/abs/1908.03851
+    loss = 1.0 - iou * iou
+    Args:
+        loss_weight (float): iou loss weight, default is 2.5
+        max_height (int): max height of input to support random shape input
+        max_width (int): max width of input to support random shape input
+        ciou_term (bool): whether to add ciou_term
+        loss_square (bool): whether to square the iou term
+    """
+
+    def __init__(self,
+                 loss_weight=2.5,
+                 max_height=608,
+                 max_width=608,
+                 ciou_term=False,
+                 loss_square=True):
+        self._loss_weight = loss_weight
+        self._MAX_HI = max_height
+        self._MAX_WI = max_width
+        self.ciou_term = ciou_term
+        self.loss_square = loss_square
+
+    def __call__(self,
+                 x,
+                 y,
+                 w,
+                 h,
+                 tx,
+                 ty,
+                 tw,
+                 th,
+                 anchors,
+                 downsample_ratio,
+                 batch_size,
+                 scale_x_y=1.,
+                 ioup=None,
+                 eps=1.e-10):
+        '''
+        Args:
+            x  | y | w | h  ([Variables]): the output of yolov3 for encoded x|y|w|h
+            tx |ty |tw |th  ([Variables]): the target of yolov3 for encoded x|y|w|h
+            anchors ([float]): list of anchors for current output layer
+            downsample_ratio (float): the downsample ratio for current output layer
+            batch_size (int): training batch size
+            eps (float): the decimal to prevent the denominator eqaul zero
+        '''
+        pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
+                                    batch_size, False, scale_x_y, eps)
+        gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
+                                  batch_size, True, scale_x_y, eps)
+        iouk = self._iou(pred, gt, ioup, eps)
+        if self.loss_square:
+            loss_iou = 1. - iouk * iouk
+        else:
+            loss_iou = 1. - iouk
+        loss_iou = loss_iou * self._loss_weight
+
+        return loss_iou
+
+    def _iou(self, pred, gt, ioup=None, eps=1.e-10):
+        x1, y1, x2, y2 = pred
+        x1g, y1g, x2g, y2g = gt
+        x2 = fluid.layers.elementwise_max(x1, x2)
+        y2 = fluid.layers.elementwise_max(y1, y2)
+
+        xkis1 = fluid.layers.elementwise_max(x1, x1g)
+        ykis1 = fluid.layers.elementwise_max(y1, y1g)
+        xkis2 = fluid.layers.elementwise_min(x2, x2g)
+        ykis2 = fluid.layers.elementwise_min(y2, y2g)
+
+        intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
+        intsctk = intsctk * fluid.layers.greater_than(
+            xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
+        unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
+                                                        ) - intsctk + eps
+        iouk = intsctk / unionk
+        if self.ciou_term:
+            ciou = self.get_ciou_term(pred, gt, iouk, eps)
+            iouk = iouk - ciou
+        return iouk
+
+    def get_ciou_term(self, pred, gt, iouk, eps):
+        x1, y1, x2, y2 = pred
+        x1g, y1g, x2g, y2g = gt
+
+        cx = (x1 + x2) / 2
+        cy = (y1 + y2) / 2
+        w = (x2 - x1) + fluid.layers.cast((x2 - x1) == 0, 'float32')
+        h = (y2 - y1) + fluid.layers.cast((y2 - y1) == 0, 'float32')
+
+        cxg = (x1g + x2g) / 2
+        cyg = (y1g + y2g) / 2
+        wg = x2g - x1g
+        hg = y2g - y1g
+
+        # A or B
+        xc1 = fluid.layers.elementwise_min(x1, x1g)
+        yc1 = fluid.layers.elementwise_min(y1, y1g)
+        xc2 = fluid.layers.elementwise_max(x2, x2g)
+        yc2 = fluid.layers.elementwise_max(y2, y2g)
+
+        # DIOU term
+        dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg)
+        dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
+        diou_term = (dist_intersection + eps) / (dist_union + eps)
+        # CIOU term
+        ciou_term = 0
+        ar_gt = wg / hg
+        ar_pred = w / h
+        arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred)
+        ar_loss = 4. / np.pi / np.pi * arctan * arctan
+        alpha = ar_loss / (1 - iouk + ar_loss + eps)
+        alpha.stop_gradient = True
+        ciou_term = alpha * ar_loss
+        return diou_term + ciou_term
+
+    def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio,
+                        batch_size, is_gt, scale_x_y, eps):
+        grid_x = int(self._MAX_WI / downsample_ratio)
+        grid_y = int(self._MAX_HI / downsample_ratio)
+        an_num = len(anchors) // 2
+
+        shape_fmp = fluid.layers.shape(dcx)
+        shape_fmp.stop_gradient = True
+        # generate the grid_w x grid_h center of feature map
+        idx_i = np.array([[i for i in range(grid_x)]])
+        idx_j = np.array([[j for j in range(grid_y)]]).transpose()
+        gi_np = np.repeat(idx_i, grid_y, axis=0)
+        gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
+        gi_np = np.tile(gi_np, reps=[batch_size, an_num, 1, 1])
+        gj_np = np.repeat(idx_j, grid_x, axis=1)
+        gj_np = np.reshape(gj_np, newshape=[1, 1, grid_y, grid_x])
+        gj_np = np.tile(gj_np, reps=[batch_size, an_num, 1, 1])
+        gi_max = self._create_tensor_from_numpy(gi_np.astype(np.float32))
+        gi = fluid.layers.crop(x=gi_max, shape=dcx)
+        gi.stop_gradient = True
+        gj_max = self._create_tensor_from_numpy(gj_np.astype(np.float32))
+        gj = fluid.layers.crop(x=gj_max, shape=dcx)
+        gj.stop_gradient = True
+
+        grid_x_act = fluid.layers.cast(shape_fmp[3], dtype="float32")
+        grid_x_act.stop_gradient = True
+        grid_y_act = fluid.layers.cast(shape_fmp[2], dtype="float32")
+        grid_y_act.stop_gradient = True
+        if is_gt:
+            cx = fluid.layers.elementwise_add(dcx, gi) / grid_x_act
+            cx.gradient = True
+            cy = fluid.layers.elementwise_add(dcy, gj) / grid_y_act
+            cy.gradient = True
+        else:
+            dcx_sig = fluid.layers.sigmoid(dcx)
+            dcy_sig = fluid.layers.sigmoid(dcy)
+            if (abs(scale_x_y - 1.0) > eps):
+                dcx_sig = scale_x_y * dcx_sig - 0.5 * (scale_x_y - 1)
+                dcy_sig = scale_x_y * dcy_sig - 0.5 * (scale_x_y - 1)
+            cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act
+            cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act
+
+        anchor_w_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 0]
+        anchor_w_np = np.array(anchor_w_)
+        anchor_w_np = np.reshape(anchor_w_np, newshape=[1, an_num, 1, 1])
+        anchor_w_np = np.tile(
+            anchor_w_np, reps=[batch_size, 1, grid_y, grid_x])
+        anchor_w_max = self._create_tensor_from_numpy(
+            anchor_w_np.astype(np.float32))
+        anchor_w = fluid.layers.crop(x=anchor_w_max, shape=dcx)
+        anchor_w.stop_gradient = True
+        anchor_h_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 1]
+        anchor_h_np = np.array(anchor_h_)
+        anchor_h_np = np.reshape(anchor_h_np, newshape=[1, an_num, 1, 1])
+        anchor_h_np = np.tile(
+            anchor_h_np, reps=[batch_size, 1, grid_y, grid_x])
+        anchor_h_max = self._create_tensor_from_numpy(
+            anchor_h_np.astype(np.float32))
+        anchor_h = fluid.layers.crop(x=anchor_h_max, shape=dcx)
+        anchor_h.stop_gradient = True
+        # e^tw e^th
+        exp_dw = fluid.layers.exp(dw)
+        exp_dh = fluid.layers.exp(dh)
+        pw = fluid.layers.elementwise_mul(exp_dw, anchor_w) / \
+            (grid_x_act * downsample_ratio)
+        ph = fluid.layers.elementwise_mul(exp_dh, anchor_h) / \
+            (grid_y_act * downsample_ratio)
+        if is_gt:
+            exp_dw.stop_gradient = True
+            exp_dh.stop_gradient = True
+            pw.stop_gradient = True
+            ph.stop_gradient = True
+
+        x1 = cx - 0.5 * pw
+        y1 = cy - 0.5 * ph
+        x2 = cx + 0.5 * pw
+        y2 = cy + 0.5 * ph
+        if is_gt:
+            x1.stop_gradient = True
+            y1.stop_gradient = True
+            x2.stop_gradient = True
+            y2.stop_gradient = True
+
+        return x1, y1, x2, y2
+
+    def _create_tensor_from_numpy(self, numpy_array):
+        paddle_array = fluid.layers.create_parameter(
+            attr=ParamAttr(),
+            shape=numpy_array.shape,
+            dtype=numpy_array.dtype,
+            default_initializer=NumpyArrayInitializer(numpy_array))
+        paddle_array.stop_gradient = True
+        return paddle_array

+ 371 - 0
paddlex/cv/nets/detection/loss/yolo_loss.py

@@ -0,0 +1,371 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import fluid
+try:
+    from collections.abc import Sequence
+except Exception:
+    from collections import Sequence
+
+
+class YOLOv3Loss(object):
+    """
+    Combined loss for YOLOv3 network
+
+    Args:
+        batch_size (int): training batch size
+        ignore_thresh (float): threshold to ignore confidence loss
+        label_smooth (bool): whether to use label smoothing
+        use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss
+                                      instead of fluid.layers.yolov3_loss
+    """
+
+    def __init__(self,
+                 batch_size=8,
+                 ignore_thresh=0.7,
+                 label_smooth=True,
+                 use_fine_grained_loss=False,
+                 iou_loss=None,
+                 iou_aware_loss=None,
+                 downsample=[32, 16, 8],
+                 scale_x_y=1.,
+                 match_score=False):
+        self._batch_size = batch_size
+        self._ignore_thresh = ignore_thresh
+        self._label_smooth = label_smooth
+        self._use_fine_grained_loss = use_fine_grained_loss
+        self._iou_loss = iou_loss
+        self._iou_aware_loss = iou_aware_loss
+        self.downsample = downsample
+        self.scale_x_y = scale_x_y
+        self.match_score = match_score
+
+    def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors,
+                 anchor_masks, mask_anchors, num_classes, prefix_name):
+        if self._use_fine_grained_loss:
+            return self._get_fine_grained_loss(
+                outputs, targets, gt_box, self._batch_size, num_classes,
+                mask_anchors, self._ignore_thresh)
+        else:
+            losses = []
+            for i, output in enumerate(outputs):
+                scale_x_y = self.scale_x_y if not isinstance(
+                    self.scale_x_y, Sequence) else self.scale_x_y[i]
+                anchor_mask = anchor_masks[i]
+                loss = fluid.layers.yolov3_loss(
+                    x=output,
+                    gt_box=gt_box,
+                    gt_label=gt_label,
+                    gt_score=gt_score,
+                    anchors=anchors,
+                    anchor_mask=anchor_mask,
+                    class_num=num_classes,
+                    ignore_thresh=self._ignore_thresh,
+                    downsample_ratio=self.downsample[i],
+                    use_label_smooth=self._label_smooth,
+                    scale_x_y=scale_x_y,
+                    name=prefix_name + "yolo_loss" + str(i))
+
+                losses.append(fluid.layers.reduce_mean(loss))
+
+            return {'loss': sum(losses)}
+
+    def _get_fine_grained_loss(self,
+                               outputs,
+                               targets,
+                               gt_box,
+                               batch_size,
+                               num_classes,
+                               mask_anchors,
+                               ignore_thresh,
+                               eps=1.e-10):
+        """
+        Calculate fine grained YOLOv3 loss
+
+        Args:
+            outputs ([Variables]): List of Variables, output of backbone stages
+            targets ([Variables]): List of Variables, The targets for yolo
+                                   loss calculatation.
+            gt_box (Variable): The ground-truth boudding boxes.
+            batch_size (int): The training batch size
+            num_classes (int): class num of dataset
+            mask_anchors ([[float]]): list of anchors in each output layer
+            ignore_thresh (float): prediction bbox overlap any gt_box greater
+                                   than ignore_thresh, objectness loss will
+                                   be ignored.
+
+        Returns:
+            Type: dict
+                xy_loss (Variable): YOLOv3 (x, y) coordinates loss
+                wh_loss (Variable): YOLOv3 (w, h) coordinates loss
+                obj_loss (Variable): YOLOv3 objectness score loss
+                cls_loss (Variable): YOLOv3 classification loss
+
+        """
+
+        assert len(outputs) == len(targets), \
+            "YOLOv3 output layer number not equal target number"
+
+        loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], []
+        if self._iou_loss is not None:
+            loss_ious = []
+        if self._iou_aware_loss is not None:
+            loss_iou_awares = []
+        for i, (output, target,
+                anchors) in enumerate(zip(outputs, targets, mask_anchors)):
+            downsample = self.downsample[i]
+            an_num = len(anchors) // 2
+            if self._iou_aware_loss is not None:
+                ioup, output = self._split_ioup(output, an_num, num_classes)
+            x, y, w, h, obj, cls = self._split_output(output, an_num,
+                                                      num_classes)
+            tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target)
+
+            tscale_tobj = tscale * tobj
+
+            scale_x_y = self.scale_x_y if not isinstance(
+                self.scale_x_y, Sequence) else self.scale_x_y[i]
+
+            if (abs(scale_x_y - 1.0) < eps):
+                loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
+                    x, tx) * tscale_tobj
+                loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
+                loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
+                    y, ty) * tscale_tobj
+                loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
+            else:
+                dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y -
+                                                                  1.0)
+                dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y -
+                                                                  1.0)
+                loss_x = fluid.layers.abs(dx - tx) * tscale_tobj
+                loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
+                loss_y = fluid.layers.abs(dy - ty) * tscale_tobj
+                loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
+
+            # NOTE: we refined loss function of (w, h) as L1Loss
+            loss_w = fluid.layers.abs(w - tw) * tscale_tobj
+            loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
+            loss_h = fluid.layers.abs(h - th) * tscale_tobj
+            loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
+            if self._iou_loss is not None:
+                loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors,
+                                          downsample, self._batch_size,
+                                          scale_x_y)
+                loss_iou = loss_iou * tscale_tobj
+                loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3])
+                loss_ious.append(fluid.layers.reduce_mean(loss_iou))
+
+            if self._iou_aware_loss is not None:
+                loss_iou_aware = self._iou_aware_loss(
+                    ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample,
+                    self._batch_size, scale_x_y)
+                loss_iou_aware = loss_iou_aware * tobj
+                loss_iou_aware = fluid.layers.reduce_sum(
+                    loss_iou_aware, dim=[1, 2, 3])
+                loss_iou_awares.append(
+                    fluid.layers.reduce_mean(loss_iou_aware))
+
+            loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
+                output, obj, tobj, gt_box, self._batch_size, anchors,
+                num_classes, downsample, self._ignore_thresh, scale_x_y)
+
+            loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls,
+                                                                      tcls)
+            loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
+            loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
+
+            loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y))
+            loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h))
+            loss_objs.append(
+                fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg))
+            loss_clss.append(fluid.layers.reduce_mean(loss_cls))
+
+        losses_all = {
+            "loss_xy": fluid.layers.sum(loss_xys),
+            "loss_wh": fluid.layers.sum(loss_whs),
+            "loss_obj": fluid.layers.sum(loss_objs),
+            "loss_cls": fluid.layers.sum(loss_clss),
+        }
+        if self._iou_loss is not None:
+            losses_all["loss_iou"] = fluid.layers.sum(loss_ious)
+        if self._iou_aware_loss is not None:
+            losses_all["loss_iou_aware"] = fluid.layers.sum(loss_iou_awares)
+        return losses_all
+
+    def _split_ioup(self, output, an_num, num_classes):
+        """
+        Split output feature map to output, predicted iou
+        along channel dimension
+        """
+        ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
+        ioup = fluid.layers.sigmoid(ioup)
+        oriout = fluid.layers.slice(
+            output,
+            axes=[1],
+            starts=[an_num],
+            ends=[an_num * (num_classes + 6)])
+        return (ioup, oriout)
+
+    def _split_output(self, output, an_num, num_classes):
+        """
+        Split output feature map to x, y, w, h, objectness, classification
+        along channel dimension
+        """
+        x = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[0],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        y = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[1],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        w = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[2],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        h = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[3],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        obj = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[4],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        clss = []
+        stride = output.shape[1] // an_num
+        for m in range(an_num):
+            clss.append(
+                fluid.layers.slice(
+                    output,
+                    axes=[1],
+                    starts=[stride * m + 5],
+                    ends=[stride * m + 5 + num_classes]))
+        cls = fluid.layers.transpose(
+            fluid.layers.stack(
+                clss, axis=1), perm=[0, 1, 3, 4, 2])
+
+        return (x, y, w, h, obj, cls)
+
+    def _split_target(self, target):
+        """
+        split target to x, y, w, h, objectness, classification
+        along dimension 2
+
+        target is in shape [N, an_num, 6 + class_num, H, W]
+        """
+        tx = target[:, :, 0, :, :]
+        ty = target[:, :, 1, :, :]
+        tw = target[:, :, 2, :, :]
+        th = target[:, :, 3, :, :]
+
+        tscale = target[:, :, 4, :, :]
+        tobj = target[:, :, 5, :, :]
+
+        tcls = fluid.layers.transpose(
+            target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
+        tcls.stop_gradient = True
+
+        return (tx, ty, tw, th, tscale, tobj, tcls)
+
+    def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors,
+                       num_classes, downsample, ignore_thresh, scale_x_y):
+        # A prediction bbox overlap any gt_bbox over ignore_thresh,
+        # objectness loss will be ignored, process as follows:
+
+        # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
+        # NOTE: img_size is set as 1.0 to get noramlized pred bbox
+        bbox, prob = fluid.layers.yolo_box(
+            x=output,
+            img_size=fluid.layers.ones(
+                shape=[batch_size, 2], dtype="int32"),
+            anchors=anchors,
+            class_num=num_classes,
+            conf_thresh=0.,
+            downsample_ratio=downsample,
+            clip_bbox=False,
+            scale_x_y=scale_x_y)
+
+        # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
+        #    and gt bbox in each sample
+        if batch_size > 1:
+            preds = fluid.layers.split(bbox, batch_size, dim=0)
+            gts = fluid.layers.split(gt_box, batch_size, dim=0)
+        else:
+            preds = [bbox]
+            gts = [gt_box]
+            probs = [prob]
+        ious = []
+        for pred, gt in zip(preds, gts):
+
+            def box_xywh2xyxy(box):
+                x = box[:, 0]
+                y = box[:, 1]
+                w = box[:, 2]
+                h = box[:, 3]
+                return fluid.layers.stack(
+                    [
+                        x - w / 2.,
+                        y - h / 2.,
+                        x + w / 2.,
+                        y + h / 2.,
+                    ], axis=1)
+
+            pred = fluid.layers.squeeze(pred, axes=[0])
+            gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
+            ious.append(fluid.layers.iou_similarity(pred, gt))
+
+        iou = fluid.layers.stack(ious, axis=0)
+        # 3. Get iou_mask by IoU between gt bbox and prediction bbox,
+        #    Get obj_mask by tobj(holds gt_score), calculate objectness loss
+
+        max_iou = fluid.layers.reduce_max(iou, dim=-1)
+        iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
+        if self.match_score:
+            max_prob = fluid.layers.reduce_max(prob, dim=-1)
+            iou_mask = iou_mask * fluid.layers.cast(
+                max_prob <= 0.25, dtype="float32")
+        output_shape = fluid.layers.shape(output)
+        an_num = len(anchors) // 2
+        iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
+                                                   output_shape[3]))
+        iou_mask.stop_gradient = True
+
+        # NOTE: tobj holds gt_score, obj_mask holds object existence mask
+        obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
+        obj_mask.stop_gradient = True
+
+        # For positive objectness grids, objectness loss should be calculated
+        # For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
+        loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj,
+                                                                  obj_mask)
+        loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
+        loss_obj_neg = fluid.layers.reduce_sum(
+            loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
+
+        return loss_obj_pos, loss_obj_neg

+ 270 - 0
paddlex/cv/nets/detection/ops.py

@@ -0,0 +1,270 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from numbers import Integral
+import math
+import six
+
+import paddle
+from paddle import fluid
+
+
+def DropBlock(input, block_size, keep_prob, is_test):
+    if is_test:
+        return input
+
+    def CalculateGamma(input, block_size, keep_prob):
+        input_shape = fluid.layers.shape(input)
+        feat_shape_tmp = fluid.layers.slice(input_shape, [0], [3], [4])
+        feat_shape_tmp = fluid.layers.cast(feat_shape_tmp, dtype="float32")
+        feat_shape_t = fluid.layers.reshape(feat_shape_tmp, [1, 1, 1, 1])
+        feat_area = fluid.layers.pow(feat_shape_t, factor=2)
+
+        block_shape_t = fluid.layers.fill_constant(
+            shape=[1, 1, 1, 1], value=block_size, dtype='float32')
+        block_area = fluid.layers.pow(block_shape_t, factor=2)
+
+        useful_shape_t = feat_shape_t - block_shape_t + 1
+        useful_area = fluid.layers.pow(useful_shape_t, factor=2)
+
+        upper_t = feat_area * (1 - keep_prob)
+        bottom_t = block_area * useful_area
+        output = upper_t / bottom_t
+        return output
+
+    gamma = CalculateGamma(input, block_size=block_size, keep_prob=keep_prob)
+    input_shape = fluid.layers.shape(input)
+    p = fluid.layers.expand_as(gamma, input)
+
+    input_shape_tmp = fluid.layers.cast(input_shape, dtype="int64")
+    random_matrix = fluid.layers.uniform_random(
+        input_shape_tmp, dtype='float32', min=0.0, max=1.0)
+    one_zero_m = fluid.layers.less_than(random_matrix, p)
+    one_zero_m.stop_gradient = True
+    one_zero_m = fluid.layers.cast(one_zero_m, dtype="float32")
+
+    mask_flag = fluid.layers.pool2d(
+        one_zero_m,
+        pool_size=block_size,
+        pool_type='max',
+        pool_stride=1,
+        pool_padding=block_size // 2)
+    mask = 1.0 - mask_flag
+
+    elem_numel = fluid.layers.reduce_prod(input_shape)
+    elem_numel_m = fluid.layers.cast(elem_numel, dtype="float32")
+    elem_numel_m.stop_gradient = True
+
+    elem_sum = fluid.layers.reduce_sum(mask)
+    elem_sum_m = fluid.layers.cast(elem_sum, dtype="float32")
+    elem_sum_m.stop_gradient = True
+
+    output = input * mask * elem_numel_m / elem_sum_m
+    return output
+
+
+class MultiClassNMS(object):
+    def __init__(self,
+                 score_threshold=.05,
+                 nms_top_k=-1,
+                 keep_top_k=100,
+                 nms_threshold=.5,
+                 normalized=False,
+                 nms_eta=1.0,
+                 background_label=0):
+        super(MultiClassNMS, self).__init__()
+        self.score_threshold = score_threshold
+        self.nms_top_k = nms_top_k
+        self.keep_top_k = keep_top_k
+        self.nms_threshold = nms_threshold
+        self.normalized = normalized
+        self.nms_eta = nms_eta
+        self.background_label = background_label
+
+    def __call__(self, bboxes, scores):
+        return fluid.layers.multiclass_nms(
+            bboxes=bboxes,
+            scores=scores,
+            score_threshold=self.score_threshold,
+            nms_top_k=self.nms_top_k,
+            keep_top_k=self.keep_top_k,
+            normalized=self.normalized,
+            nms_threshold=self.nms_threshold,
+            nms_eta=self.nms_eta,
+            background_label=self.background_label)
+
+
+class MatrixNMS(object):
+    def __init__(self,
+                 score_threshold=.05,
+                 post_threshold=.05,
+                 nms_top_k=-1,
+                 keep_top_k=100,
+                 use_gaussian=False,
+                 gaussian_sigma=2.,
+                 normalized=False,
+                 background_label=0):
+        super(MatrixNMS, self).__init__()
+        self.score_threshold = score_threshold
+        self.post_threshold = post_threshold
+        self.nms_top_k = nms_top_k
+        self.keep_top_k = keep_top_k
+        self.normalized = normalized
+        self.use_gaussian = use_gaussian
+        self.gaussian_sigma = gaussian_sigma
+        self.background_label = background_label
+
+    def __call__(self, bboxes, scores):
+        return paddle.fluid.layers.matrix_nms(
+            bboxes=bboxes,
+            scores=scores,
+            score_threshold=self.score_threshold,
+            post_threshold=self.post_threshold,
+            nms_top_k=self.nms_top_k,
+            keep_top_k=self.keep_top_k,
+            normalized=self.normalized,
+            use_gaussian=self.use_gaussian,
+            gaussian_sigma=self.gaussian_sigma,
+            background_label=self.background_label)
+
+
+class MultiClassSoftNMS(object):
+    def __init__(
+            self,
+            score_threshold=0.01,
+            keep_top_k=300,
+            softnms_sigma=0.5,
+            normalized=False,
+            background_label=0, ):
+        super(MultiClassSoftNMS, self).__init__()
+        self.score_threshold = score_threshold
+        self.keep_top_k = keep_top_k
+        self.softnms_sigma = softnms_sigma
+        self.normalized = normalized
+        self.background_label = background_label
+
+    def __call__(self, bboxes, scores):
+        def create_tmp_var(program, name, dtype, shape, lod_level):
+            return program.current_block().create_var(
+                name=name, dtype=dtype, shape=shape, lod_level=lod_level)
+
+        def _soft_nms_for_cls(dets, sigma, thres):
+            """soft_nms_for_cls"""
+            dets_final = []
+            while len(dets) > 0:
+                maxpos = np.argmax(dets[:, 0])
+                dets_final.append(dets[maxpos].copy())
+                ts, tx1, ty1, tx2, ty2 = dets[maxpos]
+                scores = dets[:, 0]
+                # force remove bbox at maxpos
+                scores[maxpos] = -1
+                x1 = dets[:, 1]
+                y1 = dets[:, 2]
+                x2 = dets[:, 3]
+                y2 = dets[:, 4]
+                eta = 0 if self.normalized else 1
+                areas = (x2 - x1 + eta) * (y2 - y1 + eta)
+                xx1 = np.maximum(tx1, x1)
+                yy1 = np.maximum(ty1, y1)
+                xx2 = np.minimum(tx2, x2)
+                yy2 = np.minimum(ty2, y2)
+                w = np.maximum(0.0, xx2 - xx1 + eta)
+                h = np.maximum(0.0, yy2 - yy1 + eta)
+                inter = w * h
+                ovr = inter / (areas + areas[maxpos] - inter)
+                weight = np.exp(-(ovr * ovr) / sigma)
+                scores = scores * weight
+                idx_keep = np.where(scores >= thres)
+                dets[:, 0] = scores
+                dets = dets[idx_keep]
+            dets_final = np.array(dets_final).reshape(-1, 5)
+            return dets_final
+
+        def _soft_nms(bboxes, scores):
+            class_nums = scores.shape[-1]
+
+            softnms_thres = self.score_threshold
+            softnms_sigma = self.softnms_sigma
+            keep_top_k = self.keep_top_k
+
+            cls_boxes = [[] for _ in range(class_nums)]
+            cls_ids = [[] for _ in range(class_nums)]
+
+            start_idx = 1 if self.background_label == 0 else 0
+            for j in range(start_idx, class_nums):
+                inds = np.where(scores[:, j] >= softnms_thres)[0]
+                scores_j = scores[inds, j]
+                rois_j = bboxes[inds, j, :] if len(
+                    bboxes.shape) > 2 else bboxes[inds, :]
+                dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
+                    np.float32, copy=False)
+                cls_rank = np.argsort(-dets_j[:, 0])
+                dets_j = dets_j[cls_rank]
+
+                cls_boxes[j] = _soft_nms_for_cls(
+                    dets_j, sigma=softnms_sigma, thres=softnms_thres)
+                cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1,
+                                                                           1)
+
+            cls_boxes = np.vstack(cls_boxes[start_idx:])
+            cls_ids = np.vstack(cls_ids[start_idx:])
+            pred_result = np.hstack([cls_ids, cls_boxes])
+
+            # Limit to max_per_image detections **over all classes**
+            image_scores = cls_boxes[:, 0]
+            if len(image_scores) > keep_top_k:
+                image_thresh = np.sort(image_scores)[-keep_top_k]
+                keep = np.where(cls_boxes[:, 0] >= image_thresh)[0]
+                pred_result = pred_result[keep, :]
+
+            return pred_result
+
+        def _batch_softnms(bboxes, scores):
+            batch_offsets = bboxes.lod()
+            bboxes = np.array(bboxes)
+            scores = np.array(scores)
+            out_offsets = [0]
+            pred_res = []
+            if len(batch_offsets) > 0:
+                batch_offset = batch_offsets[0]
+                for i in range(len(batch_offset) - 1):
+                    s, e = batch_offset[i], batch_offset[i + 1]
+                    pred = _soft_nms(bboxes[s:e], scores[s:e])
+                    out_offsets.append(pred.shape[0] + out_offsets[-1])
+                    pred_res.append(pred)
+            else:
+                assert len(bboxes.shape) == 3
+                assert len(scores.shape) == 3
+                for i in range(bboxes.shape[0]):
+                    pred = _soft_nms(bboxes[i], scores[i])
+                    out_offsets.append(pred.shape[0] + out_offsets[-1])
+                    pred_res.append(pred)
+
+            res = fluid.LoDTensor()
+            res.set_lod([out_offsets])
+            if len(pred_res) == 0:
+                pred_res = np.array([[1]], dtype=np.float32)
+            res.set(np.vstack(pred_res).astype(np.float32), fluid.CPUPlace())
+            return res
+
+        pred_result = create_tmp_var(
+            fluid.default_main_program(),
+            name='softnms_pred_result',
+            dtype='float32',
+            shape=[-1, 6],
+            lod_level=1)
+        fluid.layers.py_func(
+            func=_batch_softnms, x=[bboxes, scores], out=pred_result)
+        return pred_result

+ 295 - 105
paddlex/cv/nets/detection/yolo_v3.py

@@ -16,25 +16,50 @@ from paddle import fluid
 from paddle.fluid.param_attr import ParamAttr
 from paddle.fluid.regularizer import L2Decay
 from collections import OrderedDict
+from .ops import MultiClassNMS, MultiClassSoftNMS, MatrixNMS
+from .ops import DropBlock
+from .loss.yolo_loss import YOLOv3Loss
+from .loss.iou_loss import IouLoss
+from .loss.iou_aware_loss import IouAwareLoss
+from .iou_aware import get_iou_aware_score
+try:
+    from collections.abc import Sequence
+except Exception:
+    from collections import Sequence
 
 
 class YOLOv3:
-    def __init__(self,
-                 backbone,
-                 num_classes,
-                 mode='train',
-                 anchors=None,
-                 anchor_masks=None,
-                 ignore_threshold=0.7,
-                 label_smooth=False,
-                 nms_score_threshold=0.01,
-                 nms_topk=1000,
-                 nms_keep_topk=100,
-                 nms_iou_threshold=0.45,
-                 train_random_shapes=[
-                     320, 352, 384, 416, 448, 480, 512, 544, 576, 608
-                 ],
-                 fixed_input_shape=None):
+    def __init__(
+            self,
+            backbone,
+            mode='train',
+            # YOLOv3Head
+            num_classes=80,
+            anchors=None,
+            anchor_masks=None,
+            coord_conv=False,
+            iou_aware=False,
+            iou_aware_factor=0.4,
+            scale_x_y=1.,
+            spp=False,
+            drop_block=False,
+            use_matrix_nms=False,
+            # YOLOv3Loss
+            batch_size=8,
+            ignore_threshold=0.7,
+            label_smooth=False,
+            use_fine_grained_loss=False,
+            use_iou_loss=False,
+            iou_loss_weight=2.5,
+            iou_aware_loss_weight=1.0,
+            max_height=608,
+            max_width=608,
+            # NMS
+            nms_score_threshold=0.01,
+            nms_topk=1000,
+            nms_keep_topk=100,
+            nms_iou_threshold=0.45,
+            fixed_input_shape=None):
         if anchors is None:
             anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
                        [59, 119], [116, 90], [156, 198], [373, 326]]
@@ -46,56 +71,114 @@ class YOLOv3:
         self.mode = mode
         self.num_classes = num_classes
         self.backbone = backbone
-        self.ignore_thresh = ignore_threshold
-        self.label_smooth = label_smooth
-        self.nms_score_threshold = nms_score_threshold
-        self.nms_topk = nms_topk
-        self.nms_keep_topk = nms_keep_topk
-        self.nms_iou_threshold = nms_iou_threshold
         self.norm_decay = 0.0
         self.prefix_name = ''
-        self.train_random_shapes = train_random_shapes
+        self.use_fine_grained_loss = use_fine_grained_loss
         self.fixed_input_shape = fixed_input_shape
+        self.coord_conv = coord_conv
+        self.iou_aware = iou_aware
+        self.iou_aware_factor = iou_aware_factor
+        self.scale_x_y = scale_x_y
+        self.use_spp = spp
+        self.drop_block = drop_block
 
-    def _head(self, feats):
+        if use_matrix_nms:
+            self.nms = MatrixNMS(
+                background_label=-1,
+                keep_top_k=nms_keep_topk,
+                normalized=False,
+                score_threshold=nms_score_threshold,
+                post_threshold=0.01)
+        else:
+            self.nms = MultiClassNMS(
+                background_label=-1,
+                keep_top_k=nms_keep_topk,
+                nms_threshold=nms_iou_threshold,
+                nms_top_k=nms_topk,
+                normalized=False,
+                score_threshold=nms_score_threshold)
+        self.iou_loss = None
+        self.iou_aware_loss = None
+        if use_iou_loss:
+            self.iou_loss = IouLoss(
+                loss_weight=iou_loss_weight,
+                max_height=max_height,
+                max_width=max_width)
+        if iou_aware:
+            self.iou_aware_loss = IouAwareLoss(
+                loss_weight=iou_aware_loss_weight,
+                max_height=max_height,
+                max_width=max_width)
+        self.yolo_loss = YOLOv3Loss(
+            batch_size=batch_size,
+            ignore_thresh=ignore_threshold,
+            scale_x_y=scale_x_y,
+            label_smooth=label_smooth,
+            use_fine_grained_loss=self.use_fine_grained_loss,
+            iou_loss=self.iou_loss,
+            iou_aware_loss=self.iou_aware_loss)
+        self.conv_block_num = 2
+        self.block_size = 3
+        self.keep_prob = 0.9
+        self.downsample = [32, 16, 8]
+        self.clip_bbox = True
+
+    def _head(self, input, is_train=True):
         outputs = []
+
+        # get last out_layer_num blocks in reverse order
         out_layer_num = len(self.anchor_masks)
-        blocks = feats[-1:-out_layer_num - 1:-1]
-        route = None
+        blocks = input[-1:-out_layer_num - 1:-1]
 
+        route = None
         for i, block in enumerate(blocks):
-            if i > 0:
+            if i > 0:  # perform concat in first 2 detection_block
                 block = fluid.layers.concat(input=[route, block], axis=1)
             route, tip = self._detection_block(
                 block,
-                channel=512 // (2**i),
-                name=self.prefix_name + 'yolo_block.{}'.format(i))
+                channel=64 * (2**out_layer_num) // (2**i),
+                is_first=i == 0,
+                is_test=(not is_train),
+                conv_block_num=self.conv_block_num,
+                name=self.prefix_name + "yolo_block.{}".format(i))
 
-            num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
-            block_out = fluid.layers.conv2d(
-                input=tip,
-                num_filters=num_filters,
-                filter_size=1,
-                stride=1,
-                padding=0,
-                act=None,
-                param_attr=ParamAttr(name=self.prefix_name +
-                                     'yolo_output.{}.conv.weights'.format(i)),
-                bias_attr=ParamAttr(
-                    regularizer=L2Decay(0.0),
-                    name=self.prefix_name +
-                    'yolo_output.{}.conv.bias'.format(i)))
-            outputs.append(block_out)
+            # out channel number = mask_num * (5 + class_num)
+            if self.iou_aware:
+                num_filters = len(self.anchor_masks[i]) * (
+                    self.num_classes + 6)
+            else:
+                num_filters = len(self.anchor_masks[i]) * (
+                    self.num_classes + 5)
+            with fluid.name_scope('yolo_output'):
+                block_out = fluid.layers.conv2d(
+                    input=tip,
+                    num_filters=num_filters,
+                    filter_size=1,
+                    stride=1,
+                    padding=0,
+                    act=None,
+                    param_attr=ParamAttr(
+                        name=self.prefix_name +
+                        "yolo_output.{}.conv.weights".format(i)),
+                    bias_attr=ParamAttr(
+                        regularizer=L2Decay(0.),
+                        name=self.prefix_name +
+                        "yolo_output.{}.conv.bias".format(i)))
+                outputs.append(block_out)
 
             if i < len(blocks) - 1:
+                # do not perform upsample in the last detection_block
                 route = self._conv_bn(
                     input=route,
                     ch_out=256 // (2**i),
                     filter_size=1,
                     stride=1,
                     padding=0,
-                    name=self.prefix_name + 'yolo_transition.{}'.format(i))
+                    is_test=(not is_train),
+                    name=self.prefix_name + "yolo_transition.{}".format(i))
+                # upsample
                 route = self._upsample(route)
+
         return outputs
 
     def _parse_anchors(self, anchors):
@@ -116,6 +199,54 @@ class YOLOv3:
                 assert mask < anchor_num, "anchor mask index overflow"
                 self.mask_anchors[-1].extend(anchors[mask])
 
+    def _create_tensor_from_numpy(self, numpy_array):
+        paddle_array = fluid.layers.create_global_var(
+            shape=numpy_array.shape, value=0., dtype=numpy_array.dtype)
+        fluid.layers.assign(numpy_array, paddle_array)
+        return paddle_array
+
+    def _add_coord(self, input, is_test=True):
+        if not self.coord_conv:
+            return input
+
+        # NOTE: here is used for exporting model for TensorRT inference,
+        #       only support batch_size=1 for input shape should be fixed,
+        #       and we create tensor with fixed shape from numpy array
+        if is_test and input.shape[2] > 0 and input.shape[3] > 0:
+            batch_size = 1
+            grid_x = int(input.shape[3])
+            grid_y = int(input.shape[2])
+            idx_i = np.array(
+                [[i / (grid_x - 1) * 2.0 - 1 for i in range(grid_x)]],
+                dtype='float32')
+            gi_np = np.repeat(idx_i, grid_y, axis=0)
+            gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
+            gi_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1])
+
+            x_range = self._create_tensor_from_numpy(gi_np.astype(np.float32))
+            x_range.stop_gradient = True
+            y_range = self._create_tensor_from_numpy(
+                gi_np.transpose([0, 1, 3, 2]).astype(np.float32))
+            y_range.stop_gradient = True
+
+        # NOTE: in training mode, H and W is variable for random shape,
+        #       implement add_coord with shape as Variable
+        else:
+            input_shape = fluid.layers.shape(input)
+            b = input_shape[0]
+            h = input_shape[2]
+            w = input_shape[3]
+
+            x_range = fluid.layers.range(0, w, 1, 'float32') / ((w - 1.) / 2.)
+            x_range = x_range - 1.
+            x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
+            x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
+            x_range.stop_gradient = True
+            y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
+            y_range.stop_gradient = True
+
+        return fluid.layers.concat([input, x_range, y_range], axis=1)
+
     def _conv_bn(self,
                  input,
                  ch_out,
@@ -151,18 +282,52 @@ class YOLOv3:
             out = fluid.layers.leaky_relu(x=out, alpha=0.1)
         return out
 
+    def _spp_module(self, input, is_test=True, name=""):
+        output1 = input
+        output2 = fluid.layers.pool2d(
+            input=output1,
+            pool_size=5,
+            pool_stride=1,
+            pool_padding=2,
+            ceil_mode=False,
+            pool_type='max')
+        output3 = fluid.layers.pool2d(
+            input=output1,
+            pool_size=9,
+            pool_stride=1,
+            pool_padding=4,
+            ceil_mode=False,
+            pool_type='max')
+        output4 = fluid.layers.pool2d(
+            input=output1,
+            pool_size=13,
+            pool_stride=1,
+            pool_padding=6,
+            ceil_mode=False,
+            pool_type='max')
+        output = fluid.layers.concat(
+            input=[output1, output2, output3, output4], axis=1)
+        return output
+
     def _upsample(self, input, scale=2, name=None):
         out = fluid.layers.resize_nearest(
             input=input, scale=float(scale), name=name)
         return out
 
-    def _detection_block(self, input, channel, name=None):
-        assert channel % 2 == 0, "channel({}) cannot be divided by 2 in detection block({})".format(
-            channel, name)
+    def _detection_block(self,
+                         input,
+                         channel,
+                         conv_block_num=2,
+                         is_first=False,
+                         is_test=True,
+                         name=None):
+        assert channel % 2 == 0, \
+            "channel {} cannot be divided by 2 in detection block {}" \
+            .format(channel, name)
 
-        is_test = False if self.mode == 'train' else True
         conv = input
-        for i in range(2):
+        for j in range(conv_block_num):
+            conv = self._add_coord(conv, is_test=is_test)
             conv = self._conv_bn(
                 conv,
                 channel,
@@ -170,7 +335,17 @@ class YOLOv3:
                 stride=1,
                 padding=0,
                 is_test=is_test,
-                name='{}.{}.0'.format(name, i))
+                name='{}.{}.0'.format(name, j))
+            if self.use_spp and is_first and j == 1:
+                conv = self._spp_module(conv, is_test=is_test, name="spp")
+                conv = self._conv_bn(
+                    conv,
+                    512,
+                    filter_size=1,
+                    stride=1,
+                    padding=0,
+                    is_test=is_test,
+                    name='{}.{}.spp.conv'.format(name, j))
             conv = self._conv_bn(
                 conv,
                 channel * 2,
@@ -178,7 +353,21 @@ class YOLOv3:
                 stride=1,
                 padding=1,
                 is_test=is_test,
-                name='{}.{}.1'.format(name, i))
+                name='{}.{}.1'.format(name, j))
+            if self.drop_block and j == 0 and not is_first:
+                conv = DropBlock(
+                    conv,
+                    block_size=self.block_size,
+                    keep_prob=self.keep_prob,
+                    is_test=is_test)
+
+        if self.drop_block and is_first:
+            conv = DropBlock(
+                conv,
+                block_size=self.block_size,
+                keep_prob=self.keep_prob,
+                is_test=is_test)
+        conv = self._add_coord(conv, is_test=is_test)
         route = self._conv_bn(
             conv,
             channel,
@@ -187,8 +376,9 @@ class YOLOv3:
             padding=0,
             is_test=is_test,
             name='{}.2'.format(name))
+        new_route = self._add_coord(route, is_test=is_test)
         tip = self._conv_bn(
-            route,
+            new_route,
             channel * 2,
             filter_size=3,
             stride=1,
@@ -197,54 +387,44 @@ class YOLOv3:
             name='{}.tip'.format(name))
         return route, tip
 
-    def _get_loss(self, inputs, gt_box, gt_label, gt_score):
-        losses = []
-        downsample = 32
-        for i, input in enumerate(inputs):
-            loss = fluid.layers.yolov3_loss(
-                x=input,
-                gt_box=gt_box,
-                gt_label=gt_label,
-                gt_score=gt_score,
-                anchors=self.anchors,
-                anchor_mask=self.anchor_masks[i],
-                class_num=self.num_classes,
-                ignore_thresh=self.ignore_thresh,
-                downsample_ratio=downsample,
-                use_label_smooth=self.label_smooth,
-                name=self.prefix_name + 'yolo_loss' + str(i))
-            losses.append(fluid.layers.reduce_mean(loss))
-            downsample //= 2
-        return sum(losses)
+    def _get_loss(self, inputs, gt_box, gt_label, gt_score, targets):
+        loss = self.yolo_loss(inputs, gt_box, gt_label, gt_score, targets,
+                              self.anchors, self.anchor_masks,
+                              self.mask_anchors, self.num_classes,
+                              self.prefix_name)
+        total_loss = fluid.layers.sum(list(loss.values()))
+        return total_loss
 
     def _get_prediction(self, inputs, im_size):
         boxes = []
         scores = []
-        downsample = 32
         for i, input in enumerate(inputs):
+            if self.iou_aware:
+                input = get_iou_aware_score(input,
+                                            len(self.anchor_masks[i]),
+                                            self.num_classes,
+                                            self.iou_aware_factor)
+            scale_x_y = self.scale_x_y if not isinstance(
+                self.scale_x_y, Sequence) else self.scale_x_y[i]
+
             box, score = fluid.layers.yolo_box(
                 x=input,
                 img_size=im_size,
                 anchors=self.mask_anchors[i],
                 class_num=self.num_classes,
-                conf_thresh=self.nms_score_threshold,
-                downsample_ratio=downsample,
-                name=self.prefix_name + 'yolo_box' + str(i))
+                conf_thresh=self.nms.score_threshold,
+                downsample_ratio=self.downsample[i],
+                name=self.prefix_name + 'yolo_box' + str(i),
+                clip_bbox=self.clip_bbox,
+                scale_x_y=self.scale_x_y)
             boxes.append(box)
             scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
-            downsample //= 2
+
         yolo_boxes = fluid.layers.concat(boxes, axis=1)
         yolo_scores = fluid.layers.concat(scores, axis=2)
-        pred = fluid.layers.multiclass_nms(
-            bboxes=yolo_boxes,
-            scores=yolo_scores,
-            score_threshold=self.nms_score_threshold,
-            nms_top_k=self.nms_topk,
-            keep_top_k=self.nms_keep_topk,
-            nms_threshold=self.nms_iou_threshold,
-            normalized=False,
-            nms_eta=1.0,
-            background_label=-1)
+        if type(self.nms) is MultiClassSoftNMS:
+            yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1])
+        pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
         return pred
 
     def generate_inputs(self):
@@ -267,6 +447,25 @@ class YOLOv3:
                 dtype='float32', shape=[None, None], name='gt_score')
             inputs['im_size'] = fluid.data(
                 dtype='int32', shape=[None, 2], name='im_size')
+            if self.use_fine_grained_loss:
+                downsample = 32
+                for i, mask in enumerate(self.anchor_masks):
+                    if self.fixed_input_shape is not None:
+                        target_shape = [
+                            self.fixed_input_shape[1] // downsample,
+                            self.fixed_input_shape[0] // downsample
+                        ]
+                    else:
+                        target_shape = [None, None]
+                    inputs['target{}'.format(i)] = fluid.data(
+                        dtype='float32',
+                        lod_level=0,
+                        shape=[
+                            None, len(mask), 6 + self.num_classes,
+                            target_shape[0], target_shape[1]
+                        ],
+                        name='target{}'.format(i))
+                    downsample //= 2
         elif self.mode == 'eval':
             inputs['im_size'] = fluid.data(
                 dtype='int32', shape=[None, 2], name='im_size')
@@ -285,28 +484,12 @@ class YOLOv3:
 
     def build_net(self, inputs):
         image = inputs['image']
-        if self.mode == 'train':
-            if isinstance(self.train_random_shapes,
-                          (list, tuple)) and len(self.train_random_shapes) > 0:
-                import numpy as np
-                shapes = np.array(self.train_random_shapes)
-                shapes = np.stack([shapes, shapes], axis=1).astype('float32')
-                shapes_tensor = fluid.layers.assign(shapes)
-                index = fluid.layers.uniform_random(
-                    shape=[1], dtype='float32', min=0.0, max=1)
-                index = fluid.layers.cast(
-                    index * len(self.train_random_shapes), dtype='int32')
-                shape = fluid.layers.gather(shapes_tensor, index)
-                shape = fluid.layers.reshape(shape, [-1])
-                shape = fluid.layers.cast(shape, dtype='int32')
-                image = fluid.layers.resize_nearest(
-                    image, out_shape=shape, align_corners=False)
         feats = self.backbone(image)
         if isinstance(feats, OrderedDict):
             feat_names = list(feats.keys())
             feats = [feats[name] for name in feat_names]
 
-        head_outputs = self._head(feats)
+        head_outputs = self._head(feats, self.mode == 'train')
         if self.mode == 'train':
             gt_box = inputs['gt_box']
             gt_label = inputs['gt_label']
@@ -320,8 +503,15 @@ class YOLOv3:
             whwh = fluid.layers.cast(whwh, dtype='float32')
             whwh.stop_gradient = True
             normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
+
+            targets = []
+            if self.use_fine_grained_loss:
+                for i, mask in enumerate(self.anchor_masks):
+                    k = 'target{}'.format(i)
+                    if k in inputs:
+                        targets.append(inputs[k])
             return self._get_loss(head_outputs, normalized_box, gt_label,
-                                  gt_score)
+                                  gt_score, targets)
         else:
             im_size = inputs['im_size']
             return self._get_prediction(head_outputs, im_size)

+ 4 - 1
paddlex/cv/transforms/__init__.py

@@ -91,7 +91,10 @@ def arrange_transforms(model_type, class_name, transforms, mode='train'):
     elif model_type == 'segmenter':
         arrange_transform = seg_transforms.ArrangeSegmenter
     elif model_type == 'detector':
-        arrange_name = 'Arrange{}'.format(class_name)
+        if class_name == "PPYOLO":
+            arrange_name = 'ArrangeYOLOv3'
+        else:
+            arrange_name = 'Arrange{}'.format(class_name)
         arrange_transform = getattr(det_transforms, arrange_name)
     else:
         raise Exception("Unrecognized model type: {}".format(self.model_type))

+ 1 - 1
paddlex/cv/transforms/cls_transforms.py

@@ -46,7 +46,7 @@ class Compose(ClsTransform):
             raise ValueError('The length of transforms ' + \
                             'must be equal or larger than 1!')
         self.transforms = transforms
-
+        self.batch_transforms = None
         # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
         for op in self.transforms:
             if not isinstance(op, ClsTransform):

+ 185 - 0
paddlex/cv/transforms/det_transforms.py

@@ -55,6 +55,7 @@ class Compose(DetTransform):
             raise ValueError('The length of transforms ' + \
                             'must be equal or larger than 1!')
         self.transforms = transforms
+        self.batch_transforms = None
         self.use_mixup = False
         for t in self.transforms:
             if type(t).__name__ == 'MixupImage':
@@ -1385,3 +1386,187 @@ class ComposedYOLOv3Transforms(Compose):
                         mean=mean, std=std)
             ]
         super(ComposedYOLOv3Transforms, self).__init__(transforms)
+
+
+class BatchRandomShape(DetTransform):
+    """调整图像大小(resize)。
+
+    对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
+    注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
+
+    Args:
+        random_shapes (list): resize大小选择列表。
+            默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
+        interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
+            ['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"RANDOM"。
+    Raises:
+        ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
+                    'AREA', 'LANCZOS4', 'RANDOM']中。
+    """
+
+    # The interpolation mode
+    interp_dict = {
+        'NEAREST': cv2.INTER_NEAREST,
+        'LINEAR': cv2.INTER_LINEAR,
+        'CUBIC': cv2.INTER_CUBIC,
+        'AREA': cv2.INTER_AREA,
+        'LANCZOS4': cv2.INTER_LANCZOS4
+    }
+
+    def __init__(
+            self,
+            random_shapes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+            interp='RANDOM'):
+        if not (interp == "RANDOM" or interp in self.interp_dict):
+            raise ValueError("interp should be one of {}".format(
+                self.interp_dict.keys()))
+        self.random_shapes = random_shapes
+        self.interp = interp
+
+    def __call__(self, batch_data):
+        """
+        Args:
+            batch_data (list): 由与图像相关的各种信息组成的batch数据。
+        Returns:
+            list: 由与图像相关的各种信息组成的batch数据。
+        """
+        shape = np.random.choice(self.random_shapes)
+
+        if self.interp == "RANDOM":
+            interp = random.choice(list(self.interp_dict.keys()))
+        else:
+            interp = self.interp
+        for data_id, data in enumerate(batch_data):
+            data_list = list(data)
+            im = data_list[0]
+            im = np.swapaxes(im, 1, 0)
+            im = np.swapaxes(im, 1, 2)
+            im = resize(im, shape, self.interp_dict[interp])
+            im = np.swapaxes(im, 1, 2)
+            im = np.swapaxes(im, 1, 0)
+            data_list[0] = im
+            batch_data[data_id] = tuple(data_list)
+        return batch_data
+
+
+class GenerateYoloTarget(object):
+    """生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
+       该transform只在YOLOv3计算细粒度loss时使用。
+
+       Args:
+           anchors (list|tuple): anchor框的宽度和高度。
+           anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
+           num_classes (int): 类别数。默认为80。
+           iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
+    """
+
+    def __init__(self,
+                 anchors,
+                 anchor_masks,
+                 downsample_ratios,
+                 num_classes=80,
+                 iou_thresh=1.):
+        super(GenerateYoloTarget, self).__init__()
+        self.anchors = anchors
+        self.anchor_masks = anchor_masks
+        self.downsample_ratios = downsample_ratios
+        self.num_classes = num_classes
+        self.iou_thresh = iou_thresh
+
+    def __call__(self, batch_data):
+        """
+        Args:
+            batch_data (list): 由与图像相关的各种信息组成的batch数据。
+        Returns:
+            list: 由与图像相关的各种信息组成的batch数据。
+                  其中,每个数据新添加的字段为:
+                           - target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
+                                   形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
+                           - target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
+                                   形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
+                           - ...
+                           -targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
+                                   形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
+                    n的是大小由anchor_masks的长度决定。
+        """
+        im = batch_data[0][0]
+        h = im.shape[1]
+        w = im.shape[2]
+        an_hw = np.array(self.anchors) / np.array([[w, h]])
+        for data_id, data in enumerate(batch_data):
+            gt_bbox = data[1]
+            gt_class = data[2]
+            gt_score = data[3]
+            im_shape = data[4]
+            origin_h = float(im_shape[0])
+            origin_w = float(im_shape[1])
+            data_list = list(data)
+            for i, (
+                    mask, downsample_ratio
+            ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
+                grid_h = int(h / downsample_ratio)
+                grid_w = int(w / downsample_ratio)
+                target = np.zeros(
+                    (len(mask), 6 + self.num_classes, grid_h, grid_w),
+                    dtype=np.float32)
+                for b in range(gt_bbox.shape[0]):
+                    gx = gt_bbox[b, 0] / float(origin_w)
+                    gy = gt_bbox[b, 1] / float(origin_h)
+                    gw = gt_bbox[b, 2] / float(origin_w)
+                    gh = gt_bbox[b, 3] / float(origin_h)
+                    cls = gt_class[b]
+                    score = gt_score[b]
+                    if gw <= 0. or gh <= 0. or score <= 0.:
+                        continue
+                    # find best match anchor index
+                    best_iou = 0.
+                    best_idx = -1
+                    for an_idx in range(an_hw.shape[0]):
+                        iou = jaccard_overlap(
+                            [0., 0., gw, gh],
+                            [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
+                        if iou > best_iou:
+                            best_iou = iou
+                            best_idx = an_idx
+                    gi = int(gx * grid_w)
+                    gj = int(gy * grid_h)
+                    # gtbox should be regresed in this layes if best match
+                    # anchor index in anchor mask of this layer
+                    if best_idx in mask:
+                        best_n = mask.index(best_idx)
+                        # x, y, w, h, scale
+                        target[best_n, 0, gj, gi] = gx * grid_w - gi
+                        target[best_n, 1, gj, gi] = gy * grid_h - gj
+                        target[best_n, 2, gj, gi] = np.log(
+                            gw * w / self.anchors[best_idx][0])
+                        target[best_n, 3, gj, gi] = np.log(
+                            gh * h / self.anchors[best_idx][1])
+                        target[best_n, 4, gj, gi] = 2.0 - gw * gh
+                        # objectness record gt_score
+                        target[best_n, 5, gj, gi] = score
+                        # classification
+                        target[best_n, 6 + cls, gj, gi] = 1.
+                    # For non-matched anchors, calculate the target if the iou
+                    # between anchor and gt is larger than iou_thresh
+                    if self.iou_thresh < 1:
+                        for idx, mask_i in enumerate(mask):
+                            if mask_i == best_idx: continue
+                            iou = jaccard_overlap(
+                                [0., 0., gw, gh],
+                                [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
+                            if iou > self.iou_thresh:
+                                # x, y, w, h, scale
+                                target[idx, 0, gj, gi] = gx * grid_w - gi
+                                target[idx, 1, gj, gi] = gy * grid_h - gj
+                                target[idx, 2, gj, gi] = np.log(
+                                    gw * w / self.anchors[mask_i][0])
+                                target[idx, 3, gj, gi] = np.log(
+                                    gh * h / self.anchors[mask_i][1])
+                                target[idx, 4, gj, gi] = 2.0 - gw * gh
+                                # objectness record gt_score
+                                target[idx, 5, gj, gi] = score
+                                # classification
+                                target[idx, 6 + cls, gj, gi] = 1.
+                data_list.append(target)
+            batch_data[data_id] = tuple(data_list)
+        return batch_data

+ 1 - 0
paddlex/cv/transforms/seg_transforms.py

@@ -49,6 +49,7 @@ class Compose(SegTransform):
             raise ValueError('The length of transforms ' + \
                             'must be equal or larger than 1!')
         self.transforms = transforms
+        self.batch_transforms = None
         self.to_rgb = False
         # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
         for op in self.transforms:

+ 1 - 0
paddlex/det.py

@@ -17,6 +17,7 @@ from . import cv
 
 FasterRCNN = cv.models.FasterRCNN
 YOLOv3 = cv.models.YOLOv3
+PPYOLO = cv.models.PPYOLO
 MaskRCNN = cv.models.MaskRCNN
 transforms = cv.transforms.det_transforms
 visualize = cv.models.utils.visualize.visualize_detection

+ 58 - 0
tutorials/train/object_detection/ppyolo.py

@@ -0,0 +1,58 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.det import transforms
+import paddlex as pdx
+
+# 下载和解压昆虫检测数据集
+insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(insect_dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/det_transforms.html
+train_transforms = transforms.Compose([
+    transforms.MixupImage(mixup_epoch=250), transforms.RandomDistort(),
+    transforms.RandomExpand(), transforms.RandomCrop(), transforms.Resize(
+        target_size=608, interp='RANDOM'), transforms.RandomHorizontalFlip(),
+    transforms.Normalize()
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Resize(
+        target_size=608, interp='CUBIC'), transforms.Normalize()
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-vocdetection
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/train_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
+num_classes = len(train_dataset.labels)
+
+# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#paddlex-det-yolov3
+model = pdx.det.PPYOLO(num_classes=num_classes)
+
+# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#train
+# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
+model.train(
+    num_epochs=270,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    learning_rate=0.000125,
+    lr_decay_epochs=[210, 240],
+    save_dir='output/ppyolo',
+    use_vdl=True)