Pārlūkot izejas kodu

Merge pull request #804 from will-jl944/develop_jf

Refine Detection Code
FlyingQianMM 4 gadi atpakaļ
vecāks
revīzija
4e3094afa4

+ 1 - 1
dygraph/PaddleDetection

@@ -1 +1 @@
-Subproject commit 66d7eefab9aca8243ddf49a52b748b786b80ffb5
+Subproject commit c987dc1e543f1e489a32b165d7078b591d2ca363

+ 7 - 0
dygraph/paddlex/cv/datasets/coco.py

@@ -57,12 +57,14 @@ class CocoDetection(VOCDetection):
         super(VOCDetection, self).__init__()
         self.data_fields = None
         self.transforms = copy.deepcopy(transforms)
+        self.num_max_boxes = 50
         self.use_mix = False
         if self.transforms is not None:
             for op in self.transforms.transforms:
                 if isinstance(op, MixupImage):
                     self.mixup_op = copy.deepcopy(op)
                     self.use_mix = True
+                    self.num_max_boxes *= 2
                     break
 
         self.batch_transforms = None
@@ -153,6 +155,11 @@ class CocoDetection(VOCDetection):
                 **
                 label_info
             }))
+        if self.use_mix:
+            self.num_max_boxes = max(self.num_max_boxes, 2 * len(instances))
+        else:
+            self.num_max_boxes = max(self.num_max_boxes, len(instances))
+
         if not len(self.file_list) > 0:
             raise Exception('not found any coco record in %s' % ann_file)
         logging.info("{} samples in file {}".format(

+ 7 - 0
dygraph/paddlex/cv/datasets/voc.py

@@ -56,6 +56,7 @@ class VOCDetection(Dataset):
         super(VOCDetection, self).__init__()
         self.data_fields = None
         self.transforms = copy.deepcopy(transforms)
+        self.num_max_boxes = 50
 
         self.use_mix = False
         if self.transforms is not None:
@@ -63,6 +64,7 @@ class VOCDetection(Dataset):
                 if isinstance(op, MixupImage):
                     self.mixup_op = copy.deepcopy(op)
                     self.use_mix = True
+                    self.num_max_boxes *= 2
                     break
 
         self.batch_transforms = None
@@ -257,6 +259,11 @@ class VOCDetection(Dataset):
                         'id': int(im_id[0]),
                         'file_name': osp.split(img_file)[1]
                     })
+                if self.use_mix:
+                    self.num_max_boxes = max(self.num_max_boxes, 2 * len(objs))
+                else:
+                    self.num_max_boxes = max(self.num_max_boxes, len(objs))
+
         if not len(self.file_list) > 0:
             raise Exception('not found any voc record in %s' % (file_list))
         logging.info("{} samples in file {}".format(

+ 48 - 59
dygraph/paddlex/cv/models/detector.py

@@ -192,9 +192,10 @@ class BaseDetector(BaseModel):
                 "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
             self.metric = metric.lower()
 
+        self.labels = train_dataset.labels
+        self.num_max_boxes = train_dataset.num_max_boxes
         train_dataset.batch_transforms = self._compose_batch_transform(
             train_dataset.transforms, mode='train')
-        self.labels = train_dataset.labels
 
         # build optimizer if not defined
         if optimizer is None:
@@ -334,12 +335,24 @@ class BaseDetector(BaseModel):
             collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
 
         """
-        if eval_dataset.__class__.__name__ == 'VOCDetection':
+
+        if metric is None:
+            if not hasattr(self, 'metric'):
+                if eval_dataset.__class__.__name__ == 'VOCDetection':
+                    self.metric = 'voc'
+                elif eval_dataset.__class__.__name__ == 'CocoDetection':
+                    self.metric = 'coco'
+        else:
+            assert metric.lower() in ['coco', 'voc'], \
+                "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
+            self.metric = metric.lower()
+
+        if self.metric == 'voc':
             eval_dataset.data_fields = {
                 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
                 'difficult'
             }
-        elif eval_dataset.__class__.__name__ == 'CocoDetection':
+        elif self.metric == 'coco':
             if self.__class__.__name__ == 'MaskRCNN':
                 eval_dataset.data_fields = {
                     'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
@@ -380,41 +393,16 @@ class BaseDetector(BaseModel):
                 is_bbox_normalized = any(
                     isinstance(t, _NormalizeBox)
                     for t in eval_dataset.batch_transforms.batch_transforms)
-            if metric is None:
-                if getattr(self, 'metric', None) is not None:
-                    if self.metric == 'voc':
-                        eval_metric = VOCMetric(
-                            labels=eval_dataset.labels,
-                            coco_gt=copy.deepcopy(eval_dataset.coco_gt),
-                            is_bbox_normalized=is_bbox_normalized,
-                            classwise=False)
-                    else:
-                        eval_metric = COCOMetric(
-                            coco_gt=copy.deepcopy(eval_dataset.coco_gt),
-                            classwise=False)
-                else:
-                    if eval_dataset.__class__.__name__ == 'VOCDetection':
-                        eval_metric = VOCMetric(
-                            labels=eval_dataset.labels,
-                            coco_gt=copy.deepcopy(eval_dataset.coco_gt),
-                            is_bbox_normalized=is_bbox_normalized,
-                            classwise=False)
-                    elif eval_dataset.__class__.__name__ == 'CocoDetection':
-                        eval_metric = COCOMetric(
-                            coco_gt=copy.deepcopy(eval_dataset.coco_gt),
-                            classwise=False)
+            if self.metric == 'voc':
+                eval_metric = VOCMetric(
+                    labels=eval_dataset.labels,
+                    coco_gt=copy.deepcopy(eval_dataset.coco_gt),
+                    is_bbox_normalized=is_bbox_normalized,
+                    classwise=False)
             else:
-                assert metric.lower() in ['coco', 'voc'], \
-                    "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
-                if metric.lower() == 'coco':
-                    eval_metric = COCOMetric(
-                        coco_gt=copy.deepcopy(eval_dataset.coco_gt),
-                        classwise=False)
-                else:
-                    eval_metric = VOCMetric(
-                        labels=eval_dataset.labels,
-                        is_bbox_normalized=is_bbox_normalized,
-                        classwise=False)
+                eval_metric = COCOMetric(
+                    coco_gt=copy.deepcopy(eval_dataset.coco_gt),
+                    classwise=False)
             scores = collections.OrderedDict()
             logging.info(
                 "Start to evaluate(total_samples={}, total_steps={})...".
@@ -649,8 +637,7 @@ class YOLOv3(BaseDetector):
     def _compose_batch_transform(self, transforms, mode='train'):
         if mode == 'train':
             default_batch_transforms = [
-                _BatchPadding(
-                    pad_to_stride=-1, pad_gt=False), _NormalizeBox(),
+                _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
                 _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
                 _Gt2YoloTarget(
                     anchor_masks=self.anchor_masks,
@@ -660,10 +647,11 @@ class YOLOv3(BaseDetector):
                     num_classes=self.num_classes)
             ]
         else:
-            default_batch_transforms = [
-                _BatchPadding(
-                    pad_to_stride=-1, pad_gt=False)
-            ]
+            default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
+        if mode == 'eval' and self.metric == 'voc':
+            collate_batch = False
+        else:
+            collate_batch = True
 
         custom_batch_transforms = []
         for i, op in enumerate(transforms.transforms):
@@ -675,8 +663,9 @@ class YOLOv3(BaseDetector):
                         "Please check the {} transforms.".format(mode))
                 custom_batch_transforms.insert(0, copy.deepcopy(op))
 
-        batch_transforms = BatchCompose(custom_batch_transforms +
-                                        default_batch_transforms)
+        batch_transforms = BatchCompose(
+            custom_batch_transforms + default_batch_transforms,
+            collate_batch=collate_batch)
 
         return batch_transforms
 
@@ -901,14 +890,14 @@ class FasterRCNN(BaseDetector):
     def _compose_batch_transform(self, transforms, mode='train'):
         if mode == 'train':
             default_batch_transforms = [
-                _BatchPadding(
-                    pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
+                _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
             ]
+            collate_batch = False
         else:
             default_batch_transforms = [
-                _BatchPadding(
-                    pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
+                _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
             ]
+            collate_batch = True
         custom_batch_transforms = []
         for i, op in enumerate(transforms.transforms):
             if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
@@ -919,8 +908,9 @@ class FasterRCNN(BaseDetector):
                         "Please check the {} transforms.".format(mode))
                 custom_batch_transforms.insert(0, copy.deepcopy(op))
 
-        batch_transforms = BatchCompose(custom_batch_transforms +
-                                        default_batch_transforms)
+        batch_transforms = BatchCompose(
+            custom_batch_transforms + default_batch_transforms,
+            collate_batch=collate_batch)
 
         return batch_transforms
 
@@ -1189,7 +1179,6 @@ class PPYOLOTiny(YOLOv3):
         self.anchors = anchors
         self.anchor_masks = anchor_masks
         self.downsample_ratios = downsample_ratios
-        self.num_max_boxes = 100
         self.model_name = 'PPYOLOTiny'
 
 
@@ -1313,7 +1302,6 @@ class PPYOLOv2(YOLOv3):
         self.anchors = anchors
         self.anchor_masks = anchor_masks
         self.downsample_ratios = downsample_ratios
-        self.num_max_boxes = 100
         self.model_name = 'PPYOLOv2'
 
 
@@ -1542,14 +1530,14 @@ class MaskRCNN(BaseDetector):
     def _compose_batch_transform(self, transforms, mode='train'):
         if mode == 'train':
             default_batch_transforms = [
-                _BatchPadding(
-                    pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
+                _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
             ]
+            collate_batch = False
         else:
             default_batch_transforms = [
-                _BatchPadding(
-                    pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
+                _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
             ]
+            collate_batch = True
         custom_batch_transforms = []
         for i, op in enumerate(transforms.transforms):
             if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
@@ -1560,7 +1548,8 @@ class MaskRCNN(BaseDetector):
                         "Please check the {} transforms.".format(mode))
                 custom_batch_transforms.insert(0, copy.deepcopy(op))
 
-        batch_transforms = BatchCompose(custom_batch_transforms +
-                                        default_batch_transforms)
+        batch_transforms = BatchCompose(
+            custom_batch_transforms + default_batch_transforms,
+            collate_batch=collate_batch)
 
         return batch_transforms

+ 3 - 4
dygraph/paddlex/cv/models/utils/det_metrics/coco_utils.py

@@ -20,9 +20,8 @@ import sys
 import copy
 import numpy as np
 import itertools
-
+from ppdet.metrics.map_utils import draw_pr_curve
 from .json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res
-from .map_utils import _draw_pr_curve
 
 import paddlex.utils.logging as logging
 
@@ -123,7 +122,7 @@ def cocoapi_eval(anns,
                 (str(nm["name"]), '{:0.3f}'.format(float(ap))))
             pr_array = precisions[0, :, idx, 0, 2]
             recall_array = np.arange(0.0, 1.01, 0.01)
-            _draw_pr_curve(
+            draw_pr_curve(
                 pr_array,
                 recall_array,
                 out_dir=style + '_pr_curve',
@@ -133,7 +132,7 @@ def cocoapi_eval(anns,
         results_flatten = list(itertools.chain(*results_per_category))
         headers = ['category', 'AP'] * (num_columns // 2)
         results_2d = itertools.zip_longest(
-            * [results_flatten[i::num_columns] for i in range(num_columns)])
+            *[results_flatten[i::num_columns] for i in range(num_columns)])
         table_data = [headers]
         table_data += [result for result in results_2d]
         table = AsciiTable(table_data)

+ 0 - 305
dygraph/paddlex/cv/models/utils/det_metrics/map_utils.py

@@ -1,305 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-
-import os
-import sys
-import numpy as np
-import itertools
-import paddlex.utils.logging as logging
-
-__all__ = [
-    '_draw_pr_curve', 'bbox_area', 'jaccard_overlap', 'prune_zero_padding',
-    'DetectionMAP'
-]
-
-
-def _draw_pr_curve(precision,
-                   recall,
-                   iou=0.5,
-                   out_dir='pr_curve',
-                   file_name='precision_recall_curve.jpg'):
-    if not os.path.exists(out_dir):
-        os.makedirs(out_dir)
-    output_path = os.path.join(out_dir, file_name)
-    try:
-        import matplotlib.pyplot as plt
-    except Exception as e:
-        logging.error('Matplotlib not found, plaese install matplotlib.'
-                      'for example: `pip install matplotlib`.')
-        raise e
-    plt.cla()
-    plt.figure('P-R Curve')
-    plt.title('Precision/Recall Curve(IoU={})'.format(iou))
-    plt.xlabel('Recall')
-    plt.ylabel('Precision')
-    plt.grid(True)
-    plt.plot(recall, precision)
-    plt.savefig(output_path)
-
-
-def bbox_area(bbox, is_bbox_normalized):
-    """
-    Calculate area of a bounding box
-    """
-    norm = 1. - float(is_bbox_normalized)
-    width = bbox[2] - bbox[0] + norm
-    height = bbox[3] - bbox[1] + norm
-    return width * height
-
-
-def jaccard_overlap(pred, gt, is_bbox_normalized=False):
-    """
-    Calculate jaccard overlap ratio between two bounding box
-    """
-    if pred[0] >= gt[2] or pred[2] <= gt[0] or \
-        pred[1] >= gt[3] or pred[3] <= gt[1]:
-        return 0.
-    inter_xmin = max(pred[0], gt[0])
-    inter_ymin = max(pred[1], gt[1])
-    inter_xmax = min(pred[2], gt[2])
-    inter_ymax = min(pred[3], gt[3])
-    inter_size = bbox_area([inter_xmin, inter_ymin, inter_xmax, inter_ymax],
-                           is_bbox_normalized)
-    pred_size = bbox_area(pred, is_bbox_normalized)
-    gt_size = bbox_area(gt, is_bbox_normalized)
-    overlap = float(inter_size) / (pred_size + gt_size - inter_size)
-    return overlap
-
-
-def prune_zero_padding(gt_box, gt_label, difficult=None):
-    valid_cnt = 0
-    for i in range(len(gt_box)):
-        if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \
-                gt_box[i, 2] == 0 and gt_box[i, 3] == 0:
-            break
-        valid_cnt += 1
-    return (gt_box[:valid_cnt], gt_label[:valid_cnt], difficult[:valid_cnt]
-            if difficult is not None else None)
-
-
-class DetectionMAP(object):
-    """
-    Calculate detection mean average precision.
-    Currently support two types: 11point and integral
-
-    Args:
-        class_num (int): The class number.
-        overlap_thresh (float): The threshold of overlap
-            ratio between prediction bounding box and
-            ground truth bounding box for deciding
-            true/false positive. Default 0.5.
-        map_type (str): Calculation method of mean average
-            precision, currently support '11point' and
-            'integral'. Default '11point'.
-        is_bbox_normalized (bool): Whether bounding boxes
-            is normalized to range[0, 1]. Default False.
-        evaluate_difficult (bool): Whether to evaluate
-            difficult bounding boxes. Default False.
-        catid2name (dict): Mapping between category id and category name.
-        classwise (bool): Whether per-category AP and draw
-            P-R Curve or not.
-    """
-
-    def __init__(self,
-                 class_num,
-                 overlap_thresh=0.5,
-                 map_type='11point',
-                 is_bbox_normalized=False,
-                 evaluate_difficult=False,
-                 catid2name=None,
-                 classwise=False):
-        self.class_num = class_num
-        self.overlap_thresh = overlap_thresh
-        assert map_type in ['11point', 'integral'], \
-                "map_type currently only support '11point' "\
-                "and 'integral'"
-        self.map_type = map_type
-        self.is_bbox_normalized = is_bbox_normalized
-        self.evaluate_difficult = evaluate_difficult
-        self.classwise = classwise
-        self.classes = []
-        for cname in catid2name.values():
-            self.classes.append(cname)
-        self.reset()
-
-    def update(self, bbox, score, label, gt_box, gt_label, difficult=None):
-        """
-        Update metric statics from given prediction and ground
-        truth infomations.
-        """
-        if difficult is None:
-            difficult = np.zeros_like(gt_label)
-
-        # record class gt count
-        for gtl, diff in zip(gt_label, difficult):
-            if self.evaluate_difficult or int(diff) == 0:
-                self.class_gt_counts[int(np.array(gtl))] += 1
-
-        # record class score positive
-        visited = [False] * len(gt_label)
-        for b, s, l in zip(bbox, score, label):
-            xmin, ymin, xmax, ymax = b.tolist()
-            pred = [xmin, ymin, xmax, ymax]
-            max_idx = -1
-            max_overlap = -1.0
-            for i, gl in enumerate(gt_label):
-                if int(gl) == int(l):
-                    overlap = jaccard_overlap(pred, gt_box[i],
-                                              self.is_bbox_normalized)
-                    if overlap > max_overlap:
-                        max_overlap = overlap
-                        max_idx = i
-
-            if max_overlap > self.overlap_thresh:
-                if self.evaluate_difficult or \
-                        int(np.array(difficult[max_idx])) == 0:
-                    if not visited[max_idx]:
-                        self.class_score_poss[int(l)].append([s, 1.0])
-                        visited[max_idx] = True
-                    else:
-                        self.class_score_poss[int(l)].append([s, 0.0])
-            else:
-                self.class_score_poss[int(l)].append([s, 0.0])
-
-    def reset(self):
-        """
-        Reset metric statics
-        """
-        self.class_score_poss = [[] for _ in range(self.class_num)]
-        self.class_gt_counts = [0] * self.class_num
-        self.mAP = None
-
-    def accumulate(self):
-        """
-        Accumulate metric results and calculate mAP
-        """
-        mAP = 0.
-        valid_cnt = 0
-        eval_results = []
-        for score_pos, count in zip(self.class_score_poss,
-                                    self.class_gt_counts):
-            if count == 0: continue
-            if len(score_pos) == 0:
-                valid_cnt += 1
-                continue
-
-            accum_tp_list, accum_fp_list = \
-                    self._get_tp_fp_accum(score_pos)
-            precision = []
-            recall = []
-            for ac_tp, ac_fp in zip(accum_tp_list, accum_fp_list):
-                precision.append(float(ac_tp) / (ac_tp + ac_fp))
-                recall.append(float(ac_tp) / count)
-
-            one_class_ap = 0.0
-            if self.map_type == '11point':
-                max_precisions = [0.] * 11
-                start_idx = len(precision) - 1
-                for j in range(10, -1, -1):
-                    for i in range(start_idx, -1, -1):
-                        if recall[i] < float(j) / 10.:
-                            start_idx = i
-                            if j > 0:
-                                max_precisions[j - 1] = max_precisions[j]
-                                break
-                        else:
-                            if max_precisions[j] < precision[i]:
-                                max_precisions[j] = precision[i]
-                one_class_ap = sum(max_precisions) / 11.
-                mAP += one_class_ap
-                valid_cnt += 1
-            elif self.map_type == 'integral':
-                import math
-                prev_recall = 0.
-                for i in range(len(precision)):
-                    recall_gap = math.fabs(recall[i] - prev_recall)
-                    if recall_gap > 1e-6:
-                        one_class_ap += precision[i] * recall_gap
-                        prev_recall = recall[i]
-                mAP += one_class_ap
-                valid_cnt += 1
-            else:
-                logging.error("Unspported mAP type {}".format(self.map_type))
-                sys.exit(1)
-            eval_results.append({
-                'class': self.classes[valid_cnt - 1],
-                'ap': one_class_ap,
-                'precision': precision,
-                'recall': recall,
-            })
-        self.eval_results = eval_results
-        self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP
-
-    def get_map(self):
-        """
-        Get mAP result
-        """
-        if self.mAP is None:
-            logging.error("mAP is not calculated.")
-        if self.classwise:
-            # Compute per-category AP and PR curve
-            try:
-                from terminaltables import AsciiTable
-            except Exception as e:
-                logging.error(
-                    'terminaltables not found, plaese install terminaltables. '
-                    'for example: `pip install terminaltables`.')
-                raise e
-            results_per_category = []
-            for eval_result in self.eval_results:
-                results_per_category.append(
-                    (str(eval_result['class']),
-                     '{:0.3f}'.format(float(eval_result['ap']))))
-                _draw_pr_curve(
-                    eval_result['precision'],
-                    eval_result['recall'],
-                    out_dir='voc_pr_curve',
-                    file_name='{}_precision_recall_curve.jpg'.format(
-                        eval_result['class']))
-
-            num_columns = min(6, len(results_per_category) * 2)
-            results_flatten = list(itertools.chain(*results_per_category))
-            headers = ['category', 'AP'] * (num_columns // 2)
-            results_2d = itertools.zip_longest(* [
-                results_flatten[i::num_columns] for i in range(num_columns)
-            ])
-            table_data = [headers]
-            table_data += [result for result in results_2d]
-            table = AsciiTable(table_data)
-            logging.info('Per-category of VOC AP: \n{}'.format(table.table))
-            logging.info(
-                "per-category PR curve has output to voc_pr_curve folder.")
-        return self.mAP
-
-    def _get_tp_fp_accum(self, score_pos_list):
-        """
-        Calculate accumulating true/false positive results from
-        [score, pos] records
-        """
-        sorted_list = sorted(score_pos_list, key=lambda s: s[0], reverse=True)
-        accum_tp = 0
-        accum_fp = 0
-        accum_tp_list = []
-        accum_fp_list = []
-        for (score, pos) in sorted_list:
-            accum_tp += int(pos)
-            accum_tp_list.append(accum_tp)
-            accum_fp += 1 - int(pos)
-            accum_fp_list.append(accum_fp)
-        return accum_tp_list, accum_fp_list

+ 10 - 10
dygraph/paddlex/cv/models/utils/det_metrics/metrics.py

@@ -21,8 +21,7 @@ import sys
 from collections import OrderedDict
 import paddle
 import numpy as np
-
-from .map_utils import prune_zero_padding, DetectionMAP
+from ppdet.metrics.map_utils import prune_zero_padding, DetectionMAP
 from .coco_utils import get_infer_results, cocoapi_eval
 import paddlex.utils.logging as logging
 
@@ -88,22 +87,23 @@ class VOCMetric(Metric):
 
         if bboxes.shape == (1, 1) or bboxes is None:
             return
-        gt_boxes = inputs['gt_bbox'].numpy()
-        gt_labels = inputs['gt_class'].numpy()
-        difficults = inputs['difficult'].numpy(
-        ) if not self.evaluate_difficult else None
+        gt_boxes = inputs['gt_bbox']
+        gt_labels = inputs['gt_class']
+        difficults = inputs['difficult'] if not self.evaluate_difficult \
+            else None
 
         scale_factor = inputs['scale_factor'].numpy(
         ) if 'scale_factor' in inputs else np.ones(
             (gt_boxes.shape[0], 2)).astype('float32')
 
         bbox_idx = 0
-        for i in range(gt_boxes.shape[0]):
-            gt_box = gt_boxes[i]
+        for i in range(len(gt_boxes)):
+            gt_box = gt_boxes[i].numpy()
             h, w = scale_factor[i]
             gt_box = gt_box / np.array([w, h, w, h])
-            gt_label = gt_labels[i]
-            difficult = None if difficults is None else difficults[i]
+            gt_label = gt_labels[i].numpy()
+            difficult = None if difficults is None \
+                else difficults[i].numpy()
             bbox_num = bbox_lengths[i]
             bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
             score = scores[bbox_idx:bbox_idx + bbox_num]

+ 20 - 69
dygraph/paddlex/cv/transforms/batch_operators.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 
 import traceback
-import multiprocessing as mp
 import random
 import numpy as np
 try:
@@ -27,10 +26,10 @@ from paddlex.utils import logging
 
 
 class BatchCompose(Transform):
-    def __init__(self, batch_transforms=None):
+    def __init__(self, batch_transforms=None, collate_batch=True):
         super(BatchCompose, self).__init__()
         self.batch_transforms = batch_transforms
-        self.lock = mp.Lock()
+        self.collate_batch = collate_batch
 
     def __call__(self, samples):
         if self.batch_transforms is not None:
@@ -46,7 +45,23 @@ class BatchCompose(Transform):
 
         samples = _Permute()(samples)
 
-        batch_data = default_collate_fn(samples)
+        extra_key = ['h', 'w', 'flipped']
+        for k in extra_key:
+            for sample in samples:
+                if k in sample:
+                    sample.pop(k)
+
+        if self.collate_batch:
+            batch_data = default_collate_fn(samples)
+        else:
+            batch_data = {}
+            for k in samples[0].keys():
+                tmp_data = []
+                for i in range(len(samples)):
+                    tmp_data.append(samples[i][k])
+                if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
+                    tmp_data = np.stack(tmp_data, axis=0)
+                batch_data[k] = tmp_data
         return batch_data
 
 
@@ -133,10 +148,9 @@ class BatchRandomResizeByShort(Transform):
 
 
 class _BatchPadding(Transform):
-    def __init__(self, pad_to_stride=0, pad_gt=False):
+    def __init__(self, pad_to_stride=0):
         super(_BatchPadding, self).__init__()
         self.pad_to_stride = pad_to_stride
-        self.pad_gt = pad_gt
 
     def __call__(self, samples):
         coarsest_stride = self.pad_to_stride
@@ -155,69 +169,6 @@ class _BatchPadding(Transform):
             padding_im[:im_h, :im_w, :] = im
             data['image'] = padding_im
 
-        if self.pad_gt:
-            gt_num = []
-            if 'gt_poly' in data and data['gt_poly'] is not None and len(data[
-                    'gt_poly']) > 0:
-                pad_mask = True
-            else:
-                pad_mask = False
-
-            if pad_mask:
-                poly_num = []
-                poly_part_num = []
-                point_num = []
-
-            for data in samples:
-                gt_num.append(data['gt_bbox'].shape[0])
-                if pad_mask:
-                    poly_num.append(len(data['gt_poly']))
-                    for poly in data['gt_poly']:
-                        poly_part_num.append(int(len(poly)))
-                        for p_p in poly:
-                            point_num.append(int(len(p_p) / 2))
-            gt_num_max = max(gt_num)
-
-            for i, data in enumerate(samples):
-                gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32)
-                gt_class_data = -np.ones([gt_num_max], dtype=np.int32)
-                is_crowd_data = np.ones([gt_num_max], dtype=np.int32)
-
-                if pad_mask:
-                    poly_num_max = max(poly_num)
-                    poly_part_num_max = max(poly_part_num)
-                    point_num_max = max(point_num)
-                    gt_masks_data = -np.ones(
-                        [poly_num_max, poly_part_num_max, point_num_max, 2],
-                        dtype=np.float32)
-
-                gt_num = data['gt_bbox'].shape[0]
-                gt_box_data[0:gt_num, :] = data['gt_bbox']
-                gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])
-                if 'is_crowd' in data:
-                    is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd'])
-                    data['is_crowd'] = is_crowd_data
-
-                data['gt_bbox'] = gt_box_data
-                data['gt_class'] = gt_class_data
-
-                if pad_mask:
-                    for j, poly in enumerate(data['gt_poly']):
-                        for k, p_p in enumerate(poly):
-                            pp_np = np.array(p_p).reshape(-1, 2)
-                            gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np
-                    data['gt_poly'] = gt_masks_data
-
-                if 'gt_score' in data:
-                    gt_score_data = np.zeros([gt_num_max], dtype=np.float32)
-                    gt_score_data[0:gt_num] = data['gt_score'][:gt_num, 0]
-                    data['gt_score'] = gt_score_data
-
-                if 'difficult' in data:
-                    diff_data = np.zeros([gt_num_max], dtype=np.int32)
-                    diff_data[0:gt_num] = data['difficult'][:gt_num, 0]
-                    data['difficult'] = diff_data
-
         return samples
 
 

+ 1 - 1
dygraph/paddlex/utils/__init__.py

@@ -19,6 +19,6 @@ from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str,
                     DisablePrint)
 from .checkpoint import get_pretrain_weights, load_pretrain_weights
 from .env import get_environ_info, get_num_workers, init_parallel_env
-from .download import download_and_decompress
+from .download import download_and_decompress, decompress
 from .stats import SmoothedValue, TrainingStats
 from .shm import _get_shared_memory_size_in_M

+ 0 - 1
dygraph/tutorials/train/README.md

@@ -12,7 +12,6 @@
 |image_classification/darknet53.py | 图像分类DarkNet53 | 蔬菜分类 |
 |image_classification/xception41.py | 图像分类Xception41 | 蔬菜分类 |
 |image_classification/densenet121.py | 图像分类DenseNet121 | 蔬菜分类 |
-|object_detection/faster_rcnn_r34_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
 |object_detection/faster_rcnn_r50_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
 |object_detection/ppyolo.py | 目标检测PPYOLO | 昆虫检测 |
 |object_detection/ppyolotiny.py | 目标检测PPYOLOTiny | 昆虫检测 |

+ 5 - 6
dygraph/tutorials/train/object_detection/faster_rcnn_r34_fpn.py → dygraph/tutorials/train/object_detection/faster_rcnn_hrnet_w18.py

@@ -40,19 +40,18 @@ eval_dataset = pdx.datasets.VOCDetection(
 # 初始化模型,并进行训练
 # 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/tree/release/2.0-rc/tutorials/train#visualdl可视化训练指标
 num_classes = len(train_dataset.labels)
-model = pdx.models.FasterRCNN(
-    num_classes=num_classes, backbone='ResNet34', with_fpn=True)
+model = pdx.models.FasterRCNN(num_classes=num_classes, backbone='HRNet_W18')
 
 # API说明:https://github.com/PaddlePaddle/PaddleX/blob/release/2.0-rc/paddlex/cv/models/detector.py#L154
 # 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
 model.train(
-    num_epochs=12,
+    num_epochs=24,
     train_dataset=train_dataset,
     train_batch_size=2,
     eval_dataset=eval_dataset,
     learning_rate=0.0025,
-    lr_decay_epochs=[8, 11],
-    warmup_steps=500,
+    lr_decay_epochs=[16, 22],
+    warmup_steps=1000,
     warmup_start_lr=0.00025,
-    save_dir='output/faster_rcnn_r50_fpn',
+    save_dir='output/faster_rcnn_hrnet_w18',
     use_vdl=True)