FlyingQianMM 5 vuotta sitten
vanhempi
commit
6b8a6a7e9d

+ 8 - 0
paddlex/cv/datasets/dataset.py

@@ -205,6 +205,14 @@ def generate_minibatch(batch_data, label_padding_value=255, mapper=None):
             batch_data = op(batch_data)
     # if batch_size is 1, do not pad the image
     if len(batch_data) == 1:
+        #im = np.load('/home/luoqianhui/PaddleDetection/image.npy')
+        #im_info = np.load('/home/luoqianhui/PaddleDetection/im_info.npy')
+        #box = np.load('/home/luoqianhui/PaddleDetection/gt_bbox.npy')
+        #id = np.load('/home/luoqianhui/PaddleDetection/gt_class.npy')
+        #diff = np.load('/home/luoqianhui/PaddleDetection/difficult.npy')
+        #im_shape = np.array([1920,2560,1], dtype=np.float32)
+        #batch_data = [(im, im_info, box, im_shape, id, diff)]
+        #batch_data = [(im, im_info, box, id, diff)]
         return batch_data
     width = [data[0].shape[2] for data in batch_data]
     height = [data[0].shape[1] for data in batch_data]

+ 7 - 2
paddlex/cv/datasets/voc.py

@@ -104,8 +104,11 @@ class VOCDetection(Dataset):
                 if not osp.isfile(xml_file):
                     continue
                 if not osp.exists(img_file):
-                    raise IOError('The image file {} is not exist!'.format(
-                        img_file))
+                    #raise IOError('The image file {} is not exist!'.format(
+                    #    img_file))
+                    continue
+                if not osp.exists(xml_file):
+                    continue
                 tree = ET.parse(xml_file)
                 if tree.find('id') is None:
                     im_id = np.array([ct])
@@ -141,6 +144,8 @@ class VOCDetection(Dataset):
                     name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
                         1:-1]
                     cname = obj.find(name_tag).text.strip()
+                    if cname in ['bu_dao_dian', 'jiao_wei_lou_di']:
+                        cname = 'lou_di'
                     gt_class[i][0] = cname2cid[cname]
                     pattern = re.compile('<difficult>', re.IGNORECASE)
                     diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][

+ 5 - 1
paddlex/cv/models/base.py

@@ -34,6 +34,9 @@ from os import path as osp
 from paddle.fluid.framework import Program
 from .utils.pretrain_weights import get_pretrain_weights
 
+#fluid.default_startup_program().random_seed = 1000
+#fluid.default_main_program().random_seed = 1000
+
 
 def dict2str(dict_input):
     out = ''
@@ -544,7 +547,7 @@ class BaseAPI:
             time_train_one_epoch = time.time() - epoch_start_time
             epoch_start_time = time.time()
 
-            # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
+            ## 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
             self.completed_epochs += 1
             eval_epoch_start_time = time.time()
             if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
@@ -591,3 +594,4 @@ class BaseAPI:
                 if eval_dataset is not None and early_stop:
                     if earlystop(current_accuracy):
                         break
+            #return

+ 30 - 2
paddlex/cv/models/faster_rcnn.py

@@ -50,7 +50,16 @@ class FasterRCNN(BaseAPI):
                  with_dcn=False,
                  rpn_cls_loss='SigmoidCrossEntropy',
                  rpn_focal_loss_alpha=0.25,
-                 rpn_focal_loss_gamma=2):
+                 rpn_focal_loss_gamma=2,
+                 rcnn_bbox_loss='SmoothL1Loss',
+                 rcnn_nms='MultiClassNMS',
+                 keep_top_k=100,
+                 nms_threshold=0.5,
+                 score_threshold=0.05,
+                 softnms_sigma=0.5,
+                 post_threshold=0.05,
+                 bbox_assigner='BBoxAssigner',
+                 fpn_num_channels=256):
         self.init_params = locals()
         super(FasterRCNN, self).__init__('detector')
         backbones = [
@@ -73,6 +82,15 @@ class FasterRCNN(BaseAPI):
         self.rpn_cls_loss = rpn_cls_loss
         self.rpn_focal_loss_alpha = rpn_focal_loss_alpha
         self.rpn_focal_loss_gamma = rpn_focal_loss_gamma
+        self.rcnn_bbox_loss = rcnn_bbox_loss
+        self.rcnn_nms = rcnn_nms
+        self.keep_top_k = keep_top_k
+        self.nms_threshold = nms_threshold
+        self.score_threshold = score_threshold
+        self.softnms_sigma = softnms_sigma
+        self.post_threshold = post_threshold
+        self.bbox_assigner = bbox_assigner
+        self.fpn_num_channels = fpn_num_channels
 
     def _get_backbone(self, backbone_name):
         norm_type = None
@@ -145,7 +163,16 @@ class FasterRCNN(BaseAPI):
             fixed_input_shape=self.fixed_input_shape,
             rpn_cls_loss=self.rpn_cls_loss,
             rpn_focal_loss_alpha=self.rpn_focal_loss_alpha,
-            rpn_focal_loss_gamma=self.rpn_focal_loss_gamma)
+            rpn_focal_loss_gamma=self.rpn_focal_loss_gamma,
+            rcnn_bbox_loss=self.rcnn_bbox_loss,
+            rcnn_nms=self.rcnn_nms,
+            keep_top_k=self.keep_top_k,
+            nms_threshold=self.nms_threshold,
+            score_threshold=self.score_threshold,
+            post_threshold=self.post_threshold,
+            softnms_sigma=self.softnms_sigma,
+            bbox_assigner=self.bbox_assigner,
+            fpn_num_channels=self.fpn_num_channels)
         inputs = model.generate_inputs()
         if mode == 'train':
             model_out = model.build_net(inputs)
@@ -187,6 +214,7 @@ class FasterRCNN(BaseAPI):
             end_lr=learning_rate)
         optimizer = fluid.optimizer.Momentum(
             learning_rate=lr_warmup,
+            #learning_rate=lr_decay,
             momentum=0.9,
             regularization=fluid.regularizer.L2Decay(1e-04))
         return optimizer

+ 109 - 59
paddlex/cv/nets/detection/bbox_head.py

@@ -24,6 +24,9 @@ from paddle.fluid.initializer import Normal, Xavier
 from paddle.fluid.regularizer import L2Decay
 from paddle.fluid.initializer import MSRA
 
+from .loss.diou_loss import DiouLoss
+from .ops import MultiClassNMS, MatrixNMS, MultiClassSoftNMS, MultiClassDiouNMS
+
 __all__ = ['BBoxHead', 'TwoFCHead']
 
 
@@ -42,23 +45,27 @@ class TwoFCHead(object):
     def __call__(self, roi_feat):
         fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3]
 
-        fc6 = fluid.layers.fc(
-            input=roi_feat,
-            size=self.mlp_dim,
-            act='relu',
-            name='fc6',
-            param_attr=ParamAttr(
-                name='fc6_w', initializer=Xavier(fan_out=fan)),
-            bias_attr=ParamAttr(
-                name='fc6_b', learning_rate=2., regularizer=L2Decay(0.)))
-        head_feat = fluid.layers.fc(
-            input=fc6,
-            size=self.mlp_dim,
-            act='relu',
-            name='fc7',
-            param_attr=ParamAttr(name='fc7_w', initializer=Xavier()),
-            bias_attr=ParamAttr(
-                name='fc7_b', learning_rate=2., regularizer=L2Decay(0.)))
+        fc6 = fluid.layers.fc(input=roi_feat,
+                              size=self.mlp_dim,
+                              act='relu',
+                              name='fc6',
+                              param_attr=ParamAttr(
+                                  name='fc6_w',
+                                  initializer=Xavier(fan_out=fan)),
+                              bias_attr=ParamAttr(
+                                  name='fc6_b',
+                                  learning_rate=2.,
+                                  regularizer=L2Decay(0.)))
+        head_feat = fluid.layers.fc(input=fc6,
+                                    size=self.mlp_dim,
+                                    act='relu',
+                                    name='fc7',
+                                    param_attr=ParamAttr(
+                                        name='fc7_w', initializer=Xavier()),
+                                    bias_attr=ParamAttr(
+                                        name='fc7_b',
+                                        learning_rate=2.,
+                                        regularizer=L2Decay(0.)))
 
         return head_feat
 
@@ -73,6 +80,7 @@ class BBoxHead(object):
             box_normalized=False,
             axis=1,
             #MultiClassNMS
+            rcnn_nms='MultiClassNMS',
             score_threshold=.05,
             nms_top_k=-1,
             keep_top_k=100,
@@ -80,25 +88,63 @@ class BBoxHead(object):
             normalized=False,
             nms_eta=1.0,
             background_label=0,
+            post_threshold=.05,
+            softnms_sigma=0.5,
             #bbox_loss
             sigma=1.0,
-            num_classes=81):
+            num_classes=81,
+            rcnn_bbox_loss='SmoothL1Loss',
+            diouloss_weight=10.0,
+            diouloss_is_cls_agnostic=False,
+            diouloss_use_complete_iou_loss=True):
         super(BBoxHead, self).__init__()
         self.head = head
         self.prior_box_var = prior_box_var
         self.code_type = code_type
         self.box_normalized = box_normalized
         self.axis = axis
-        self.score_threshold = score_threshold
-        self.nms_top_k = nms_top_k
-        self.keep_top_k = keep_top_k
-        self.nms_threshold = nms_threshold
-        self.normalized = normalized
-        self.nms_eta = nms_eta
-        self.background_label = background_label
         self.sigma = sigma
         self.num_classes = num_classes
         self.head_feat = None
+        self.rcnn_bbox_loss = rcnn_bbox_loss
+        self.diouloss_weight = diouloss_weight
+        self.diouloss_is_cls_agnostic = diouloss_is_cls_agnostic
+        self.diouloss_use_complete_iou_loss = diouloss_use_complete_iou_loss
+        if self.rcnn_bbox_loss == 'DIoULoss':
+            self.diou_loss = DiouLoss(
+                loss_weight=self.diouloss_weight,
+                is_cls_agnostic=self.diouloss_is_cls_agnostic,
+                num_classes=num_classes,
+                use_complete_iou_loss=self.diouloss_use_complete_iou_loss)
+        if rcnn_nms == 'MultiClassNMS':
+            self.nms = MultiClassNMS(
+                score_threshold=score_threshold,
+                keep_top_k=keep_top_k,
+                nms_threshold=nms_threshold,
+                normalized=normalized,
+                nms_eta=nms_eta,
+                background_label=background_label)
+        elif rcnn_nms == 'MultiClassSoftNMS':
+            self.nms = MultiClassSoftNMS(
+                score_threshold=score_threshold,
+                keep_top_k=keep_top_k,
+                softnms_sigma=softnms_sigma,
+                normalized=normalized,
+                background_label=background_label)
+        elif rcnn_nms == 'MatrixNMS':
+            self.nms = MatrixNMS(
+                score_threshold=score_threshold,
+                post_threshold=post_threshold,
+                keep_top_k=keep_top_k,
+                normalized=normalized,
+                background_label=background_label)
+        elif rcnn_nms == 'MultiClassCiouNMS':
+            self.nms = MultiClassDiouNMS(
+                score_threshold=score_threshold,
+                keep_top_k=keep_top_k,
+                nms_threshold=nms_threshold,
+                normalized=normalized,
+                background_label=background_label)
 
     def get_head_feat(self, input=None):
         """
@@ -130,24 +176,30 @@ class BBoxHead(object):
         if not isinstance(self.head, TwoFCHead):
             head_feat = fluid.layers.pool2d(
                 head_feat, pool_type='avg', global_pooling=True)
-        cls_score = fluid.layers.fc(
-            input=head_feat,
-            size=self.num_classes,
-            act=None,
-            name='cls_score',
-            param_attr=ParamAttr(
-                name='cls_score_w', initializer=Normal(loc=0.0, scale=0.01)),
-            bias_attr=ParamAttr(
-                name='cls_score_b', learning_rate=2., regularizer=L2Decay(0.)))
-        bbox_pred = fluid.layers.fc(
-            input=head_feat,
-            size=4 * self.num_classes,
-            act=None,
-            name='bbox_pred',
-            param_attr=ParamAttr(
-                name='bbox_pred_w', initializer=Normal(loc=0.0, scale=0.001)),
-            bias_attr=ParamAttr(
-                name='bbox_pred_b', learning_rate=2., regularizer=L2Decay(0.)))
+        cls_score = fluid.layers.fc(input=head_feat,
+                                    size=self.num_classes,
+                                    act=None,
+                                    name='cls_score',
+                                    param_attr=ParamAttr(
+                                        name='cls_score_w',
+                                        initializer=Normal(
+                                            loc=0.0, scale=0.01)),
+                                    bias_attr=ParamAttr(
+                                        name='cls_score_b',
+                                        learning_rate=2.,
+                                        regularizer=L2Decay(0.)))
+        bbox_pred = fluid.layers.fc(input=head_feat,
+                                    size=4 * self.num_classes,
+                                    act=None,
+                                    name='bbox_pred',
+                                    param_attr=ParamAttr(
+                                        name='bbox_pred_w',
+                                        initializer=Normal(
+                                            loc=0.0, scale=0.001)),
+                                    bias_attr=ParamAttr(
+                                        name='bbox_pred_b',
+                                        learning_rate=2.,
+                                        regularizer=L2Decay(0.)))
         return cls_score, bbox_pred
 
     def get_loss(self, roi_feat, labels_int32, bbox_targets,
@@ -179,12 +231,19 @@ class BBoxHead(object):
         loss_cls = fluid.layers.softmax_with_cross_entropy(
             logits=cls_score, label=labels_int64, numeric_stable_mode=True)
         loss_cls = fluid.layers.reduce_mean(loss_cls)
-        loss_bbox = fluid.layers.smooth_l1(
-            x=bbox_pred,
-            y=bbox_targets,
-            inside_weight=bbox_inside_weights,
-            outside_weight=bbox_outside_weights,
-            sigma=self.sigma)
+        if self.rcnn_bbox_loss == 'SmoothL1Loss':
+            loss_bbox = fluid.layers.smooth_l1(
+                x=bbox_pred,
+                y=bbox_targets,
+                inside_weight=bbox_inside_weights,
+                outside_weight=bbox_outside_weights,
+                sigma=self.sigma)
+        elif self.rcnn_bbox_loss == 'DIoULoss':
+            loss_bbox = self.diou_loss(
+                x=bbox_pred,
+                y=bbox_targets,
+                inside_weight=bbox_inside_weights,
+                outside_weight=bbox_outside_weights)
         loss_bbox = fluid.layers.reduce_mean(loss_bbox)
         return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox}
 
@@ -229,14 +288,5 @@ class BBoxHead(object):
         cliped_box = fluid.layers.box_clip(input=decoded_box, im_info=im_shape)
         if return_box_score:
             return {'bbox': cliped_box, 'score': cls_prob}
-        pred_result = fluid.layers.multiclass_nms(
-            bboxes=cliped_box,
-            scores=cls_prob,
-            score_threshold=self.score_threshold,
-            nms_top_k=self.nms_top_k,
-            keep_top_k=self.keep_top_k,
-            nms_threshold=self.nms_threshold,
-            normalized=self.normalized,
-            nms_eta=self.nms_eta,
-            background_label=self.background_label)
+        pred_result = self.nms(bboxes=cliped_box, scores=cls_prob)
         return {'bbox': pred_result}

+ 48 - 14
paddlex/cv/nets/detection/faster_rcnn.py

@@ -26,6 +26,8 @@ from .rpn_head import (RPNHead, FPNRPNHead)
 from .roi_extractor import (RoIAlign, FPNRoIAlign)
 from .bbox_head import (BBoxHead, TwoFCHead)
 from ..resnet import ResNetC5
+from .loss.diou_loss import DiouLoss
+from .ops import BBoxAssigner, LibraBBoxAssigner
 
 __all__ = ['FasterRCNN']
 
@@ -73,6 +75,9 @@ class FasterRCNN(object):
             keep_top_k=100,
             nms_threshold=0.5,
             score_threshold=0.05,
+            rcnn_nms='MultiClassNMS',
+            softnms_sigma=0.5,
+            post_threshold=.05,
             #bbox_assigner
             batch_size_per_im=512,
             fg_fraction=.25,
@@ -80,7 +85,13 @@ class FasterRCNN(object):
             bg_thresh_hi=.5,
             bg_thresh_lo=0.,
             bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
-            fixed_input_shape=None):
+            fixed_input_shape=None,
+            rcnn_bbox_loss='SmoothL1Loss',
+            diouloss_weight=10.0,
+            diouloss_is_cls_agnostic=False,
+            diouloss_use_complete_iou_loss=True,
+            bbox_assigner='BBoxAssigner',
+            fpn_num_channels=256):
         super(FasterRCNN, self).__init__()
         self.backbone = backbone
         self.mode = mode
@@ -92,6 +103,7 @@ class FasterRCNN(object):
             else:
                 fpn = FPN()
         self.fpn = fpn
+        self.fpn.num_chan = fpn_num_channels
         self.num_classes = num_classes
         if rpn_head is None:
             if self.fpn is None:
@@ -110,7 +122,8 @@ class FasterRCNN(object):
                     test_nms_thresh=test_nms_thresh,
                     rpn_cls_loss=rpn_cls_loss,
                     rpn_focal_loss_alpha=rpn_focal_loss_alpha,
-                    rpn_focal_loss_gamma=rpn_focal_loss_gamma)
+                    rpn_focal_loss_gamma=rpn_focal_loss_gamma,
+                    use_random=False)
             else:
                 rpn_head = FPNRPNHead(
                     anchor_start_size=anchor_sizes[0],
@@ -130,7 +143,8 @@ class FasterRCNN(object):
                     test_nms_thresh=test_nms_thresh,
                     rpn_cls_loss=rpn_cls_loss,
                     rpn_focal_loss_alpha=rpn_focal_loss_alpha,
-                    rpn_focal_loss_gamma=rpn_focal_loss_gamma)
+                    rpn_focal_loss_gamma=rpn_focal_loss_gamma,
+                    use_random=False)
         self.rpn_head = rpn_head
         if roi_extractor is None:
             if self.fpn is None:
@@ -154,7 +168,15 @@ class FasterRCNN(object):
                 keep_top_k=keep_top_k,
                 nms_threshold=nms_threshold,
                 score_threshold=score_threshold,
-                num_classes=num_classes)
+                rcnn_nms=rcnn_nms,
+                softnms_sigma=softnms_sigma,
+                post_threshold=post_threshold,
+                num_classes=num_classes,
+                rcnn_bbox_loss=rcnn_bbox_loss,
+                diouloss_weight=diouloss_weight,
+                diouloss_is_cls_agnostic=diouloss_is_cls_agnostic,
+                diouloss_use_complete_iou_loss=diouloss_use_complete_iou_loss)
+
         self.bbox_head = bbox_head
         self.batch_size_per_im = batch_size_per_im
         self.fg_fraction = fg_fraction
@@ -164,6 +186,26 @@ class FasterRCNN(object):
         self.bbox_reg_weights = bbox_reg_weights
         self.rpn_only = rpn_only
         self.fixed_input_shape = fixed_input_shape
+        if bbox_assigner == 'BBoxAssigner':
+            self.bbox_assigner = BBoxAssigner(
+                batch_size_per_im=batch_size_per_im,
+                fg_fraction=fg_fraction,
+                fg_thresh=fg_thresh,
+                bg_thresh_hi=bg_thresh_hi,
+                bg_thresh_lo=bg_thresh_lo,
+                bbox_reg_weights=bbox_reg_weights,
+                num_classes=num_classes,
+                shuffle_before_sample=self.rpn_head.use_random)
+        elif bbox_assigner == 'LibraBBoxAssigner':
+            self.bbox_assigner = LibraBBoxAssigner(
+                batch_size_per_im=batch_size_per_im,
+                fg_fraction=fg_fraction,
+                fg_thresh=fg_thresh,
+                bg_thresh_hi=bg_thresh_hi,
+                bg_thresh_lo=bg_thresh_lo,
+                bbox_reg_weights=bbox_reg_weights,
+                num_classes=num_classes,
+                shuffle_before_sample=self.rpn_head.use_random)
 
     def build_net(self, inputs):
         im = inputs['image']
@@ -184,20 +226,12 @@ class FasterRCNN(object):
 
         if self.mode == 'train':
             rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd)
-            outputs = fluid.layers.generate_proposal_labels(
+            outputs = self.bbox_assigner(
                 rpn_rois=rois,
                 gt_classes=inputs['gt_label'],
                 is_crowd=inputs['is_crowd'],
                 gt_boxes=inputs['gt_box'],
-                im_info=inputs['im_info'],
-                batch_size_per_im=self.batch_size_per_im,
-                fg_fraction=self.fg_fraction,
-                fg_thresh=self.fg_thresh,
-                bg_thresh_hi=self.bg_thresh_hi,
-                bg_thresh_lo=self.bg_thresh_lo,
-                bbox_reg_weights=self.bbox_reg_weights,
-                class_nums=self.num_classes,
-                use_random=self.rpn_head.use_random)
+                im_info=inputs['im_info'])
 
             rois = outputs[0]
             labels_int32 = outputs[1]

+ 696 - 0
paddlex/cv/nets/detection/ops.py

@@ -21,6 +21,63 @@ import paddle
 from paddle import fluid
 
 
+def bbox_overlaps(boxes_1, boxes_2):
+    '''
+    bbox_overlaps
+        boxes_1: x1, y, x2, y2
+        boxes_2: x1, y, x2, y2
+    '''
+    assert boxes_1.shape[1] == 4 and boxes_2.shape[1] == 4
+
+    num_1 = boxes_1.shape[0]
+    num_2 = boxes_2.shape[0]
+
+    x1_1 = boxes_1[:, 0:1]
+    y1_1 = boxes_1[:, 1:2]
+    x2_1 = boxes_1[:, 2:3]
+    y2_1 = boxes_1[:, 3:4]
+    area_1 = (x2_1 - x1_1 + 1) * (y2_1 - y1_1 + 1)
+
+    x1_2 = boxes_2[:, 0].transpose()
+    y1_2 = boxes_2[:, 1].transpose()
+    x2_2 = boxes_2[:, 2].transpose()
+    y2_2 = boxes_2[:, 3].transpose()
+    area_2 = (x2_2 - x1_2 + 1) * (y2_2 - y1_2 + 1)
+
+    xx1 = np.maximum(x1_1, x1_2)
+    yy1 = np.maximum(y1_1, y1_2)
+    xx2 = np.minimum(x2_1, x2_2)
+    yy2 = np.minimum(y2_1, y2_2)
+
+    w = np.maximum(0.0, xx2 - xx1 + 1)
+    h = np.maximum(0.0, yy2 - yy1 + 1)
+    inter = w * h
+
+    ovr = inter / (area_1 + area_2 - inter)
+    return ovr
+
+
+def box_to_delta(ex_boxes, gt_boxes, weights):
+    """ box_to_delta """
+    ex_w = ex_boxes[:, 2] - ex_boxes[:, 0] + 1
+    ex_h = ex_boxes[:, 3] - ex_boxes[:, 1] + 1
+    ex_ctr_x = ex_boxes[:, 0] + 0.5 * ex_w
+    ex_ctr_y = ex_boxes[:, 1] + 0.5 * ex_h
+
+    gt_w = gt_boxes[:, 2] - gt_boxes[:, 0] + 1
+    gt_h = gt_boxes[:, 3] - gt_boxes[:, 1] + 1
+    gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_w
+    gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_h
+
+    dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0]
+    dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1]
+    dw = (np.log(gt_w / ex_w)) / weights[2]
+    dh = (np.log(gt_h / ex_h)) / weights[3]
+
+    targets = np.vstack([dx, dy, dw, dh]).transpose()
+    return targets
+
+
 def DropBlock(input, block_size, keep_prob, is_test):
     if is_test:
         return input
@@ -268,3 +325,642 @@ class MultiClassSoftNMS(object):
         fluid.layers.py_func(
             func=_batch_softnms, x=[bboxes, scores], out=pred_result)
         return pred_result
+
+
+class MultiClassDiouNMS(object):
+    def __init__(
+            self,
+            score_threshold=0.05,
+            keep_top_k=100,
+            nms_threshold=0.5,
+            normalized=False,
+            background_label=0, ):
+        super(MultiClassDiouNMS, self).__init__()
+        self.score_threshold = score_threshold
+        self.nms_threshold = nms_threshold
+        self.keep_top_k = keep_top_k
+        self.normalized = normalized
+        self.background_label = background_label
+
+    def __call__(self, bboxes, scores):
+        def create_tmp_var(program, name, dtype, shape, lod_level):
+            return program.current_block().create_var(
+                name=name, dtype=dtype, shape=shape, lod_level=lod_level)
+
+        def _calc_diou_term(dets1, dets2):
+            eps = 1.e-10
+            eta = 0 if self.normalized else 1
+
+            x1, y1, x2, y2 = dets1[0], dets1[1], dets1[2], dets1[3]
+            x1g, y1g, x2g, y2g = dets2[0], dets2[1], dets2[2], dets2[3]
+
+            cx = (x1 + x2) / 2
+            cy = (y1 + y2) / 2
+            w = x2 - x1 + eta
+            h = y2 - y1 + eta
+
+            cxg = (x1g + x2g) / 2
+            cyg = (y1g + y2g) / 2
+            wg = x2g - x1g + eta
+            hg = y2g - y1g + eta
+
+            x2 = np.maximum(x1, x2)
+            y2 = np.maximum(y1, y2)
+
+            # A or B
+            xc1 = np.minimum(x1, x1g)
+            yc1 = np.minimum(y1, y1g)
+            xc2 = np.maximum(x2, x2g)
+            yc2 = np.maximum(y2, y2g)
+
+            # DIOU term
+            dist_intersection = (cx - cxg)**2 + (cy - cyg)**2
+            dist_union = (xc2 - xc1)**2 + (yc2 - yc1)**2
+            diou_term = (dist_intersection + eps) / (dist_union + eps)
+            return diou_term
+
+        def _diou_nms_for_cls(dets, thres):
+            """_diou_nms_for_cls"""
+            scores = dets[:, 0]
+            x1 = dets[:, 1]
+            y1 = dets[:, 2]
+            x2 = dets[:, 3]
+            y2 = dets[:, 4]
+            eta = 0 if self.normalized else 1
+            areas = (x2 - x1 + eta) * (y2 - y1 + eta)
+            dt_num = dets.shape[0]
+            order = np.array(range(dt_num))
+
+            keep = []
+            while order.size > 0:
+                i = order[0]
+                keep.append(i)
+                xx1 = np.maximum(x1[i], x1[order[1:]])
+                yy1 = np.maximum(y1[i], y1[order[1:]])
+                xx2 = np.minimum(x2[i], x2[order[1:]])
+                yy2 = np.minimum(y2[i], y2[order[1:]])
+
+                w = np.maximum(0.0, xx2 - xx1 + eta)
+                h = np.maximum(0.0, yy2 - yy1 + eta)
+                inter = w * h
+                ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+                diou_term = _calc_diou_term([x1[i], y1[i], x2[i], y2[i]], [
+                    x1[order[1:]], y1[order[1:]], x2[order[1:]], y2[order[1:]]
+                ])
+
+                inds = np.where(ovr - diou_term <= thres)[0]
+
+                order = order[inds + 1]
+
+            dets_final = dets[keep]
+            return dets_final
+
+        def _diou_nms(bboxes, scores):
+            bboxes = np.array(bboxes)
+            scores = np.array(scores)
+            class_nums = scores.shape[-1]
+
+            score_threshold = self.score_threshold
+            nms_threshold = self.nms_threshold
+            keep_top_k = self.keep_top_k
+
+            cls_boxes = [[] for _ in range(class_nums)]
+            cls_ids = [[] for _ in range(class_nums)]
+
+            start_idx = 1 if self.background_label == 0 else 0
+            for j in range(start_idx, class_nums):
+                inds = np.where(scores[:, j] >= score_threshold)[0]
+                scores_j = scores[inds, j]
+                rois_j = bboxes[inds, j, :]
+                dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
+                    np.float32, copy=False)
+                cls_rank = np.argsort(-dets_j[:, 0])
+                dets_j = dets_j[cls_rank]
+
+                cls_boxes[j] = _diou_nms_for_cls(dets_j, thres=nms_threshold)
+                cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1,
+                                                                           1)
+
+            cls_boxes = np.vstack(cls_boxes[start_idx:])
+            cls_ids = np.vstack(cls_ids[start_idx:])
+            pred_result = np.hstack([cls_ids, cls_boxes]).astype(np.float32)
+
+            # Limit to max_per_image detections **over all classes**
+            image_scores = cls_boxes[:, 0]
+            if len(image_scores) > keep_top_k:
+                image_thresh = np.sort(image_scores)[-keep_top_k]
+                keep = np.where(cls_boxes[:, 0] >= image_thresh)[0]
+                pred_result = pred_result[keep, :]
+
+            res = fluid.LoDTensor()
+            res.set_lod([[0, pred_result.shape[0]]])
+            if pred_result.shape[0] == 0:
+                pred_result = np.array([[1]], dtype=np.float32)
+            res.set(pred_result, fluid.CPUPlace())
+
+            return res
+
+        pred_result = create_tmp_var(
+            fluid.default_main_program(),
+            name='diou_nms_pred_result',
+            dtype='float32',
+            shape=[-1, 6],
+            lod_level=0)
+        fluid.layers.py_func(
+            func=_diou_nms, x=[bboxes, scores], out=pred_result)
+        return pred_result
+
+
+class LibraBBoxAssigner(object):
+    def __init__(self,
+                 batch_size_per_im=512,
+                 fg_fraction=.25,
+                 fg_thresh=.5,
+                 bg_thresh_hi=.5,
+                 bg_thresh_lo=0.,
+                 bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
+                 num_classes=81,
+                 shuffle_before_sample=True,
+                 is_cls_agnostic=False,
+                 num_bins=3):
+        super(LibraBBoxAssigner, self).__init__()
+        self.batch_size_per_im = batch_size_per_im
+        self.fg_fraction = fg_fraction
+        self.fg_thresh = fg_thresh
+        self.bg_thresh_hi = bg_thresh_hi
+        self.bg_thresh_lo = bg_thresh_lo
+        self.bbox_reg_weights = bbox_reg_weights
+        self.class_nums = num_classes
+        self.use_random = shuffle_before_sample
+        self.is_cls_agnostic = is_cls_agnostic
+        self.num_bins = num_bins
+
+    def __call__(
+            self,
+            rpn_rois,
+            gt_classes,
+            is_crowd,
+            gt_boxes,
+            im_info, ):
+        return self.generate_proposal_label_libra(
+            rpn_rois=rpn_rois,
+            gt_classes=gt_classes,
+            is_crowd=is_crowd,
+            gt_boxes=gt_boxes,
+            im_info=im_info,
+            batch_size_per_im=self.batch_size_per_im,
+            fg_fraction=self.fg_fraction,
+            fg_thresh=self.fg_thresh,
+            bg_thresh_hi=self.bg_thresh_hi,
+            bg_thresh_lo=self.bg_thresh_lo,
+            bbox_reg_weights=self.bbox_reg_weights,
+            class_nums=self.class_nums,
+            use_random=self.use_random,
+            is_cls_agnostic=self.is_cls_agnostic,
+            is_cascade_rcnn=False)
+
+    def generate_proposal_label_libra(
+            self, rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
+            batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
+            bg_thresh_lo, bbox_reg_weights, class_nums, use_random,
+            is_cls_agnostic, is_cascade_rcnn):
+        num_bins = self.num_bins
+
+        def create_tmp_var(program, name, dtype, shape, lod_level=None):
+            return program.current_block().create_var(
+                name=name, dtype=dtype, shape=shape, lod_level=lod_level)
+
+        def _sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
+            if len(pos_inds) <= num_expected:
+                return pos_inds
+            else:
+                unique_gt_inds = np.unique(max_classes[pos_inds])
+                num_gts = len(unique_gt_inds)
+                num_per_gt = int(round(num_expected / float(num_gts)) + 1)
+
+                sampled_inds = []
+                for i in unique_gt_inds:
+                    inds = np.nonzero(max_classes == i)[0]
+                    before_len = len(inds)
+                    inds = list(set(inds) & set(pos_inds))
+                    after_len = len(inds)
+                    if len(inds) > num_per_gt:
+                        inds = np.random.choice(
+                            inds, size=num_per_gt, replace=False)
+                    sampled_inds.extend(list(inds))  # combine as a new sampler
+                if len(sampled_inds) < num_expected:
+                    num_extra = num_expected - len(sampled_inds)
+                    extra_inds = np.array(
+                        list(set(pos_inds) - set(sampled_inds)))
+                    assert len(sampled_inds)+len(extra_inds) == len(pos_inds), \
+                        "sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
+                            len(sampled_inds), len(extra_inds), len(pos_inds))
+                    if len(extra_inds) > num_extra:
+                        extra_inds = np.random.choice(
+                            extra_inds, size=num_extra, replace=False)
+                    sampled_inds.extend(extra_inds.tolist())
+                elif len(sampled_inds) > num_expected:
+                    sampled_inds = np.random.choice(
+                        sampled_inds, size=num_expected, replace=False)
+                return sampled_inds
+
+        def sample_via_interval(max_overlaps, full_set, num_expected,
+                                floor_thr, num_bins, bg_thresh_hi):
+            max_iou = max_overlaps.max()
+            iou_interval = (max_iou - floor_thr) / num_bins
+            per_num_expected = int(num_expected / num_bins)
+
+            sampled_inds = []
+            for i in range(num_bins):
+                start_iou = floor_thr + i * iou_interval
+                end_iou = floor_thr + (i + 1) * iou_interval
+
+                tmp_set = set(
+                    np.where(
+                        np.logical_and(max_overlaps >= start_iou, max_overlaps
+                                       < end_iou))[0])
+                tmp_inds = list(tmp_set & full_set)
+
+                if len(tmp_inds) > per_num_expected:
+                    tmp_sampled_set = np.random.choice(
+                        tmp_inds, size=per_num_expected, replace=False)
+                else:
+                    tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
+                sampled_inds.append(tmp_sampled_set)
+
+            sampled_inds = np.concatenate(sampled_inds)
+            if len(sampled_inds) < num_expected:
+                num_extra = num_expected - len(sampled_inds)
+                extra_inds = np.array(list(full_set - set(sampled_inds)))
+                assert len(sampled_inds)+len(extra_inds) == len(full_set), \
+                    "sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
+                            len(sampled_inds), len(extra_inds), len(full_set))
+
+                if len(extra_inds) > num_extra:
+                    extra_inds = np.random.choice(
+                        extra_inds, num_extra, replace=False)
+                sampled_inds = np.concatenate([sampled_inds, extra_inds])
+
+            return sampled_inds
+
+        def _sample_neg(max_overlaps,
+                        max_classes,
+                        neg_inds,
+                        num_expected,
+                        floor_thr=-1,
+                        floor_fraction=0,
+                        num_bins=3,
+                        bg_thresh_hi=0.5):
+            if len(neg_inds) <= num_expected:
+                return neg_inds
+            else:
+                # balance sampling for negative samples
+                neg_set = set(neg_inds)
+                if floor_thr > 0:
+                    floor_set = set(
+                        np.where(
+                            np.logical_and(max_overlaps >= 0, max_overlaps <
+                                           floor_thr))[0])
+                    iou_sampling_set = set(
+                        np.where(max_overlaps >= floor_thr)[0])
+                elif floor_thr == 0:
+                    floor_set = set(np.where(max_overlaps == 0)[0])
+                    iou_sampling_set = set(
+                        np.where(max_overlaps > floor_thr)[0])
+                else:
+                    floor_set = set()
+                    iou_sampling_set = set(
+                        np.where(max_overlaps > floor_thr)[0])
+                    floor_thr = 0
+
+                floor_neg_inds = list(floor_set & neg_set)
+                iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
+
+                num_expected_iou_sampling = int(num_expected *
+                                                (1 - floor_fraction))
+                if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
+                    if num_bins >= 2:
+                        iou_sampled_inds = sample_via_interval(
+                            max_overlaps,
+                            set(iou_sampling_neg_inds),
+                            num_expected_iou_sampling, floor_thr, num_bins,
+                            bg_thresh_hi)
+                    else:
+                        iou_sampled_inds = np.random.choice(
+                            iou_sampling_neg_inds,
+                            size=num_expected_iou_sampling,
+                            replace=False)
+                else:
+                    iou_sampled_inds = np.array(
+                        iou_sampling_neg_inds, dtype=np.int)
+                num_expected_floor = num_expected - len(iou_sampled_inds)
+                if len(floor_neg_inds) > num_expected_floor:
+                    sampled_floor_inds = np.random.choice(
+                        floor_neg_inds, size=num_expected_floor, replace=False)
+                else:
+                    sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
+                sampled_inds = np.concatenate(
+                    (sampled_floor_inds, iou_sampled_inds))
+                if len(sampled_inds) < num_expected:
+                    num_extra = num_expected - len(sampled_inds)
+                    extra_inds = np.array(list(neg_set - set(sampled_inds)))
+                    if len(extra_inds) > num_extra:
+                        extra_inds = np.random.choice(
+                            extra_inds, size=num_extra, replace=False)
+                    sampled_inds = np.concatenate((sampled_inds, extra_inds))
+                return sampled_inds
+
+        def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
+                         batch_size_per_im, fg_fraction, fg_thresh,
+                         bg_thresh_hi, bg_thresh_lo, bbox_reg_weights,
+                         class_nums, use_random, is_cls_agnostic,
+                         is_cascade_rcnn):
+            rois_per_image = int(batch_size_per_im)
+            fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
+
+            # Roidb
+            im_scale = im_info[2]
+            inv_im_scale = 1. / im_scale
+            rpn_rois = rpn_rois * inv_im_scale
+            if is_cascade_rcnn:
+                rpn_rois = rpn_rois[gt_boxes.shape[0]:, :]
+            boxes = np.vstack([gt_boxes, rpn_rois])
+            gt_overlaps = np.zeros((boxes.shape[0], class_nums))
+            box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32)
+            if len(gt_boxes) > 0:
+                proposal_to_gt_overlaps = bbox_overlaps(boxes, gt_boxes)
+
+                overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
+                overlaps_max = proposal_to_gt_overlaps.max(axis=1)
+                # Boxes which with non-zero overlap with gt boxes
+                overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
+
+                overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
+                    overlapped_boxes_ind]]
+
+                for idx in range(len(overlapped_boxes_ind)):
+                    gt_overlaps[overlapped_boxes_ind[
+                        idx], overlapped_boxes_gt_classes[idx]] = overlaps_max[
+                            overlapped_boxes_ind[idx]]
+                    box_to_gt_ind_map[overlapped_boxes_ind[
+                        idx]] = overlaps_argmax[overlapped_boxes_ind[idx]]
+
+            crowd_ind = np.where(is_crowd)[0]
+            gt_overlaps[crowd_ind] = -1
+
+            max_overlaps = gt_overlaps.max(axis=1)
+            max_classes = gt_overlaps.argmax(axis=1)
+
+            # Cascade RCNN Decode Filter
+            if is_cascade_rcnn:
+                ws = boxes[:, 2] - boxes[:, 0] + 1
+                hs = boxes[:, 3] - boxes[:, 1] + 1
+                keep = np.where((ws > 0) & (hs > 0))[0]
+                boxes = boxes[keep]
+                max_overlaps = max_overlaps[keep]
+                fg_inds = np.where(max_overlaps >= fg_thresh)[0]
+                bg_inds = np.where((max_overlaps < bg_thresh_hi) & (
+                    max_overlaps >= bg_thresh_lo))[0]
+                fg_rois_per_this_image = fg_inds.shape[0]
+                bg_rois_per_this_image = bg_inds.shape[0]
+            else:
+                # Foreground
+                fg_inds = np.where(max_overlaps >= fg_thresh)[0]
+                fg_rois_per_this_image = np.minimum(fg_rois_per_im,
+                                                    fg_inds.shape[0])
+                # Sample foreground if there are too many
+                if fg_inds.shape[0] > fg_rois_per_this_image:
+                    if use_random:
+                        fg_inds = _sample_pos(max_overlaps, max_classes,
+                                              fg_inds, fg_rois_per_this_image)
+                fg_inds = fg_inds[:fg_rois_per_this_image]
+
+                # Background
+                bg_inds = np.where((max_overlaps < bg_thresh_hi) & (
+                    max_overlaps >= bg_thresh_lo))[0]
+                bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
+                bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
+                                                    bg_inds.shape[0])
+                assert bg_rois_per_this_image >= 0, "bg_rois_per_this_image must be >= 0 but got {}".format(
+                    bg_rois_per_this_image)
+
+                # Sample background if there are too many
+                if bg_inds.shape[0] > bg_rois_per_this_image:
+                    if use_random:
+                        # libra neg sample
+                        bg_inds = _sample_neg(
+                            max_overlaps,
+                            max_classes,
+                            bg_inds,
+                            bg_rois_per_this_image,
+                            num_bins=num_bins,
+                            bg_thresh_hi=bg_thresh_hi)
+                bg_inds = bg_inds[:bg_rois_per_this_image]
+
+            keep_inds = np.append(fg_inds, bg_inds)
+            sampled_labels = max_classes[keep_inds]  # N x 1
+            sampled_labels[fg_rois_per_this_image:] = 0
+            sampled_boxes = boxes[keep_inds]  # N x 324
+            sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]]
+            sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0]
+            bbox_label_targets = _compute_targets(
+                sampled_boxes, sampled_gts, sampled_labels, bbox_reg_weights)
+            bbox_targets, bbox_inside_weights = _expand_bbox_targets(
+                bbox_label_targets, class_nums, is_cls_agnostic)
+            bbox_outside_weights = np.array(
+                bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype)
+            # Scale rois
+            sampled_rois = sampled_boxes * im_scale
+
+            # Faster RCNN blobs
+            frcn_blobs = dict(
+                rois=sampled_rois,
+                labels_int32=sampled_labels,
+                bbox_targets=bbox_targets,
+                bbox_inside_weights=bbox_inside_weights,
+                bbox_outside_weights=bbox_outside_weights)
+            return frcn_blobs
+
+        def _compute_targets(roi_boxes, gt_boxes, labels, bbox_reg_weights):
+            assert roi_boxes.shape[0] == gt_boxes.shape[0]
+            assert roi_boxes.shape[1] == 4
+            assert gt_boxes.shape[1] == 4
+
+            targets = np.zeros(roi_boxes.shape)
+            bbox_reg_weights = np.asarray(bbox_reg_weights)
+            targets = box_to_delta(
+                ex_boxes=roi_boxes,
+                gt_boxes=gt_boxes,
+                weights=bbox_reg_weights)
+
+            return np.hstack([labels[:, np.newaxis], targets]).astype(
+                np.float32, copy=False)
+
+        def _expand_bbox_targets(bbox_targets_input, class_nums,
+                                 is_cls_agnostic):
+            class_labels = bbox_targets_input[:, 0]
+            fg_inds = np.where(class_labels > 0)[0]
+            bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums
+                                     if not is_cls_agnostic else 4 * 2))
+            bbox_inside_weights = np.zeros(bbox_targets.shape)
+            for ind in fg_inds:
+                class_label = int(class_labels[
+                    ind]) if not is_cls_agnostic else 1
+                start_ind = class_label * 4
+                end_ind = class_label * 4 + 4
+                bbox_targets[ind, start_ind:end_ind] = bbox_targets_input[ind,
+                                                                          1:]
+                bbox_inside_weights[ind, start_ind:end_ind] = (1.0, 1.0, 1.0,
+                                                               1.0)
+            return bbox_targets, bbox_inside_weights
+
+        def generate_func(
+                rpn_rois,
+                gt_classes,
+                is_crowd,
+                gt_boxes,
+                im_info, ):
+            rpn_rois_lod = rpn_rois.lod()[0]
+            gt_classes_lod = gt_classes.lod()[0]
+
+            # convert
+            rpn_rois = np.array(rpn_rois)
+            gt_classes = np.array(gt_classes)
+            is_crowd = np.array(is_crowd)
+            gt_boxes = np.array(gt_boxes)
+            im_info = np.array(im_info)
+
+            rois = []
+            labels_int32 = []
+            bbox_targets = []
+            bbox_inside_weights = []
+            bbox_outside_weights = []
+            lod = [0]
+
+            for idx in range(len(rpn_rois_lod) - 1):
+                rois_si = rpn_rois_lod[idx]
+                rois_ei = rpn_rois_lod[idx + 1]
+
+                gt_si = gt_classes_lod[idx]
+                gt_ei = gt_classes_lod[idx + 1]
+                frcn_blobs = _sample_rois(
+                    rpn_rois[rois_si:rois_ei], gt_classes[gt_si:gt_ei],
+                    is_crowd[gt_si:gt_ei], gt_boxes[gt_si:gt_ei], im_info[idx],
+                    batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
+                    bg_thresh_lo, bbox_reg_weights, class_nums, use_random,
+                    is_cls_agnostic, is_cascade_rcnn)
+                lod.append(frcn_blobs['rois'].shape[0] + lod[-1])
+                rois.append(frcn_blobs['rois'])
+                labels_int32.append(frcn_blobs['labels_int32'].reshape(-1, 1))
+                bbox_targets.append(frcn_blobs['bbox_targets'])
+                bbox_inside_weights.append(frcn_blobs['bbox_inside_weights'])
+                bbox_outside_weights.append(frcn_blobs['bbox_outside_weights'])
+
+            rois = np.vstack(rois)
+            labels_int32 = np.vstack(labels_int32)
+            bbox_targets = np.vstack(bbox_targets)
+            bbox_inside_weights = np.vstack(bbox_inside_weights)
+            bbox_outside_weights = np.vstack(bbox_outside_weights)
+
+            # create lod-tensor for return
+            # notice that the func create_lod_tensor does not work well here
+            ret_rois = fluid.LoDTensor()
+            ret_rois.set_lod([lod])
+            ret_rois.set(rois.astype("float32"), fluid.CPUPlace())
+
+            ret_labels_int32 = fluid.LoDTensor()
+            ret_labels_int32.set_lod([lod])
+            ret_labels_int32.set(
+                labels_int32.astype("int32"), fluid.CPUPlace())
+
+            ret_bbox_targets = fluid.LoDTensor()
+            ret_bbox_targets.set_lod([lod])
+            ret_bbox_targets.set(
+                bbox_targets.astype("float32"), fluid.CPUPlace())
+
+            ret_bbox_inside_weights = fluid.LoDTensor()
+            ret_bbox_inside_weights.set_lod([lod])
+            ret_bbox_inside_weights.set(
+                bbox_inside_weights.astype("float32"), fluid.CPUPlace())
+
+            ret_bbox_outside_weights = fluid.LoDTensor()
+            ret_bbox_outside_weights.set_lod([lod])
+            ret_bbox_outside_weights.set(
+                bbox_outside_weights.astype("float32"), fluid.CPUPlace())
+
+            return ret_rois, ret_labels_int32, ret_bbox_targets, ret_bbox_inside_weights, ret_bbox_outside_weights
+
+        rois = create_tmp_var(
+            fluid.default_main_program(),
+            name=None,  #'rois',
+            dtype='float32',
+            shape=[-1, 4], )
+        bbox_inside_weights = create_tmp_var(
+            fluid.default_main_program(),
+            name=None,  #'bbox_inside_weights',
+            dtype='float32',
+            shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
+        bbox_outside_weights = create_tmp_var(
+            fluid.default_main_program(),
+            name=None,  #'bbox_outside_weights',
+            dtype='float32',
+            shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
+        bbox_targets = create_tmp_var(
+            fluid.default_main_program(),
+            name=None,  #'bbox_targets',
+            dtype='float32',
+            shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
+        labels_int32 = create_tmp_var(
+            fluid.default_main_program(),
+            name=None,  #'labels_int32',
+            dtype='int32',
+            shape=[-1, 1], )
+
+        outs = [
+            rois, labels_int32, bbox_targets, bbox_inside_weights,
+            bbox_outside_weights
+        ]
+
+        fluid.layers.py_func(
+            func=generate_func,
+            x=[rpn_rois, gt_classes, is_crowd, gt_boxes, im_info],
+            out=outs)
+        return outs
+
+
+class BBoxAssigner(object):
+    def __init__(self,
+                 batch_size_per_im=512,
+                 fg_fraction=.25,
+                 fg_thresh=.5,
+                 bg_thresh_hi=.5,
+                 bg_thresh_lo=0.,
+                 bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
+                 num_classes=81,
+                 shuffle_before_sample=True):
+        super(BBoxAssigner, self).__init__()
+        self.batch_size_per_im = batch_size_per_im
+        self.fg_fraction = fg_fraction
+        self.fg_thresh = fg_thresh
+        self.bg_thresh_hi = bg_thresh_hi
+        self.bg_thresh_lo = bg_thresh_lo
+        self.bbox_reg_weights = bbox_reg_weights
+        self.class_nums = num_classes
+        self.use_random = shuffle_before_sample
+
+    def __call__(self, rpn_rois, gt_classes, is_crowd, gt_boxes, im_info):
+        return fluid.layers.generate_proposal_labels(
+            rpn_rois=rpn_rois,
+            gt_classes=gt_classes,
+            is_crowd=is_crowd,
+            gt_boxes=gt_boxes,
+            im_info=im_info,
+            batch_size_per_im=self.batch_size_per_im,
+            fg_fraction=self.fg_fraction,
+            fg_thresh=self.fg_thresh,
+            bg_thresh_hi=self.bg_thresh_hi,
+            bg_thresh_lo=self.bg_thresh_lo,
+            bbox_reg_weights=self.bbox_reg_weights,
+            class_nums=self.class_nums,
+            use_random=self.use_random)

+ 23 - 7
paddlex/cv/nets/detection/rpn_head.py

@@ -312,13 +312,29 @@ class RPNHead(object):
                         negative_overlap=self.rpn_negative_overlap,
                         num_classes=1)
                 fg_num = fluid.layers.reduce_sum(fg_num, name='fg_num')
-                score_tgt = fluid.layers.cast(score_tgt, 'int32')
-                rpn_cls_loss = fluid.layers.sigmoid_focal_loss(
-                    x=score_pred,
-                    label=score_tgt,
-                    fg_num=fg_num,
-                    gamma=self.rpn_focal_loss_gamma,
-                    alpha=self.rpn_focal_loss_alpha)
+                #score_tgt = fluid.layers.cast(score_tgt, 'int32')
+                #rpn_cls_loss = fluid.layers.sigmoid_focal_loss(
+                #    x=score_pred,
+                #    label=score_tgt,
+                #    fg_num=fg_num,
+                #    gamma=self.rpn_focal_loss_gamma,
+                #    alpha=self.rpn_focal_loss_alpha)
+                score_tgt = fluid.layers.cast(x=score_tgt, dtype='float32')
+                score_tgt.stop_gradient = True
+                loss = fluid.layers.sigmoid_cross_entropy_with_logits(
+                    x=score_pred, label=score_tgt)
+
+                pred = fluid.layers.sigmoid(score_pred)
+                p_t = pred * score_tgt + (1 - pred) * (1 - score_tgt)
+
+                if self.rpn_focal_loss_alpha is not None:
+                    alpha_t = self.rpn_focal_loss_alpha * score_tgt + (
+                        1 - self.rpn_focal_loss_alpha) * (1 - score_tgt)
+                    loss = alpha_t * loss
+                gamma_t = fluid.layers.pow((1 - p_t),
+                                           self.rpn_focal_loss_gamma)
+                loss = gamma_t * loss
+                rpn_cls_loss = loss / fg_num
         else:
             score_pred, loc_pred, score_tgt, loc_tgt, bbox_weight = \
                 fluid.layers.rpn_target_assign(

+ 9 - 4
paddlex/cv/transforms/det_transforms.py

@@ -177,7 +177,7 @@ class ResizeByShort(DetTransform):
     4. 根据调整大小的比例对图像进行resize。
 
     Args:
-        target_size (int): 短边目标长度。默认为800。
+        short_size (int|list): 短边目标长度。默认为800。
         max_size (int): 长边目标长度的最大限制。默认为1333。
 
      Raises:
@@ -186,9 +186,9 @@ class ResizeByShort(DetTransform):
 
     def __init__(self, short_size=800, max_size=1333):
         self.max_size = int(max_size)
-        if not isinstance(short_size, int):
+        if not (isinstance(short_size, int) or isinstance(short_size, list)):
             raise TypeError(
-                "Type of short_size is invalid. Must be Integer, now is {}".
+                "Type of short_size is invalid. Must be Integer or List, now is {}".
                 format(type(short_size)))
         self.short_size = short_size
         if not (isinstance(self.max_size, int)):
@@ -221,7 +221,12 @@ class ResizeByShort(DetTransform):
             raise ValueError('ResizeByShort: image is not 3-dimensional.')
         im_short_size = min(im.shape[0], im.shape[1])
         im_long_size = max(im.shape[0], im.shape[1])
-        scale = float(self.short_size) / im_short_size
+        if isinstance(self.short_size, list):
+            # Case for multi-scale training
+            selected_size = random.choice(self.short_size)
+        else:
+            selected_size = self.short_size
+        scale = float(selected_size) / im_short_size
         if self.max_size > 0 and np.round(scale *
                                           im_long_size) > self.max_size:
             scale = float(self.max_size) / float(im_long_size)