Răsfoiți Sursa

add empty_ratio for detection dataset

will-jl944 4 ani în urmă
părinte
comite
dcacb24e08
1 a modificat fișierele cu 37 adăugiri și 21 ștergeri
  1. 37 21
      paddlex/cv/datasets/voc.py

+ 37 - 21
paddlex/cv/datasets/voc.py

@@ -39,6 +39,7 @@ class VOCDetection(Dataset):
             一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         allow_empty (bool): 是否加载负样本。默认为False。
+        empty_ratio (float): 用于指定负样本占总样本数的比例。如果小于0或大于等于1,则保留全部的负样本。默认为1。
     """
 
     def __init__(self,
@@ -48,7 +49,8 @@ class VOCDetection(Dataset):
                  transforms=None,
                  num_workers='auto',
                  shuffle=False,
-                 allow_empty=False):
+                 allow_empty=False,
+                 empty_ratio=1.):
         # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
         # or matplotlib.backends is imported for the first time
         # pycocotools import matplotlib
@@ -73,8 +75,9 @@ class VOCDetection(Dataset):
         self.num_workers = get_num_workers(num_workers)
         self.shuffle = shuffle
         self.allow_empty = allow_empty
+        self.empty_ratio = empty_ratio
         self.file_list = list()
-        self.neg_file_list = list()
+        neg_file_list = list()
         self.labels = list()
 
         annotations = dict()
@@ -265,7 +268,7 @@ class VOCDetection(Dataset):
                         'file_name': osp.split(img_file)[1]
                     })
                 else:
-                    self.neg_file_list.append({
+                    neg_file_list.append({
                         'image': img_file,
                         **
                         im_info,
@@ -282,14 +285,14 @@ class VOCDetection(Dataset):
         if not ct:
             logging.error(
                 "No voc record found in %s' % (file_list)", exit=True)
+        self.pos_num = len(self.file_list)
+        if self.allow_empty:
+            self.file_list += self._sample_empty(neg_file_list)
         logging.info(
             "{} samples in file {}, including {} positive samples and {} negative samples.".
             format(
-                len(self.file_list) + len(self.neg_file_list), file_list,
-                len(self.file_list), len(self.neg_file_list)))
-
-        if self.allow_empty:
-            self.file_list += self.neg_file_list
+                len(self.file_list), file_list, self.pos_num,
+                len(self.file_list) - self.pos_num))
         self.num_samples = len(self.file_list)
         self.coco_gt = COCO()
         self.coco_gt.dataset = annotations
@@ -323,20 +326,25 @@ class VOCDetection(Dataset):
     def set_epoch(self, epoch_id):
         self._epoch = epoch_id
 
-    def add_negative_samples(self, image_dir):
+    def add_negative_samples(self, image_dir, empty_ratio=1):
         """将背景图片加入训练
 
         Args:
             image_dir (str):背景图片所在的文件夹目录。
+            empty_ratio (float or None): 用于指定负样本占总样本数的比例。如果为None,保留数据集初始化是设置的`empty_ratio`值,
+                否则更新原有`empty_ratio`值。如果小于0或大于等于1,则保留全部的负样本。默认为1。
 
         """
         import cv2
         if not osp.isdir(image_dir):
             raise Exception("{} is not a valid image directory.".format(
                 image_dir))
+        if empty_ratio is not None:
+            self.empty_ratio = empty_ratio
         image_list = os.listdir(image_dir)
-        max_img_id = len(self.file_list)
-        ct = 0
+        max_img_id = max(
+            len(self.file_list) - 1, max(self.coco_gt.getImgIds()))
+        neg_file_list = list()
         for image in image_list:
             if not is_pic(image):
                 continue
@@ -366,20 +374,28 @@ class VOCDetection(Dataset):
             if 'gt_poly' in self.file_list[0]:
                 label_info['gt_poly'] = []
 
-            self.neg_file_list.append({
+            neg_file_list.append({
                 'image': im_fname,
                 **
                 im_info,
                 **
                 label_info
             })
-            self.file_list.append({
-                'image': im_fname,
-                **
-                im_info,
-                **
-                label_info
-            })
-            ct += 1
+        self.file_list += self._sample_empty(neg_file_list)
+        logging.info(
+            "{} negative samples added. Dataset contains {} positive samples and {} negative samples.".
+            format(
+                len(self.file_list) - self.num_samples, self.pos_num,
+                len(self.file_list) - self.pos_num))
         self.num_samples = len(self.file_list)
-        logging.info("{} negative samples added.".format(ct))
+
+    def _sample_empty(self, neg_file_list):
+        if 0. <= self.empty_ratio < 1.:
+            import random
+            total_num = len(self.file_list)
+            neg_num = total_num - self.pos_num
+            sample_num = min((total_num * self.empty_ratio - neg_num) //
+                             (1 - self.empty_ratio), len(neg_file_list))
+            return random.sample(neg_file_list, sample_num)
+        else:
+            return neg_file_list