浏览代码

determine num_max_boxes according to dataset

will-jl944 4 年之前
父节点
当前提交
85784c783f
共有 3 个文件被更改,包括 16 次插入3 次删除
  1. 7 0
      dygraph/paddlex/cv/datasets/coco.py
  2. 7 0
      dygraph/paddlex/cv/datasets/voc.py
  3. 2 3
      dygraph/paddlex/cv/models/detector.py

+ 7 - 0
dygraph/paddlex/cv/datasets/coco.py

@@ -57,12 +57,14 @@ class CocoDetection(VOCDetection):
         super(VOCDetection, self).__init__()
         self.data_fields = None
         self.transforms = copy.deepcopy(transforms)
+        self.num_max_boxes = 50
         self.use_mix = False
         if self.transforms is not None:
             for op in self.transforms.transforms:
                 if isinstance(op, MixupImage):
                     self.mixup_op = copy.deepcopy(op)
                     self.use_mix = True
+                    self.num_max_boxes *= 2
                     break
 
         self.batch_transforms = None
@@ -153,6 +155,11 @@ class CocoDetection(VOCDetection):
                 **
                 label_info
             }))
+        if self.use_mix:
+            self.num_max_boxes = max(self.num_max_boxes, 2 * len(instances))
+        else:
+            self.num_max_boxes = max(self.num_max_boxes, len(instances))
+
         if not len(self.file_list) > 0:
             raise Exception('not found any coco record in %s' % ann_file)
         logging.info("{} samples in file {}".format(

+ 7 - 0
dygraph/paddlex/cv/datasets/voc.py

@@ -56,6 +56,7 @@ class VOCDetection(Dataset):
         super(VOCDetection, self).__init__()
         self.data_fields = None
         self.transforms = copy.deepcopy(transforms)
+        self.num_max_boxes = 50
 
         self.use_mix = False
         if self.transforms is not None:
@@ -63,6 +64,7 @@ class VOCDetection(Dataset):
                 if isinstance(op, MixupImage):
                     self.mixup_op = copy.deepcopy(op)
                     self.use_mix = True
+                    self.num_max_boxes *= 2
                     break
 
         self.batch_transforms = None
@@ -257,6 +259,11 @@ class VOCDetection(Dataset):
                         'id': int(im_id[0]),
                         'file_name': osp.split(img_file)[1]
                     })
+                if self.use_mix:
+                    self.num_max_boxes = max(self.num_max_boxes, 2 * len(objs))
+                else:
+                    self.num_max_boxes = max(self.num_max_boxes, len(objs))
+
         if not len(self.file_list) > 0:
             raise Exception('not found any voc record in %s' % (file_list))
         logging.info("{} samples in file {}".format(

+ 2 - 3
dygraph/paddlex/cv/models/detector.py

@@ -192,9 +192,10 @@ class BaseDetector(BaseModel):
                 "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
             self.metric = metric.lower()
 
+        self.labels = train_dataset.labels
+        self.num_max_boxes = train_dataset.num_max_boxes
         train_dataset.batch_transforms = self._compose_batch_transform(
             train_dataset.transforms, mode='train')
-        self.labels = train_dataset.labels
 
         # build optimizer if not defined
         if optimizer is None:
@@ -1178,7 +1179,6 @@ class PPYOLOTiny(YOLOv3):
         self.anchors = anchors
         self.anchor_masks = anchor_masks
         self.downsample_ratios = downsample_ratios
-        self.num_max_boxes = 100
         self.model_name = 'PPYOLOTiny'
 
 
@@ -1302,7 +1302,6 @@ class PPYOLOv2(YOLOv3):
         self.anchors = anchors
         self.anchor_masks = anchor_masks
         self.downsample_ratios = downsample_ratios
-        self.num_max_boxes = 100
         self.model_name = 'PPYOLOv2'