FlyingQianMM 5 سال پیش
والد
کامیت
ed8de1ae27

+ 10 - 8
paddlex/cv/datasets/dataset.py

@@ -115,7 +115,7 @@ def multithread_reader(mapper,
         while not isinstance(sample, EndSignal):
             batch_data.append(sample)
             if len(batch_data) == batch_size:
-                batch_data = generate_minibatch(batch_data)
+                batch_data = generate_minibatch(batch_data, mapper=mapper)
                 yield batch_data
                 batch_data = []
             sample = out_queue.get()
@@ -127,11 +127,11 @@ def multithread_reader(mapper,
             else:
                 batch_data.append(sample)
                 if len(batch_data) == batch_size:
-                    batch_data = generate_minibatch(batch_data)
+                    batch_data = generate_minibatch(batch_data, mapper=mapper)
                     yield batch_data
                     batch_data = []
         if not drop_last and len(batch_data) != 0:
-            batch_data = generate_minibatch(batch_data)
+            batch_data = generate_minibatch(batch_data, mapper=mapper)
             yield batch_data
             batch_data = []
 
@@ -188,18 +188,21 @@ def multiprocess_reader(mapper,
             else:
                 batch_data.append(sample)
                 if len(batch_data) == batch_size:
-                    batch_data = generate_minibatch(batch_data)
+                    batch_data = generate_minibatch(batch_data, mapper=mapper)
                     yield batch_data
                     batch_data = []
         if len(batch_data) != 0 and not drop_last:
-            batch_data = generate_minibatch(batch_data)
+            batch_data = generate_minibatch(batch_data, mapper=mapper)
             yield batch_data
             batch_data = []
 
     return queue_reader
 
 
-def generate_minibatch(batch_data, label_padding_value=255):
+def generate_minibatch(batch_data, label_padding_value=255, mapper=None):
+    if mapper is not None and mapper.batch_transforms is not None:
+        for op in mapper.batch_transforms:
+            batch_data = op(batch_data)
     # if batch_size is 1, do not pad the image
     if len(batch_data) == 1:
         return batch_data
@@ -218,14 +221,13 @@ def generate_minibatch(batch_data, label_padding_value=255):
             (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
         padding_im[:, :im_h, :im_w] = data[0]
         if len(data) > 2:
-           # padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
+            # padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
             if len(data[1]) == 0 or 'padding' not in [
                     data[1][i][0] for i in range(len(data[1]))
             ]:
                 data[1].append(('padding', [im_h, im_w]))
             padding_batch.append((padding_im, data[1], data[2]))
 
-            
         elif len(data) > 1:
             if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
                 # padding the image and label of segmentation during the training

+ 19 - 8
paddlex/cv/models/base.py

@@ -94,6 +94,8 @@ class BaseAPI:
         self.train_inputs, self.train_outputs = self.build_net(mode='train')
         self.train_prog = fluid.default_main_program()
         startup_prog = fluid.default_startup_program()
+        self.train_prog.random_seed = 1000
+        startup_prog.random_seed = 1000
 
         # 构建预测网络
         self.test_prog = fluid.Program()
@@ -246,8 +248,8 @@ class BaseAPI:
             logging.info(
                 "Load pretrain weights from {}.".format(pretrain_weights),
                 use_color=True)
-            paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog,
-                                                      pretrain_weights, fuse_bn)
+            paddlex.utils.utils.load_pretrain_weights(
+                self.exe, self.train_prog, pretrain_weights, fuse_bn)
         # 进行裁剪
         if sensitivities_file is not None:
             import paddleslim
@@ -351,7 +353,9 @@ class BaseAPI:
         logging.info("Model saved in {}.".format(save_dir))
 
     def export_inference_model(self, save_dir):
-        test_input_names = [var.name for var in list(self.test_inputs.values())]
+        test_input_names = [
+            var.name for var in list(self.test_inputs.values())
+        ]
         test_outputs = list(self.test_outputs.values())
         with fluid.scope_guard(self.scope):
             if self.__class__.__name__ == 'MaskRCNN':
@@ -389,7 +393,8 @@ class BaseAPI:
 
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
-        logging.info("Model for inference deploy saved in {}.".format(save_dir))
+        logging.info("Model for inference deploy saved in {}.".format(
+            save_dir))
 
     def train_loop(self,
                    num_epochs,
@@ -516,11 +521,13 @@ class BaseAPI:
                         eta = ((num_epochs - i) * total_num_steps - step - 1
                                ) * avg_step_time
                     if time_eval_one_epoch is not None:
-                        eval_eta = (total_eval_times - i // save_interval_epochs
-                                    ) * time_eval_one_epoch
+                        eval_eta = (
+                            total_eval_times - i // save_interval_epochs
+                        ) * time_eval_one_epoch
                     else:
-                        eval_eta = (total_eval_times - i // save_interval_epochs
-                                    ) * total_num_steps_eval * avg_step_time
+                        eval_eta = (
+                            total_eval_times - i // save_interval_epochs
+                        ) * total_num_steps_eval * avg_step_time
                     eta_str = seconds_to_hms(eta + eval_eta)
 
                     logging.info(
@@ -543,6 +550,8 @@ class BaseAPI:
                 current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
                 if not osp.isdir(current_save_dir):
                     os.makedirs(current_save_dir)
+                if hasattr(self, 'use_ema'):
+                    self.exe.run(self.ema.apply_program)
                 if eval_dataset is not None and eval_dataset.num_samples > 0:
                     self.eval_metrics, self.eval_details = self.evaluate(
                         eval_dataset=eval_dataset,
@@ -569,6 +578,8 @@ class BaseAPI:
                             log_writer.add_scalar(
                                 "Metrics/Eval(Epoch): {}".format(k), v, i + 1)
                 self.save_model(save_dir=current_save_dir)
+                if hasattr(self, 'use_ema'):
+                    self.exe.run(self.ema.restore_program)
                 time_eval_one_epoch = time.time() - eval_epoch_start_time
                 eval_epoch_start_time = time.time()
                 if best_model_epoch > 0:

+ 125 - 18
paddlex/cv/models/yolo_v3.py

@@ -19,6 +19,8 @@ import os.path as osp
 import numpy as np
 from multiprocessing.pool import ThreadPool
 import paddle.fluid as fluid
+from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
+from paddle.fluid.optimizer import ExponentialMovingAverage
 import paddlex.utils.logging as logging
 import paddlex
 import copy
@@ -28,6 +30,10 @@ from .base import BaseAPI
 from collections import OrderedDict
 from .utils.detection_eval import eval_results, bbox2out
 
+import random
+random.seed(0)
+np.random.seed(0)
+
 
 class YOLOv3(BaseAPI):
     """构建YOLOv3,并实现其训练、评估、预测和模型导出。
@@ -50,24 +56,37 @@ class YOLOv3(BaseAPI):
         train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
     """
 
-    def __init__(self,
-                 num_classes=80,
-                 backbone='MobileNetV1',
-                 anchors=None,
-                 anchor_masks=None,
-                 ignore_threshold=0.7,
-                 nms_score_threshold=0.01,
-                 nms_topk=1000,
-                 nms_keep_topk=100,
-                 nms_iou_threshold=0.45,
-                 label_smooth=False,
-                 train_random_shapes=[
-                     320, 352, 384, 416, 448, 480, 512, 544, 576, 608
-                 ]):
+    def __init__(
+            self,
+            num_classes=80,
+            backbone='MobileNetV1',
+            with_dcn_v2=False,
+            # YOLO Head
+            anchors=None,
+            anchor_masks=None,
+            use_coord_conv=False,
+            use_iou_aware=False,
+            use_spp=False,
+            use_drop_block=False,
+            scale_x_y=1.0,
+            # YOLOv3 Loss
+            ignore_threshold=0.7,
+            label_smooth=False,
+            use_iou_loss=False,
+            # NMS
+            use_matrix_nms=False,
+            nms_score_threshold=0.01,
+            nms_topk=1000,
+            nms_keep_topk=100,
+            nms_iou_threshold=0.45,
+            train_random_shapes=[
+                320, 352, 384, 416, 448, 480, 512, 544, 576, 608
+            ]):
         self.init_params = locals()
         super(YOLOv3, self).__init__('detector')
         backbones = [
-            'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large'
+            'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large',
+            'ResNet50_vd'
         ]
         assert backbone in backbones, "backbone should be one of {}".format(
             backbones)
@@ -75,6 +94,11 @@ class YOLOv3(BaseAPI):
         self.num_classes = num_classes
         self.anchors = anchors
         self.anchor_masks = anchor_masks
+        if anchors is None:
+            self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
+                            [59, 119], [116, 90], [156, 198], [373, 326]]
+        if anchor_masks is None:
+            self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
         self.ignore_threshold = ignore_threshold
         self.nms_score_threshold = nms_score_threshold
         self.nms_topk = nms_topk
@@ -84,6 +108,20 @@ class YOLOv3(BaseAPI):
         self.sync_bn = True
         self.train_random_shapes = train_random_shapes
         self.fixed_input_shape = None
+        self.use_fine_grained_loss = False
+        if use_coord_conv or use_iou_aware or use_spp or use_drop_block or use_iou_loss:
+            self.use_fine_grained_loss = True
+        self.use_coord_conv = use_coord_conv
+        self.use_iou_aware = use_iou_aware
+        self.use_spp = use_spp
+        self.use_drop_block = use_drop_block
+        self.use_iou_loss = use_iou_loss
+        self.scale_x_y = scale_x_y
+        self.max_height = 608
+        self.max_width = 608
+        self.use_matrix_nms = use_matrix_nms
+        self.use_ema = False
+        self.with_dcn_v2 = with_dcn_v2
 
     def _get_backbone(self, backbone_name):
         if backbone_name == 'DarkNet53':
@@ -102,6 +140,16 @@ class YOLOv3(BaseAPI):
             model_name = backbone_name.split('_')[1]
             backbone = paddlex.cv.nets.MobileNetV3(
                 norm_type='sync_bn', model_name=model_name)
+        elif backbone_name == 'ResNet50_vd':
+            backbone = paddlex.cv.nets.ResNet(
+                norm_type='sync_bn',
+                layers=50,
+                freeze_norm=False,
+                norm_decay=0.,
+                feature_maps=[3, 4, 5],
+                freeze_at=0,
+                variant='d',
+                dcn_v2_stages=[5] if self.with_dcn_v2 else [])
         return backbone
 
     def build_net(self, mode='train'):
@@ -117,14 +165,31 @@ class YOLOv3(BaseAPI):
             nms_topk=self.nms_topk,
             nms_keep_topk=self.nms_keep_topk,
             nms_iou_threshold=self.nms_iou_threshold,
-            train_random_shapes=self.train_random_shapes,
-            fixed_input_shape=self.fixed_input_shape)
+            fixed_input_shape=self.fixed_input_shape,
+            coord_conv=self.use_coord_conv,
+            iou_aware=self.use_iou_aware,
+            scale_x_y=self.scale_x_y,
+            spp=self.use_spp,
+            drop_block=self.use_drop_block,
+            use_matrix_nms=self.use_matrix_nms,
+            use_fine_grained_loss=self.use_fine_grained_loss,
+            use_iou_loss=self.use_iou_loss,
+            batch_size=self.batch_size_per_gpu
+            if hasattr(self, 'batch_size_per_gpu') else 8)
+        if mode == 'train' and self.use_iou_loss or self.use_iou_aware:
+            model.max_height = self.max_height
+            model.max_width = self.max_width
         inputs = model.generate_inputs()
         model_out = model.build_net(inputs)
-        outputs = OrderedDict([('bbox', model_out)])
+        outputs = OrderedDict([('bbox', model_out[0])])
         if mode == 'train':
             self.optimizer.minimize(model_out)
             outputs = OrderedDict([('loss', model_out)])
+            if self.use_ema:
+                global_steps = _decay_step_counter()
+                self.ema = ExponentialMovingAverage(
+                    self.ema_decay, thres_steps=global_steps)
+                self.ema.update()
         return inputs, outputs
 
     def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
@@ -172,6 +237,8 @@ class YOLOv3(BaseAPI):
               warmup_start_lr=0.0,
               lr_decay_epochs=[213, 240],
               lr_decay_gamma=0.1,
+              use_ema=False,
+              ema_decay=0.9998,
               metric=None,
               use_vdl=False,
               sensitivities_file=None,
@@ -242,6 +309,46 @@ class YOLOv3(BaseAPI):
                 lr_decay_gamma=lr_decay_gamma,
                 num_steps_each_epoch=num_steps_each_epoch)
         self.optimizer = optimizer
+        self.use_ema = use_ema
+        self.ema_decay = ema_decay
+
+        self.batch_size_per_gpu = int(train_batch_size /
+                                      paddlex.env_info['num'])
+        if self.use_fine_grained_loss:
+            for transform in train_dataset.transforms.transforms:
+                if isinstance(transform, paddlex.det.transforms.Resize):
+                    self.max_height = transform.target_size
+                    self.max_width = transform.target_size
+                    break
+        if train_dataset.transforms.batch_transforms is None:
+            train_dataset.transforms.batch_transforms = list()
+        define_random_shape = False
+        for bt in train_dataset.transforms.batch_transforms:
+            if isinstance(bt, paddlex.det.transforms.BatchRandomShape):
+                define_random_shape = True
+        if not define_random_shape:
+            if isinstance(self.train_random_shapes,
+                          (list, tuple)) and len(self.train_random_shapes) > 0:
+                train_dataset.transforms.batch_transforms.append(
+                    paddlex.det.transforms.BatchRandomShape(
+                        random_shapes=self.train_random_shapes))
+                if self.use_fine_grained_loss:
+                    self.max_height = max(self.max_height,
+                                          max(self.train_random_shapes))
+                    self.max_width = max(self.max_width,
+                                         max(self.train_random_shapes))
+        if self.use_fine_grained_loss:
+            define_generate_target = False
+            for bt in train_dataset.transforms.batch_transforms:
+                if isinstance(bt, paddlex.det.transforms.GenerateYoloTarget):
+                    define_generate_target = True
+            if not define_generate_target:
+                train_dataset.transforms.batch_transforms.append(
+                    paddlex.det.transforms.GenerateYoloTarget(
+                        anchors=self.anchors,
+                        anchor_masks=self.anchor_masks,
+                        num_classes=self.num_classes,
+                        downsample_ratios=[32, 16, 8]))
         # 构建训练、验证、预测网络
         self.build_program()
         # 初始化网络权重

+ 85 - 0
paddlex/cv/nets/detection/iou_aware.py

@@ -0,0 +1,85 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import fluid
+
+
+def _split_ioup(output, an_num, num_classes):
+    """
+    Split new output feature map to output, predicted iou
+    along channel dimension
+    """
+    ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
+    ioup = fluid.layers.sigmoid(ioup)
+
+    oriout = fluid.layers.slice(
+        output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)])
+
+    return (ioup, oriout)
+
+
+def _de_sigmoid(x, eps=1e-7):
+    x = fluid.layers.clip(x, eps, 1 / eps)
+    one = fluid.layers.fill_constant(
+        shape=[1, 1, 1, 1], dtype=x.dtype, value=1.)
+    x = fluid.layers.clip((one / x - 1.0), eps, 1 / eps)
+    x = -fluid.layers.log(x)
+    return x
+
+
+def _postprocess_output(ioup, output, an_num, num_classes, iou_aware_factor):
+    """
+    post process output objectness score
+    """
+    tensors = []
+    stride = output.shape[1] // an_num
+    for m in range(an_num):
+        tensors.append(
+            fluid.layers.slice(
+                output,
+                axes=[1],
+                starts=[stride * m + 0],
+                ends=[stride * m + 4]))
+        obj = fluid.layers.slice(
+            output, axes=[1], starts=[stride * m + 4], ends=[stride * m + 5])
+        obj = fluid.layers.sigmoid(obj)
+        ip = fluid.layers.slice(ioup, axes=[1], starts=[m], ends=[m + 1])
+
+        new_obj = fluid.layers.pow(obj, (
+            1 - iou_aware_factor)) * fluid.layers.pow(ip, iou_aware_factor)
+        new_obj = _de_sigmoid(new_obj)
+
+        tensors.append(new_obj)
+
+        tensors.append(
+            fluid.layers.slice(
+                output,
+                axes=[1],
+                starts=[stride * m + 5],
+                ends=[stride * m + 5 + num_classes]))
+
+    output = fluid.layers.concat(tensors, axis=1)
+
+    return output
+
+
+def get_iou_aware_score(output, an_num, num_classes, iou_aware_factor):
+    ioup, output = _split_ioup(output, an_num, num_classes)
+    output = _postprocess_output(ioup, output, an_num, num_classes,
+                                 iou_aware_factor)
+    return output

+ 77 - 0
paddlex/cv/nets/detection/loss/iou_aware_loss.py

@@ -0,0 +1,77 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.initializer import NumpyArrayInitializer
+
+from paddle import fluid
+from .iou_loss import IouLoss
+
+
+class IouAwareLoss(IouLoss):
+    """
+    iou aware loss, see https://arxiv.org/abs/1912.05992
+    Args:
+        loss_weight (float): iou aware loss weight, default is 1.0
+        max_height (int): max height of input to support random shape input
+        max_width (int): max width of input to support random shape input
+    """
+
+    def __init__(self, loss_weight=1.0, max_height=608, max_width=608):
+        super(IouAwareLoss, self).__init__(
+            loss_weight=loss_weight,
+            max_height=max_height,
+            max_width=max_width)
+
+    def __call__(self,
+                 ioup,
+                 x,
+                 y,
+                 w,
+                 h,
+                 tx,
+                 ty,
+                 tw,
+                 th,
+                 anchors,
+                 downsample_ratio,
+                 batch_size,
+                 scale_x_y,
+                 eps=1.e-10):
+        '''
+        Args:
+            ioup ([Variables]): the predicted iou
+            x  | y | w | h  ([Variables]): the output of yolov3 for encoded x|y|w|h
+            tx |ty |tw |th  ([Variables]): the target of yolov3 for encoded x|y|w|h
+            anchors ([float]): list of anchors for current output layer
+            downsample_ratio (float): the downsample ratio for current output layer
+            batch_size (int): training batch size
+            eps (float): the decimal to prevent the denominator eqaul zero
+        '''
+
+        pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
+                                    batch_size, False, scale_x_y, eps)
+        gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
+                                  batch_size, True, scale_x_y, eps)
+        iouk = self._iou(pred, gt, ioup, eps)
+        iouk.stop_gradient = True
+
+        loss_iou_aware = fluid.layers.cross_entropy(
+            ioup, iouk, soft_label=True)
+        loss_iou_aware = loss_iou_aware * self._loss_weight
+        return loss_iou_aware

+ 235 - 0
paddlex/cv/nets/detection/loss/iou_loss.py

@@ -0,0 +1,235 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.initializer import NumpyArrayInitializer
+
+from paddle import fluid
+
+
+class IouLoss(object):
+    """
+    iou loss, see https://arxiv.org/abs/1908.03851
+    loss = 1.0 - iou * iou
+    Args:
+        loss_weight (float): iou loss weight, default is 2.5
+        max_height (int): max height of input to support random shape input
+        max_width (int): max width of input to support random shape input
+        ciou_term (bool): whether to add ciou_term
+        loss_square (bool): whether to square the iou term
+    """
+
+    def __init__(self,
+                 loss_weight=2.5,
+                 max_height=608,
+                 max_width=608,
+                 ciou_term=False,
+                 loss_square=True):
+        self._loss_weight = loss_weight
+        self._MAX_HI = max_height
+        self._MAX_WI = max_width
+        self.ciou_term = ciou_term
+        self.loss_square = loss_square
+
+    def __call__(self,
+                 x,
+                 y,
+                 w,
+                 h,
+                 tx,
+                 ty,
+                 tw,
+                 th,
+                 anchors,
+                 downsample_ratio,
+                 batch_size,
+                 scale_x_y=1.,
+                 ioup=None,
+                 eps=1.e-10):
+        '''
+        Args:
+            x  | y | w | h  ([Variables]): the output of yolov3 for encoded x|y|w|h
+            tx |ty |tw |th  ([Variables]): the target of yolov3 for encoded x|y|w|h
+            anchors ([float]): list of anchors for current output layer
+            downsample_ratio (float): the downsample ratio for current output layer
+            batch_size (int): training batch size
+            eps (float): the decimal to prevent the denominator eqaul zero
+        '''
+        pred = self._bbox_transform(x, y, w, h, anchors, downsample_ratio,
+                                    batch_size, False, scale_x_y, eps)
+        gt = self._bbox_transform(tx, ty, tw, th, anchors, downsample_ratio,
+                                  batch_size, True, scale_x_y, eps)
+        iouk = self._iou(pred, gt, ioup, eps)
+        if self.loss_square:
+            loss_iou = 1. - iouk * iouk
+        else:
+            loss_iou = 1. - iouk
+        loss_iou = loss_iou * self._loss_weight
+
+        return loss_iou
+
+    def _iou(self, pred, gt, ioup=None, eps=1.e-10):
+        x1, y1, x2, y2 = pred
+        x1g, y1g, x2g, y2g = gt
+        x2 = fluid.layers.elementwise_max(x1, x2)
+        y2 = fluid.layers.elementwise_max(y1, y2)
+
+        xkis1 = fluid.layers.elementwise_max(x1, x1g)
+        ykis1 = fluid.layers.elementwise_max(y1, y1g)
+        xkis2 = fluid.layers.elementwise_min(x2, x2g)
+        ykis2 = fluid.layers.elementwise_min(y2, y2g)
+
+        intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
+        intsctk = intsctk * fluid.layers.greater_than(
+            xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
+        unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
+                                                        ) - intsctk + eps
+        iouk = intsctk / unionk
+        if self.ciou_term:
+            ciou = self.get_ciou_term(pred, gt, iouk, eps)
+            iouk = iouk - ciou
+        return iouk
+
+    def get_ciou_term(self, pred, gt, iouk, eps):
+        x1, y1, x2, y2 = pred
+        x1g, y1g, x2g, y2g = gt
+
+        cx = (x1 + x2) / 2
+        cy = (y1 + y2) / 2
+        w = (x2 - x1) + fluid.layers.cast((x2 - x1) == 0, 'float32')
+        h = (y2 - y1) + fluid.layers.cast((y2 - y1) == 0, 'float32')
+
+        cxg = (x1g + x2g) / 2
+        cyg = (y1g + y2g) / 2
+        wg = x2g - x1g
+        hg = y2g - y1g
+
+        # A or B
+        xc1 = fluid.layers.elementwise_min(x1, x1g)
+        yc1 = fluid.layers.elementwise_min(y1, y1g)
+        xc2 = fluid.layers.elementwise_max(x2, x2g)
+        yc2 = fluid.layers.elementwise_max(y2, y2g)
+
+        # DIOU term
+        dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg)
+        dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
+        diou_term = (dist_intersection + eps) / (dist_union + eps)
+        # CIOU term
+        ciou_term = 0
+        ar_gt = wg / hg
+        ar_pred = w / h
+        arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred)
+        ar_loss = 4. / np.pi / np.pi * arctan * arctan
+        alpha = ar_loss / (1 - iouk + ar_loss + eps)
+        alpha.stop_gradient = True
+        ciou_term = alpha * ar_loss
+        return diou_term + ciou_term
+
+    def _bbox_transform(self, dcx, dcy, dw, dh, anchors, downsample_ratio,
+                        batch_size, is_gt, scale_x_y, eps):
+        grid_x = int(self._MAX_WI / downsample_ratio)
+        grid_y = int(self._MAX_HI / downsample_ratio)
+        an_num = len(anchors) // 2
+
+        shape_fmp = fluid.layers.shape(dcx)
+        shape_fmp.stop_gradient = True
+        # generate the grid_w x grid_h center of feature map
+        idx_i = np.array([[i for i in range(grid_x)]])
+        idx_j = np.array([[j for j in range(grid_y)]]).transpose()
+        gi_np = np.repeat(idx_i, grid_y, axis=0)
+        gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
+        gi_np = np.tile(gi_np, reps=[batch_size, an_num, 1, 1])
+        gj_np = np.repeat(idx_j, grid_x, axis=1)
+        gj_np = np.reshape(gj_np, newshape=[1, 1, grid_y, grid_x])
+        gj_np = np.tile(gj_np, reps=[batch_size, an_num, 1, 1])
+        gi_max = self._create_tensor_from_numpy(gi_np.astype(np.float32))
+        gi = fluid.layers.crop(x=gi_max, shape=dcx)
+        gi.stop_gradient = True
+        gj_max = self._create_tensor_from_numpy(gj_np.astype(np.float32))
+        gj = fluid.layers.crop(x=gj_max, shape=dcx)
+        gj.stop_gradient = True
+
+        grid_x_act = fluid.layers.cast(shape_fmp[3], dtype="float32")
+        grid_x_act.stop_gradient = True
+        grid_y_act = fluid.layers.cast(shape_fmp[2], dtype="float32")
+        grid_y_act.stop_gradient = True
+        if is_gt:
+            cx = fluid.layers.elementwise_add(dcx, gi) / grid_x_act
+            cx.gradient = True
+            cy = fluid.layers.elementwise_add(dcy, gj) / grid_y_act
+            cy.gradient = True
+        else:
+            dcx_sig = fluid.layers.sigmoid(dcx)
+            dcy_sig = fluid.layers.sigmoid(dcy)
+            if (abs(scale_x_y - 1.0) > eps):
+                dcx_sig = scale_x_y * dcx_sig - 0.5 * (scale_x_y - 1)
+                dcy_sig = scale_x_y * dcy_sig - 0.5 * (scale_x_y - 1)
+            cx = fluid.layers.elementwise_add(dcx_sig, gi) / grid_x_act
+            cy = fluid.layers.elementwise_add(dcy_sig, gj) / grid_y_act
+
+        anchor_w_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 0]
+        anchor_w_np = np.array(anchor_w_)
+        anchor_w_np = np.reshape(anchor_w_np, newshape=[1, an_num, 1, 1])
+        anchor_w_np = np.tile(
+            anchor_w_np, reps=[batch_size, 1, grid_y, grid_x])
+        anchor_w_max = self._create_tensor_from_numpy(
+            anchor_w_np.astype(np.float32))
+        anchor_w = fluid.layers.crop(x=anchor_w_max, shape=dcx)
+        anchor_w.stop_gradient = True
+        anchor_h_ = [anchors[i] for i in range(0, len(anchors)) if i % 2 == 1]
+        anchor_h_np = np.array(anchor_h_)
+        anchor_h_np = np.reshape(anchor_h_np, newshape=[1, an_num, 1, 1])
+        anchor_h_np = np.tile(
+            anchor_h_np, reps=[batch_size, 1, grid_y, grid_x])
+        anchor_h_max = self._create_tensor_from_numpy(
+            anchor_h_np.astype(np.float32))
+        anchor_h = fluid.layers.crop(x=anchor_h_max, shape=dcx)
+        anchor_h.stop_gradient = True
+        # e^tw e^th
+        exp_dw = fluid.layers.exp(dw)
+        exp_dh = fluid.layers.exp(dh)
+        pw = fluid.layers.elementwise_mul(exp_dw, anchor_w) / \
+            (grid_x_act * downsample_ratio)
+        ph = fluid.layers.elementwise_mul(exp_dh, anchor_h) / \
+            (grid_y_act * downsample_ratio)
+        if is_gt:
+            exp_dw.stop_gradient = True
+            exp_dh.stop_gradient = True
+            pw.stop_gradient = True
+            ph.stop_gradient = True
+
+        x1 = cx - 0.5 * pw
+        y1 = cy - 0.5 * ph
+        x2 = cx + 0.5 * pw
+        y2 = cy + 0.5 * ph
+        if is_gt:
+            x1.stop_gradient = True
+            y1.stop_gradient = True
+            x2.stop_gradient = True
+            y2.stop_gradient = True
+
+        return x1, y1, x2, y2
+
+    def _create_tensor_from_numpy(self, numpy_array):
+        paddle_array = fluid.layers.create_parameter(
+            attr=ParamAttr(),
+            shape=numpy_array.shape,
+            dtype=numpy_array.dtype,
+            default_initializer=NumpyArrayInitializer(numpy_array))
+        paddle_array.stop_gradient = True
+        return paddle_array

+ 371 - 0
paddlex/cv/nets/detection/loss/yolo_loss.py

@@ -0,0 +1,371 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import fluid
+try:
+    from collections.abc import Sequence
+except Exception:
+    from collections import Sequence
+
+
+class YOLOv3Loss(object):
+    """
+    Combined loss for YOLOv3 network
+
+    Args:
+        batch_size (int): training batch size
+        ignore_thresh (float): threshold to ignore confidence loss
+        label_smooth (bool): whether to use label smoothing
+        use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss
+                                      instead of fluid.layers.yolov3_loss
+    """
+
+    def __init__(self,
+                 batch_size=8,
+                 ignore_thresh=0.7,
+                 label_smooth=True,
+                 use_fine_grained_loss=False,
+                 iou_loss=None,
+                 iou_aware_loss=None,
+                 downsample=[32, 16, 8],
+                 scale_x_y=1.,
+                 match_score=False):
+        self._batch_size = batch_size
+        self._ignore_thresh = ignore_thresh
+        self._label_smooth = label_smooth
+        self._use_fine_grained_loss = use_fine_grained_loss
+        self._iou_loss = iou_loss
+        self._iou_aware_loss = iou_aware_loss
+        self.downsample = downsample
+        self.scale_x_y = scale_x_y
+        self.match_score = match_score
+
+    def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors,
+                 anchor_masks, mask_anchors, num_classes, prefix_name):
+        if self._use_fine_grained_loss:
+            return self._get_fine_grained_loss(
+                outputs, targets, gt_box, self._batch_size, num_classes,
+                mask_anchors, self._ignore_thresh)
+        else:
+            losses = []
+            for i, output in enumerate(outputs):
+                scale_x_y = self.scale_x_y if not isinstance(
+                    self.scale_x_y, Sequence) else self.scale_x_y[i]
+                anchor_mask = anchor_masks[i]
+                loss = fluid.layers.yolov3_loss(
+                    x=output,
+                    gt_box=gt_box,
+                    gt_label=gt_label,
+                    gt_score=gt_score,
+                    anchors=anchors,
+                    anchor_mask=anchor_mask,
+                    class_num=num_classes,
+                    ignore_thresh=self._ignore_thresh,
+                    downsample_ratio=self.downsample[i],
+                    use_label_smooth=self._label_smooth,
+                    scale_x_y=scale_x_y,
+                    name=prefix_name + "yolo_loss" + str(i))
+
+                losses.append(fluid.layers.reduce_mean(loss))
+
+            return {'loss': sum(losses)}
+
+    def _get_fine_grained_loss(self,
+                               outputs,
+                               targets,
+                               gt_box,
+                               batch_size,
+                               num_classes,
+                               mask_anchors,
+                               ignore_thresh,
+                               eps=1.e-10):
+        """
+        Calculate fine grained YOLOv3 loss
+
+        Args:
+            outputs ([Variables]): List of Variables, output of backbone stages
+            targets ([Variables]): List of Variables, The targets for yolo
+                                   loss calculatation.
+            gt_box (Variable): The ground-truth boudding boxes.
+            batch_size (int): The training batch size
+            num_classes (int): class num of dataset
+            mask_anchors ([[float]]): list of anchors in each output layer
+            ignore_thresh (float): prediction bbox overlap any gt_box greater
+                                   than ignore_thresh, objectness loss will
+                                   be ignored.
+
+        Returns:
+            Type: dict
+                xy_loss (Variable): YOLOv3 (x, y) coordinates loss
+                wh_loss (Variable): YOLOv3 (w, h) coordinates loss
+                obj_loss (Variable): YOLOv3 objectness score loss
+                cls_loss (Variable): YOLOv3 classification loss
+
+        """
+
+        assert len(outputs) == len(targets), \
+            "YOLOv3 output layer number not equal target number"
+
+        loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], []
+        if self._iou_loss is not None:
+            loss_ious = []
+        if self._iou_aware_loss is not None:
+            loss_iou_awares = []
+        for i, (output, target,
+                anchors) in enumerate(zip(outputs, targets, mask_anchors)):
+            downsample = self.downsample[i]
+            an_num = len(anchors) // 2
+            if self._iou_aware_loss is not None:
+                ioup, output = self._split_ioup(output, an_num, num_classes)
+            x, y, w, h, obj, cls = self._split_output(output, an_num,
+                                                      num_classes)
+            tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target)
+
+            tscale_tobj = tscale * tobj
+
+            scale_x_y = self.scale_x_y if not isinstance(
+                self.scale_x_y, Sequence) else self.scale_x_y[i]
+
+            if (abs(scale_x_y - 1.0) < eps):
+                loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
+                    x, tx) * tscale_tobj
+                loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
+                loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
+                    y, ty) * tscale_tobj
+                loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
+            else:
+                dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y -
+                                                                  1.0)
+                dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y -
+                                                                  1.0)
+                loss_x = fluid.layers.abs(dx - tx) * tscale_tobj
+                loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
+                loss_y = fluid.layers.abs(dy - ty) * tscale_tobj
+                loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
+
+            # NOTE: we refined loss function of (w, h) as L1Loss
+            loss_w = fluid.layers.abs(w - tw) * tscale_tobj
+            loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
+            loss_h = fluid.layers.abs(h - th) * tscale_tobj
+            loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
+            if self._iou_loss is not None:
+                loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors,
+                                          downsample, self._batch_size,
+                                          scale_x_y)
+                loss_iou = loss_iou * tscale_tobj
+                loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3])
+                loss_ious.append(fluid.layers.reduce_mean(loss_iou))
+
+            if self._iou_aware_loss is not None:
+                loss_iou_aware = self._iou_aware_loss(
+                    ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample,
+                    self._batch_size, scale_x_y)
+                loss_iou_aware = loss_iou_aware * tobj
+                loss_iou_aware = fluid.layers.reduce_sum(
+                    loss_iou_aware, dim=[1, 2, 3])
+                loss_iou_awares.append(
+                    fluid.layers.reduce_mean(loss_iou_aware))
+
+            loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
+                output, obj, tobj, gt_box, self._batch_size, anchors,
+                num_classes, downsample, self._ignore_thresh, scale_x_y)
+
+            loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls,
+                                                                      tcls)
+            loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
+            loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
+
+            loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y))
+            loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h))
+            loss_objs.append(
+                fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg))
+            loss_clss.append(fluid.layers.reduce_mean(loss_cls))
+
+        losses_all = {
+            "loss_xy": fluid.layers.sum(loss_xys),
+            "loss_wh": fluid.layers.sum(loss_whs),
+            "loss_obj": fluid.layers.sum(loss_objs),
+            "loss_cls": fluid.layers.sum(loss_clss),
+        }
+        if self._iou_loss is not None:
+            losses_all["loss_iou"] = fluid.layers.sum(loss_ious)
+        if self._iou_aware_loss is not None:
+            losses_all["loss_iou_aware"] = fluid.layers.sum(loss_iou_awares)
+        return losses_all
+
+    def _split_ioup(self, output, an_num, num_classes):
+        """
+        Split output feature map to output, predicted iou
+        along channel dimension
+        """
+        ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
+        ioup = fluid.layers.sigmoid(ioup)
+        oriout = fluid.layers.slice(
+            output,
+            axes=[1],
+            starts=[an_num],
+            ends=[an_num * (num_classes + 6)])
+        return (ioup, oriout)
+
+    def _split_output(self, output, an_num, num_classes):
+        """
+        Split output feature map to x, y, w, h, objectness, classification
+        along channel dimension
+        """
+        x = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[0],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        y = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[1],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        w = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[2],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        h = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[3],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        obj = fluid.layers.strided_slice(
+            output,
+            axes=[1],
+            starts=[4],
+            ends=[output.shape[1]],
+            strides=[5 + num_classes])
+        clss = []
+        stride = output.shape[1] // an_num
+        for m in range(an_num):
+            clss.append(
+                fluid.layers.slice(
+                    output,
+                    axes=[1],
+                    starts=[stride * m + 5],
+                    ends=[stride * m + 5 + num_classes]))
+        cls = fluid.layers.transpose(
+            fluid.layers.stack(
+                clss, axis=1), perm=[0, 1, 3, 4, 2])
+
+        return (x, y, w, h, obj, cls)
+
+    def _split_target(self, target):
+        """
+        split target to x, y, w, h, objectness, classification
+        along dimension 2
+
+        target is in shape [N, an_num, 6 + class_num, H, W]
+        """
+        tx = target[:, :, 0, :, :]
+        ty = target[:, :, 1, :, :]
+        tw = target[:, :, 2, :, :]
+        th = target[:, :, 3, :, :]
+
+        tscale = target[:, :, 4, :, :]
+        tobj = target[:, :, 5, :, :]
+
+        tcls = fluid.layers.transpose(
+            target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
+        tcls.stop_gradient = True
+
+        return (tx, ty, tw, th, tscale, tobj, tcls)
+
+    def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors,
+                       num_classes, downsample, ignore_thresh, scale_x_y):
+        # A prediction bbox overlap any gt_bbox over ignore_thresh,
+        # objectness loss will be ignored, process as follows:
+
+        # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
+        # NOTE: img_size is set as 1.0 to get noramlized pred bbox
+        bbox, prob = fluid.layers.yolo_box(
+            x=output,
+            img_size=fluid.layers.ones(
+                shape=[batch_size, 2], dtype="int32"),
+            anchors=anchors,
+            class_num=num_classes,
+            conf_thresh=0.,
+            downsample_ratio=downsample,
+            clip_bbox=False,
+            scale_x_y=scale_x_y)
+
+        # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
+        #    and gt bbox in each sample
+        if batch_size > 1:
+            preds = fluid.layers.split(bbox, batch_size, dim=0)
+            gts = fluid.layers.split(gt_box, batch_size, dim=0)
+        else:
+            preds = [bbox]
+            gts = [gt_box]
+            probs = [prob]
+        ious = []
+        for pred, gt in zip(preds, gts):
+
+            def box_xywh2xyxy(box):
+                x = box[:, 0]
+                y = box[:, 1]
+                w = box[:, 2]
+                h = box[:, 3]
+                return fluid.layers.stack(
+                    [
+                        x - w / 2.,
+                        y - h / 2.,
+                        x + w / 2.,
+                        y + h / 2.,
+                    ], axis=1)
+
+            pred = fluid.layers.squeeze(pred, axes=[0])
+            gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
+            ious.append(fluid.layers.iou_similarity(pred, gt))
+
+        iou = fluid.layers.stack(ious, axis=0)
+        # 3. Get iou_mask by IoU between gt bbox and prediction bbox,
+        #    Get obj_mask by tobj(holds gt_score), calculate objectness loss
+
+        max_iou = fluid.layers.reduce_max(iou, dim=-1)
+        iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
+        if self.match_score:
+            max_prob = fluid.layers.reduce_max(prob, dim=-1)
+            iou_mask = iou_mask * fluid.layers.cast(
+                max_prob <= 0.25, dtype="float32")
+        output_shape = fluid.layers.shape(output)
+        an_num = len(anchors) // 2
+        iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
+                                                   output_shape[3]))
+        iou_mask.stop_gradient = True
+
+        # NOTE: tobj holds gt_score, obj_mask holds object existence mask
+        obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
+        obj_mask.stop_gradient = True
+
+        # For positive objectness grids, objectness loss should be calculated
+        # For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
+        loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj,
+                                                                  obj_mask)
+        loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
+        loss_obj_neg = fluid.layers.reduce_sum(
+            loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
+
+        return loss_obj_pos, loss_obj_neg

+ 270 - 0
paddlex/cv/nets/detection/ops.py

@@ -0,0 +1,270 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from numbers import Integral
+import math
+import six
+
+import paddle
+from paddle import fluid
+
+
+def DropBlock(input, block_size, keep_prob, is_test):
+    if is_test:
+        return input
+
+    def CalculateGamma(input, block_size, keep_prob):
+        input_shape = fluid.layers.shape(input)
+        feat_shape_tmp = fluid.layers.slice(input_shape, [0], [3], [4])
+        feat_shape_tmp = fluid.layers.cast(feat_shape_tmp, dtype="float32")
+        feat_shape_t = fluid.layers.reshape(feat_shape_tmp, [1, 1, 1, 1])
+        feat_area = fluid.layers.pow(feat_shape_t, factor=2)
+
+        block_shape_t = fluid.layers.fill_constant(
+            shape=[1, 1, 1, 1], value=block_size, dtype='float32')
+        block_area = fluid.layers.pow(block_shape_t, factor=2)
+
+        useful_shape_t = feat_shape_t - block_shape_t + 1
+        useful_area = fluid.layers.pow(useful_shape_t, factor=2)
+
+        upper_t = feat_area * (1 - keep_prob)
+        bottom_t = block_area * useful_area
+        output = upper_t / bottom_t
+        return output
+
+    gamma = CalculateGamma(input, block_size=block_size, keep_prob=keep_prob)
+    input_shape = fluid.layers.shape(input)
+    p = fluid.layers.expand_as(gamma, input)
+
+    input_shape_tmp = fluid.layers.cast(input_shape, dtype="int64")
+    random_matrix = fluid.layers.uniform_random(
+        input_shape_tmp, dtype='float32', min=0.0, max=1.0, seed=1000)
+    one_zero_m = fluid.layers.less_than(random_matrix, p)
+    one_zero_m.stop_gradient = True
+    one_zero_m = fluid.layers.cast(one_zero_m, dtype="float32")
+
+    mask_flag = fluid.layers.pool2d(
+        one_zero_m,
+        pool_size=block_size,
+        pool_type='max',
+        pool_stride=1,
+        pool_padding=block_size // 2)
+    mask = 1.0 - mask_flag
+
+    elem_numel = fluid.layers.reduce_prod(input_shape)
+    elem_numel_m = fluid.layers.cast(elem_numel, dtype="float32")
+    elem_numel_m.stop_gradient = True
+
+    elem_sum = fluid.layers.reduce_sum(mask)
+    elem_sum_m = fluid.layers.cast(elem_sum, dtype="float32")
+    elem_sum_m.stop_gradient = True
+
+    output = input * mask * elem_numel_m / elem_sum_m
+    return output
+
+
+class MultiClassNMS(object):
+    def __init__(self,
+                 score_threshold=.05,
+                 nms_top_k=-1,
+                 keep_top_k=100,
+                 nms_threshold=.5,
+                 normalized=False,
+                 nms_eta=1.0,
+                 background_label=0):
+        super(MultiClassNMS, self).__init__()
+        self.score_threshold = score_threshold
+        self.nms_top_k = nms_top_k
+        self.keep_top_k = keep_top_k
+        self.nms_threshold = nms_threshold
+        self.normalized = normalized
+        self.nms_eta = nms_eta
+        self.background_label = background_label
+
+    def __call__(self, bboxes, scores):
+        return fluid.layers.multiclass_nms(
+            bboxes=bboxes,
+            scores=scores,
+            score_threshold=self.score_threshold,
+            nms_top_k=self.nms_top_k,
+            keep_top_k=self.keep_top_k,
+            normalized=self.normalized,
+            nms_threshold=self.nms_threshold,
+            nms_eta=self.nms_eta,
+            background_label=self.background_label)
+
+
+class MatrixNMS(object):
+    def __init__(self,
+                 score_threshold=.05,
+                 post_threshold=.05,
+                 nms_top_k=-1,
+                 keep_top_k=100,
+                 use_gaussian=False,
+                 gaussian_sigma=2.,
+                 normalized=False,
+                 background_label=0):
+        super(MatrixNMS, self).__init__()
+        self.score_threshold = score_threshold
+        self.post_threshold = post_threshold
+        self.nms_top_k = nms_top_k
+        self.keep_top_k = keep_top_k
+        self.normalized = normalized
+        self.use_gaussian = use_gaussian
+        self.gaussian_sigma = gaussian_sigma
+        self.background_label = background_label
+
+    def __call__(self, bboxes, scores):
+        return paddle.fluid.layers.matrix_nms(
+            bboxes=bboxes,
+            scores=scores,
+            score_threshold=self.score_threshold,
+            post_threshold=self.post_threshold,
+            nms_top_k=self.nms_top_k,
+            keep_top_k=self.keep_top_k,
+            normalized=self.normalized,
+            use_gaussian=self.use_gaussian,
+            gaussian_sigma=self.gaussian_sigma,
+            background_label=self.background_label)
+
+
+class MultiClassSoftNMS(object):
+    def __init__(
+            self,
+            score_threshold=0.01,
+            keep_top_k=300,
+            softnms_sigma=0.5,
+            normalized=False,
+            background_label=0, ):
+        super(MultiClassSoftNMS, self).__init__()
+        self.score_threshold = score_threshold
+        self.keep_top_k = keep_top_k
+        self.softnms_sigma = softnms_sigma
+        self.normalized = normalized
+        self.background_label = background_label
+
+    def __call__(self, bboxes, scores):
+        def create_tmp_var(program, name, dtype, shape, lod_level):
+            return program.current_block().create_var(
+                name=name, dtype=dtype, shape=shape, lod_level=lod_level)
+
+        def _soft_nms_for_cls(dets, sigma, thres):
+            """soft_nms_for_cls"""
+            dets_final = []
+            while len(dets) > 0:
+                maxpos = np.argmax(dets[:, 0])
+                dets_final.append(dets[maxpos].copy())
+                ts, tx1, ty1, tx2, ty2 = dets[maxpos]
+                scores = dets[:, 0]
+                # force remove bbox at maxpos
+                scores[maxpos] = -1
+                x1 = dets[:, 1]
+                y1 = dets[:, 2]
+                x2 = dets[:, 3]
+                y2 = dets[:, 4]
+                eta = 0 if self.normalized else 1
+                areas = (x2 - x1 + eta) * (y2 - y1 + eta)
+                xx1 = np.maximum(tx1, x1)
+                yy1 = np.maximum(ty1, y1)
+                xx2 = np.minimum(tx2, x2)
+                yy2 = np.minimum(ty2, y2)
+                w = np.maximum(0.0, xx2 - xx1 + eta)
+                h = np.maximum(0.0, yy2 - yy1 + eta)
+                inter = w * h
+                ovr = inter / (areas + areas[maxpos] - inter)
+                weight = np.exp(-(ovr * ovr) / sigma)
+                scores = scores * weight
+                idx_keep = np.where(scores >= thres)
+                dets[:, 0] = scores
+                dets = dets[idx_keep]
+            dets_final = np.array(dets_final).reshape(-1, 5)
+            return dets_final
+
+        def _soft_nms(bboxes, scores):
+            class_nums = scores.shape[-1]
+
+            softnms_thres = self.score_threshold
+            softnms_sigma = self.softnms_sigma
+            keep_top_k = self.keep_top_k
+
+            cls_boxes = [[] for _ in range(class_nums)]
+            cls_ids = [[] for _ in range(class_nums)]
+
+            start_idx = 1 if self.background_label == 0 else 0
+            for j in range(start_idx, class_nums):
+                inds = np.where(scores[:, j] >= softnms_thres)[0]
+                scores_j = scores[inds, j]
+                rois_j = bboxes[inds, j, :] if len(
+                    bboxes.shape) > 2 else bboxes[inds, :]
+                dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
+                    np.float32, copy=False)
+                cls_rank = np.argsort(-dets_j[:, 0])
+                dets_j = dets_j[cls_rank]
+
+                cls_boxes[j] = _soft_nms_for_cls(
+                    dets_j, sigma=softnms_sigma, thres=softnms_thres)
+                cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1,
+                                                                           1)
+
+            cls_boxes = np.vstack(cls_boxes[start_idx:])
+            cls_ids = np.vstack(cls_ids[start_idx:])
+            pred_result = np.hstack([cls_ids, cls_boxes])
+
+            # Limit to max_per_image detections **over all classes**
+            image_scores = cls_boxes[:, 0]
+            if len(image_scores) > keep_top_k:
+                image_thresh = np.sort(image_scores)[-keep_top_k]
+                keep = np.where(cls_boxes[:, 0] >= image_thresh)[0]
+                pred_result = pred_result[keep, :]
+
+            return pred_result
+
+        def _batch_softnms(bboxes, scores):
+            batch_offsets = bboxes.lod()
+            bboxes = np.array(bboxes)
+            scores = np.array(scores)
+            out_offsets = [0]
+            pred_res = []
+            if len(batch_offsets) > 0:
+                batch_offset = batch_offsets[0]
+                for i in range(len(batch_offset) - 1):
+                    s, e = batch_offset[i], batch_offset[i + 1]
+                    pred = _soft_nms(bboxes[s:e], scores[s:e])
+                    out_offsets.append(pred.shape[0] + out_offsets[-1])
+                    pred_res.append(pred)
+            else:
+                assert len(bboxes.shape) == 3
+                assert len(scores.shape) == 3
+                for i in range(bboxes.shape[0]):
+                    pred = _soft_nms(bboxes[i], scores[i])
+                    out_offsets.append(pred.shape[0] + out_offsets[-1])
+                    pred_res.append(pred)
+
+            res = fluid.LoDTensor()
+            res.set_lod([out_offsets])
+            if len(pred_res) == 0:
+                pred_res = np.array([[1]], dtype=np.float32)
+            res.set(np.vstack(pred_res).astype(np.float32), fluid.CPUPlace())
+            return res
+
+        pred_result = create_tmp_var(
+            fluid.default_main_program(),
+            name='softnms_pred_result',
+            dtype='float32',
+            shape=[-1, 6],
+            lod_level=1)
+        fluid.layers.py_func(
+            func=_batch_softnms, x=[bboxes, scores], out=pred_result)
+        return pred_result

+ 305 - 113
paddlex/cv/nets/detection/yolo_v3.py

@@ -16,25 +16,50 @@ from paddle import fluid
 from paddle.fluid.param_attr import ParamAttr
 from paddle.fluid.regularizer import L2Decay
 from collections import OrderedDict
+from .ops import MultiClassNMS, MultiClassSoftNMS, MatrixNMS
+from .ops import DropBlock
+from .loss.yolo_loss import YOLOv3Loss
+from .loss.iou_loss import IouLoss
+from .loss.iou_aware_loss import IouAwareLoss
+from .iou_aware import get_iou_aware_score
+try:
+    from collections.abc import Sequence
+except Exception:
+    from collections import Sequence
 
 
 class YOLOv3:
-    def __init__(self,
-                 backbone,
-                 num_classes,
-                 mode='train',
-                 anchors=None,
-                 anchor_masks=None,
-                 ignore_threshold=0.7,
-                 label_smooth=False,
-                 nms_score_threshold=0.01,
-                 nms_topk=1000,
-                 nms_keep_topk=100,
-                 nms_iou_threshold=0.45,
-                 train_random_shapes=[
-                     320, 352, 384, 416, 448, 480, 512, 544, 576, 608
-                 ],
-                 fixed_input_shape=None):
+    def __init__(
+            self,
+            backbone,
+            mode='train',
+            # YOLOv3Head
+            num_classes=80,
+            anchors=None,
+            anchor_masks=None,
+            coord_conv=False,
+            iou_aware=False,
+            iou_aware_factor=0.4,
+            scale_x_y=1.,
+            spp=False,
+            drop_block=False,
+            use_matrix_nms=False,
+            # YOLOv3Loss
+            batch_size=8,
+            ignore_threshold=0.7,
+            label_smooth=False,
+            use_fine_grained_loss=False,
+            use_iou_loss=False,
+            iou_loss_weight=2.5,
+            iou_aware_loss_weight=1.0,
+            max_height=608,
+            max_width=608,
+            # NMS
+            nms_score_threshold=0.01,
+            nms_topk=1000,
+            nms_keep_topk=100,
+            nms_iou_threshold=0.45,
+            fixed_input_shape=None):
         if anchors is None:
             anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
                        [59, 119], [116, 90], [156, 198], [373, 326]]
@@ -46,56 +71,114 @@ class YOLOv3:
         self.mode = mode
         self.num_classes = num_classes
         self.backbone = backbone
-        self.ignore_thresh = ignore_threshold
-        self.label_smooth = label_smooth
-        self.nms_score_threshold = nms_score_threshold
-        self.nms_topk = nms_topk
-        self.nms_keep_topk = nms_keep_topk
-        self.nms_iou_threshold = nms_iou_threshold
         self.norm_decay = 0.0
         self.prefix_name = ''
-        self.train_random_shapes = train_random_shapes
+        self.use_fine_grained_loss = use_fine_grained_loss
         self.fixed_input_shape = fixed_input_shape
+        self.coord_conv = coord_conv
+        self.iou_aware = iou_aware
+        self.iou_aware_factor = iou_aware_factor
+        self.scale_x_y = scale_x_y
+        self.use_spp = spp
+        self.drop_block = drop_block
 
-    def _head(self, feats):
+        if use_matrix_nms:
+            self.nms = MatrixNMS(
+                background_label=-1,
+                keep_top_k=nms_keep_topk,
+                normalized=False,
+                score_threshold=nms_score_threshold,
+                post_threshold=0.01)
+        else:
+            self.nms = MultiClassNMS(
+                background_label=-1,
+                keep_top_k=nms_keep_topk,
+                nms_threshold=nms_iou_threshold,
+                nms_top_k=nms_topk,
+                normalized=False,
+                score_threshold=nms_score_threshold)
+        self.iou_loss = None
+        self.iou_aware_loss = None
+        if use_iou_loss:
+            self.iou_loss = IouLoss(
+                loss_weight=iou_loss_weight,
+                max_height=max_height,
+                max_width=max_width)
+        if iou_aware:
+            self.iou_aware_loss = IouAwareLoss(
+                loss_weight=iou_aware_loss_weight,
+                max_height=max_height,
+                max_width=max_width)
+        self.yolo_loss = YOLOv3Loss(
+            batch_size=batch_size,
+            ignore_thresh=ignore_threshold,
+            scale_x_y=scale_x_y,
+            label_smooth=label_smooth,
+            use_fine_grained_loss=self.use_fine_grained_loss,
+            iou_loss=self.iou_loss,
+            iou_aware_loss=self.iou_aware_loss)
+        self.conv_block_num = 2
+        self.block_size = 3
+        self.keep_prob = 0.9
+        self.downsample = [32, 16, 8]
+        self.clip_bbox = True
+
+    def _head(self, input, is_train=True):
         outputs = []
+
+        # get last out_layer_num blocks in reverse order
         out_layer_num = len(self.anchor_masks)
-        blocks = feats[-1:-out_layer_num - 1:-1]
-        route = None
+        blocks = input[-1:-out_layer_num - 1:-1]
 
+        route = None
         for i, block in enumerate(blocks):
-            if i > 0:
+            if i > 0:  # perform concat in first 2 detection_block
                 block = fluid.layers.concat(input=[route, block], axis=1)
             route, tip = self._detection_block(
                 block,
-                channel=512 // (2**i),
-                name=self.prefix_name + 'yolo_block.{}'.format(i))
+                channel=64 * (2**out_layer_num) // (2**i),
+                is_first=i == 0,
+                is_test=(not is_train),
+                conv_block_num=self.conv_block_num,
+                name=self.prefix_name + "yolo_block.{}".format(i))
 
-            num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
-            block_out = fluid.layers.conv2d(
-                input=tip,
-                num_filters=num_filters,
-                filter_size=1,
-                stride=1,
-                padding=0,
-                act=None,
-                param_attr=ParamAttr(name=self.prefix_name +
-                                     'yolo_output.{}.conv.weights'.format(i)),
-                bias_attr=ParamAttr(
-                    regularizer=L2Decay(0.0),
-                    name=self.prefix_name +
-                    'yolo_output.{}.conv.bias'.format(i)))
-            outputs.append(block_out)
+            # out channel number = mask_num * (5 + class_num)
+            if self.iou_aware:
+                num_filters = len(self.anchor_masks[i]) * (
+                    self.num_classes + 6)
+            else:
+                num_filters = len(self.anchor_masks[i]) * (
+                    self.num_classes + 5)
+            with fluid.name_scope('yolo_output'):
+                block_out = fluid.layers.conv2d(
+                    input=tip,
+                    num_filters=num_filters,
+                    filter_size=1,
+                    stride=1,
+                    padding=0,
+                    act=None,
+                    param_attr=ParamAttr(
+                        name=self.prefix_name +
+                        "yolo_output.{}.conv.weights".format(i)),
+                    bias_attr=ParamAttr(
+                        regularizer=L2Decay(0.),
+                        name=self.prefix_name +
+                        "yolo_output.{}.conv.bias".format(i)))
+                outputs.append(block_out)
 
             if i < len(blocks) - 1:
+                # do not perform upsample in the last detection_block
                 route = self._conv_bn(
                     input=route,
                     ch_out=256 // (2**i),
                     filter_size=1,
                     stride=1,
                     padding=0,
-                    name=self.prefix_name + 'yolo_transition.{}'.format(i))
+                    is_test=(not is_train),
+                    name=self.prefix_name + "yolo_transition.{}".format(i))
+                # upsample
                 route = self._upsample(route)
+
         return outputs
 
     def _parse_anchors(self, anchors):
@@ -116,6 +199,54 @@ class YOLOv3:
                 assert mask < anchor_num, "anchor mask index overflow"
                 self.mask_anchors[-1].extend(anchors[mask])
 
+    def _create_tensor_from_numpy(self, numpy_array):
+        paddle_array = fluid.layers.create_global_var(
+            shape=numpy_array.shape, value=0., dtype=numpy_array.dtype)
+        fluid.layers.assign(numpy_array, paddle_array)
+        return paddle_array
+
+    def _add_coord(self, input, is_test=True):
+        if not self.coord_conv:
+            return input
+
+        # NOTE: here is used for exporting model for TensorRT inference,
+        #       only support batch_size=1 for input shape should be fixed,
+        #       and we create tensor with fixed shape from numpy array
+        if is_test and input.shape[2] > 0 and input.shape[3] > 0:
+            batch_size = 1
+            grid_x = int(input.shape[3])
+            grid_y = int(input.shape[2])
+            idx_i = np.array(
+                [[i / (grid_x - 1) * 2.0 - 1 for i in range(grid_x)]],
+                dtype='float32')
+            gi_np = np.repeat(idx_i, grid_y, axis=0)
+            gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
+            gi_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1])
+
+            x_range = self._create_tensor_from_numpy(gi_np.astype(np.float32))
+            x_range.stop_gradient = True
+            y_range = self._create_tensor_from_numpy(
+                gi_np.transpose([0, 1, 3, 2]).astype(np.float32))
+            y_range.stop_gradient = True
+
+        # NOTE: in training mode, H and W is variable for random shape,
+        #       implement add_coord with shape as Variable
+        else:
+            input_shape = fluid.layers.shape(input)
+            b = input_shape[0]
+            h = input_shape[2]
+            w = input_shape[3]
+
+            x_range = fluid.layers.range(0, w, 1, 'float32') / ((w - 1.) / 2.)
+            x_range = x_range - 1.
+            x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
+            x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
+            x_range.stop_gradient = True
+            y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
+            y_range.stop_gradient = True
+
+        return fluid.layers.concat([input, x_range, y_range], axis=1)
+
     def _conv_bn(self,
                  input,
                  ch_out,
@@ -151,18 +282,52 @@ class YOLOv3:
             out = fluid.layers.leaky_relu(x=out, alpha=0.1)
         return out
 
+    def _spp_module(self, input, is_test=True, name=""):
+        output1 = input
+        output2 = fluid.layers.pool2d(
+            input=output1,
+            pool_size=5,
+            pool_stride=1,
+            pool_padding=2,
+            ceil_mode=False,
+            pool_type='max')
+        output3 = fluid.layers.pool2d(
+            input=output1,
+            pool_size=9,
+            pool_stride=1,
+            pool_padding=4,
+            ceil_mode=False,
+            pool_type='max')
+        output4 = fluid.layers.pool2d(
+            input=output1,
+            pool_size=13,
+            pool_stride=1,
+            pool_padding=6,
+            ceil_mode=False,
+            pool_type='max')
+        output = fluid.layers.concat(
+            input=[output1, output2, output3, output4], axis=1)
+        return output
+
     def _upsample(self, input, scale=2, name=None):
         out = fluid.layers.resize_nearest(
             input=input, scale=float(scale), name=name)
         return out
 
-    def _detection_block(self, input, channel, name=None):
-        assert channel % 2 == 0, "channel({}) cannot be divided by 2 in detection block({})".format(
-            channel, name)
+    def _detection_block(self,
+                         input,
+                         channel,
+                         conv_block_num=2,
+                         is_first=False,
+                         is_test=True,
+                         name=None):
+        assert channel % 2 == 0, \
+            "channel {} cannot be divided by 2 in detection block {}" \
+            .format(channel, name)
 
-        is_test = False if self.mode == 'train' else True
         conv = input
-        for i in range(2):
+        for j in range(conv_block_num):
+            conv = self._add_coord(conv, is_test=is_test)
             conv = self._conv_bn(
                 conv,
                 channel,
@@ -170,7 +335,17 @@ class YOLOv3:
                 stride=1,
                 padding=0,
                 is_test=is_test,
-                name='{}.{}.0'.format(name, i))
+                name='{}.{}.0'.format(name, j))
+            if self.use_spp and is_first and j == 1:
+                conv = self._spp_module(conv, is_test=is_test, name="spp")
+                conv = self._conv_bn(
+                    conv,
+                    512,
+                    filter_size=1,
+                    stride=1,
+                    padding=0,
+                    is_test=is_test,
+                    name='{}.{}.spp.conv'.format(name, j))
             conv = self._conv_bn(
                 conv,
                 channel * 2,
@@ -178,7 +353,21 @@ class YOLOv3:
                 stride=1,
                 padding=1,
                 is_test=is_test,
-                name='{}.{}.1'.format(name, i))
+                name='{}.{}.1'.format(name, j))
+            if self.drop_block and j == 0 and not is_first:
+                conv = DropBlock(
+                    conv,
+                    block_size=self.block_size,
+                    keep_prob=self.keep_prob,
+                    is_test=is_test)
+
+        if self.drop_block and is_first:
+            conv = DropBlock(
+                conv,
+                block_size=self.block_size,
+                keep_prob=self.keep_prob,
+                is_test=is_test)
+        conv = self._add_coord(conv, is_test=is_test)
         route = self._conv_bn(
             conv,
             channel,
@@ -187,8 +376,9 @@ class YOLOv3:
             padding=0,
             is_test=is_test,
             name='{}.2'.format(name))
+        new_route = self._add_coord(route, is_test=is_test)
         tip = self._conv_bn(
-            route,
+            new_route,
             channel * 2,
             filter_size=3,
             stride=1,
@@ -197,54 +387,44 @@ class YOLOv3:
             name='{}.tip'.format(name))
         return route, tip
 
-    def _get_loss(self, inputs, gt_box, gt_label, gt_score):
-        losses = []
-        downsample = 32
-        for i, input in enumerate(inputs):
-            loss = fluid.layers.yolov3_loss(
-                x=input,
-                gt_box=gt_box,
-                gt_label=gt_label,
-                gt_score=gt_score,
-                anchors=self.anchors,
-                anchor_mask=self.anchor_masks[i],
-                class_num=self.num_classes,
-                ignore_thresh=self.ignore_thresh,
-                downsample_ratio=downsample,
-                use_label_smooth=self.label_smooth,
-                name=self.prefix_name + 'yolo_loss' + str(i))
-            losses.append(fluid.layers.reduce_mean(loss))
-            downsample //= 2
-        return sum(losses)
+    def _get_loss(self, inputs, gt_box, gt_label, gt_score, targets):
+        loss = self.yolo_loss(inputs, gt_box, gt_label, gt_score, targets,
+                              self.anchors, self.anchor_masks,
+                              self.mask_anchors, self.num_classes,
+                              self.prefix_name)
+        total_loss = fluid.layers.sum(list(loss.values()))
+        return total_loss
 
     def _get_prediction(self, inputs, im_size):
         boxes = []
         scores = []
-        downsample = 32
         for i, input in enumerate(inputs):
+            if self.iou_aware:
+                input = get_iou_aware_score(input,
+                                            len(self.anchor_masks[i]),
+                                            self.num_classes,
+                                            self.iou_aware_factor)
+            scale_x_y = self.scale_x_y if not isinstance(
+                self.scale_x_y, Sequence) else self.scale_x_y[i]
+
             box, score = fluid.layers.yolo_box(
                 x=input,
                 img_size=im_size,
                 anchors=self.mask_anchors[i],
                 class_num=self.num_classes,
-                conf_thresh=self.nms_score_threshold,
-                downsample_ratio=downsample,
-                name=self.prefix_name + 'yolo_box' + str(i))
+                conf_thresh=self.nms.score_threshold,
+                downsample_ratio=self.downsample[i],
+                name=self.prefix_name + 'yolo_box' + str(i),
+                clip_bbox=self.clip_bbox,
+                scale_x_y=self.scale_x_y)
             boxes.append(box)
             scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
-            downsample //= 2
+
         yolo_boxes = fluid.layers.concat(boxes, axis=1)
         yolo_scores = fluid.layers.concat(scores, axis=2)
-        pred = fluid.layers.multiclass_nms(
-            bboxes=yolo_boxes,
-            scores=yolo_scores,
-            score_threshold=self.nms_score_threshold,
-            nms_top_k=self.nms_topk,
-            keep_top_k=self.nms_keep_topk,
-            nms_threshold=self.nms_iou_threshold,
-            normalized=False,
-            nms_eta=1.0,
-            background_label=-1)
+        if type(self.nms) is MultiClassSoftNMS:
+            yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1])
+        pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
         return pred
 
     def generate_inputs(self):
@@ -267,6 +447,25 @@ class YOLOv3:
                 dtype='float32', shape=[None, None], name='gt_score')
             inputs['im_size'] = fluid.data(
                 dtype='int32', shape=[None, 2], name='im_size')
+            if self.use_fine_grained_loss:
+                downsample = 32
+                for i, mask in enumerate(self.anchor_masks):
+                    if self.fixed_input_shape is not None:
+                        target_shape = [
+                            self.fixed_input_shape[1] // downsample,
+                            self.fixed_input_shape[0] // downsample
+                        ]
+                    else:
+                        target_shape = [None, None]
+                    inputs['target{}'.format(i)] = fluid.data(
+                        dtype='float32',
+                        lod_level=0,
+                        shape=[
+                            None, len(mask), 6 + self.num_classes,
+                            target_shape[0], target_shape[1]
+                        ],
+                        name='target{}'.format(i))
+                    downsample //= 2
         elif self.mode == 'eval':
             inputs['im_size'] = fluid.data(
                 dtype='int32', shape=[None, 2], name='im_size')
@@ -284,44 +483,37 @@ class YOLOv3:
         return inputs
 
     def build_net(self, inputs):
+        import numpy as np
         image = inputs['image']
-        if self.mode == 'train':
-            if isinstance(self.train_random_shapes,
-                          (list, tuple)) and len(self.train_random_shapes) > 0:
-                import numpy as np
-                shapes = np.array(self.train_random_shapes)
-                shapes = np.stack([shapes, shapes], axis=1).astype('float32')
-                shapes_tensor = fluid.layers.assign(shapes)
-                index = fluid.layers.uniform_random(
-                    shape=[1], dtype='float32', min=0.0, max=1)
-                index = fluid.layers.cast(
-                    index * len(self.train_random_shapes), dtype='int32')
-                shape = fluid.layers.gather(shapes_tensor, index)
-                shape = fluid.layers.reshape(shape, [-1])
-                shape = fluid.layers.cast(shape, dtype='int32')
-                image = fluid.layers.resize_nearest(
-                    image, out_shape=shape, align_corners=False)
         feats = self.backbone(image)
         if isinstance(feats, OrderedDict):
             feat_names = list(feats.keys())
             feats = [feats[name] for name in feat_names]
 
-        head_outputs = self._head(feats)
+        head_outputs = self._head(feats, self.mode == 'train')
         if self.mode == 'train':
             gt_box = inputs['gt_box']
             gt_label = inputs['gt_label']
             gt_score = inputs['gt_score']
             im_size = inputs['im_size']
-            num_boxes = fluid.layers.shape(gt_box)[1]
-            im_size_wh = fluid.layers.reverse(im_size, axis=1)
-            whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
-            whwh = fluid.layers.unsqueeze(whwh, axes=[1])
-            whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
-            whwh = fluid.layers.cast(whwh, dtype='float32')
-            whwh.stop_gradient = True
-            normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
+            #num_boxes = fluid.layers.shape(gt_box)[1]
+            #im_size_wh = fluid.layers.reverse(im_size, axis=1)
+            #whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
+            #whwh = fluid.layers.unsqueeze(whwh, axes=[1])
+            #whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
+            #whwh = fluid.layers.cast(whwh, dtype='float32')
+            #whwh.stop_gradient = True
+            #normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
+            normalized_box = gt_box
+
+            targets = []
+            if self.use_fine_grained_loss:
+                for i, mask in enumerate(self.anchor_masks):
+                    k = 'target{}'.format(i)
+                    if k in inputs:
+                        targets.append(inputs[k])
             return self._get_loss(head_outputs, normalized_box, gt_label,
-                                  gt_score)
+                                  gt_score, targets)
         else:
             im_size = inputs['im_size']
             return self._get_prediction(head_outputs, im_size)

+ 185 - 0
paddlex/cv/transforms/det_transforms.py

@@ -55,6 +55,7 @@ class Compose(DetTransform):
             raise ValueError('The length of transforms ' + \
                             'must be equal or larger than 1!')
         self.transforms = transforms
+        self.batch_transforms = None
         self.use_mixup = False
         for t in self.transforms:
             if type(t).__name__ == 'MixupImage':
@@ -1385,3 +1386,187 @@ class ComposedYOLOv3Transforms(Compose):
                         mean=mean, std=std)
             ]
         super(ComposedYOLOv3Transforms, self).__init__(transforms)
+
+
+class BatchRandomShape(DetTransform):
+    """调整图像大小(resize)。
+
+    对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
+    注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
+
+    Args:
+        random_shapes (list): resize大小选择列表。
+            默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
+        interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
+            ['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"RANDOM"。
+    Raises:
+        ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
+                    'AREA', 'LANCZOS4', 'RANDOM']中。
+    """
+
+    # The interpolation mode
+    interp_dict = {
+        'NEAREST': cv2.INTER_NEAREST,
+        'LINEAR': cv2.INTER_LINEAR,
+        'CUBIC': cv2.INTER_CUBIC,
+        'AREA': cv2.INTER_AREA,
+        'LANCZOS4': cv2.INTER_LANCZOS4
+    }
+
+    def __init__(
+            self,
+            random_shapes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+            interp='RANDOM'):
+        if not (interp == "RANDOM" or interp in self.interp_dict):
+            raise ValueError("interp should be one of {}".format(
+                self.interp_dict.keys()))
+        self.random_shapes = random_shapes
+        self.interp = interp
+
+    def __call__(self, batch_data):
+        """
+        Args:
+            batch_data (list): 由与图像相关的各种信息组成的batch数据。
+        Returns:
+            list: 由与图像相关的各种信息组成的batch数据。
+        """
+        shape = np.random.choice(self.random_shapes)
+
+        if self.interp == "RANDOM":
+            interp = random.choice(list(self.interp_dict.keys()))
+        else:
+            interp = self.interp
+        for data_id, data in enumerate(batch_data):
+            data_list = list(data)
+            im = data_list[0]
+            im = np.swapaxes(im, 1, 0)
+            im = np.swapaxes(im, 1, 2)
+            im = resize(im, shape, self.interp_dict[interp])
+            im = np.swapaxes(im, 1, 2)
+            im = np.swapaxes(im, 1, 0)
+            data_list[0] = im
+            batch_data[data_id] = tuple(data_list)
+        return batch_data
+
+
+class GenerateYoloTarget(object):
+    """生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
+       该transform只在YOLOv3计算细粒度loss时使用。
+
+       Args:
+           anchors (list|tuple): anchor框的宽度和高度。
+           anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
+           num_classes (int): 类别数。默认为80。
+           iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
+    """
+
+    def __init__(self,
+                 anchors,
+                 anchor_masks,
+                 downsample_ratios,
+                 num_classes=80,
+                 iou_thresh=1.):
+        super(GenerateYoloTarget, self).__init__()
+        self.anchors = anchors
+        self.anchor_masks = anchor_masks
+        self.downsample_ratios = downsample_ratios
+        self.num_classes = num_classes
+        self.iou_thresh = iou_thresh
+
+    def __call__(self, batch_data):
+        """
+        Args:
+            batch_data (list): 由与图像相关的各种信息组成的batch数据。
+        Returns:
+            list: 由与图像相关的各种信息组成的batch数据。
+                  其中,每个数据新添加的字段为:
+                           - target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
+                                   形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
+                           - target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
+                                   形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
+                           - ...
+                           -targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
+                                   形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
+                    n的是大小由anchor_masks的长度决定。
+        """
+        im = batch_data[0][0]
+        h = im.shape[1]
+        w = im.shape[2]
+        an_hw = np.array(self.anchors) / np.array([[w, h]])
+        for data_id, data in enumerate(batch_data):
+            gt_bbox = data[1]
+            gt_class = data[2]
+            gt_score = data[3]
+            im_shape = data[4]
+            origin_h = float(im_shape[0])
+            origin_w = float(im_shape[1])
+            data_list = list(data)
+            for i, (
+                    mask, downsample_ratio
+            ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
+                grid_h = int(h / downsample_ratio)
+                grid_w = int(w / downsample_ratio)
+                target = np.zeros(
+                    (len(mask), 6 + self.num_classes, grid_h, grid_w),
+                    dtype=np.float32)
+                for b in range(gt_bbox.shape[0]):
+                    gx = gt_bbox[b, 0] / float(origin_w)
+                    gy = gt_bbox[b, 1] / float(origin_h)
+                    gw = gt_bbox[b, 2] / float(origin_w)
+                    gh = gt_bbox[b, 3] / float(origin_h)
+                    cls = gt_class[b]
+                    score = gt_score[b]
+                    if gw <= 0. or gh <= 0. or score <= 0.:
+                        continue
+                    # find best match anchor index
+                    best_iou = 0.
+                    best_idx = -1
+                    for an_idx in range(an_hw.shape[0]):
+                        iou = jaccard_overlap(
+                            [0., 0., gw, gh],
+                            [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
+                        if iou > best_iou:
+                            best_iou = iou
+                            best_idx = an_idx
+                    gi = int(gx * grid_w)
+                    gj = int(gy * grid_h)
+                    # gtbox should be regresed in this layes if best match
+                    # anchor index in anchor mask of this layer
+                    if best_idx in mask:
+                        best_n = mask.index(best_idx)
+                        # x, y, w, h, scale
+                        target[best_n, 0, gj, gi] = gx * grid_w - gi
+                        target[best_n, 1, gj, gi] = gy * grid_h - gj
+                        target[best_n, 2, gj, gi] = np.log(
+                            gw * w / self.anchors[best_idx][0])
+                        target[best_n, 3, gj, gi] = np.log(
+                            gh * h / self.anchors[best_idx][1])
+                        target[best_n, 4, gj, gi] = 2.0 - gw * gh
+                        # objectness record gt_score
+                        target[best_n, 5, gj, gi] = score
+                        # classification
+                        target[best_n, 6 + cls, gj, gi] = 1.
+                    # For non-matched anchors, calculate the target if the iou
+                    # between anchor and gt is larger than iou_thresh
+                    if self.iou_thresh < 1:
+                        for idx, mask_i in enumerate(mask):
+                            if mask_i == best_idx: continue
+                            iou = jaccard_overlap(
+                                [0., 0., gw, gh],
+                                [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
+                            if iou > self.iou_thresh:
+                                # x, y, w, h, scale
+                                target[idx, 0, gj, gi] = gx * grid_w - gi
+                                target[idx, 1, gj, gi] = gy * grid_h - gj
+                                target[idx, 2, gj, gi] = np.log(
+                                    gw * w / self.anchors[mask_i][0])
+                                target[idx, 3, gj, gi] = np.log(
+                                    gh * h / self.anchors[mask_i][1])
+                                target[idx, 4, gj, gi] = 2.0 - gw * gh
+                                # objectness record gt_score
+                                target[idx, 5, gj, gi] = score
+                                # classification
+                                target[idx, 6 + cls, gj, gi] = 1.
+                data_list.append(target)
+            batch_data[data_id] = tuple(data_list)
+        return batch_data

+ 69 - 0
tutorials/train/object_detection/ppyolo.py

@@ -0,0 +1,69 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.det import transforms
+import paddlex as pdx
+
+# 下载和解压昆虫检测数据集
+insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(insect_dataset, path='./')
+
+# 定义训练和验证时的transforms
+# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/det_transforms.html
+train_transforms = transforms.Compose([
+    transforms.MixupImage(mixup_epoch=250), transforms.RandomDistort(),
+    transforms.RandomExpand(), transforms.RandomCrop(), transforms.Resize(
+        target_size=608, interp='RANDOM'), transforms.RandomHorizontalFlip(),
+    transforms.Normalize()
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Resize(
+        target_size=608, interp='CUBIC'), transforms.Normalize()
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-vocdetection
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/train_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
+num_classes = len(train_dataset.labels)
+
+# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#paddlex-det-yolov3
+model = pdx.det.YOLOv3(
+    num_classes=num_classes,
+    backbone='ResNet50_vd',
+    with_dcn_v2=True,
+    use_coord_conv=True,
+    use_iou_aware=True,
+    use_spp=True,
+    use_drop_block=True,
+    scale_x_y=1.05,
+    use_iou_loss=True,
+    use_matrix_nms=True)
+
+# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#train
+# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
+model.train(
+    num_epochs=270,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    learning_rate=0.000125,
+    lr_decay_epochs=[210, 240],
+    use_ema=True,
+    save_dir='output/ppyolo',
+    use_vdl=True)