소스 검색

add empty_ratio for detection dataset

will-jl944 4 년 전
부모
커밋
c92208f7df
1개의 변경된 파일11개의 추가작업 그리고 8개의 파일을 삭제
  1. 11 8
      paddlex/cv/datasets/coco.py

+ 11 - 8
paddlex/cv/datasets/coco.py

@@ -34,6 +34,7 @@ class CocoDetection(VOCDetection):
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
         allow_empty (bool): 是否加载负样本。默认为False。
+        empty_ratio (float): 用于指定负样本占总样本数的比例。如果小于0或大于等于1,则保留全部的负样本。默认为1。
     """
 
     def __init__(self,
@@ -42,7 +43,8 @@ class CocoDetection(VOCDetection):
                  transforms=None,
                  num_workers='auto',
                  shuffle=False,
-                 allow_empty=False):
+                 allow_empty=False,
+                 empty_ratio=1.):
         # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
         # or matplotlib.backends is imported for the first time
         # pycocotools import matplotlib
@@ -73,8 +75,9 @@ class CocoDetection(VOCDetection):
         self.num_workers = get_num_workers(num_workers)
         self.shuffle = shuffle
         self.allow_empty = allow_empty
+        self.empty_ratio = empty_ratio
         self.file_list = list()
-        self.neg_file_list = list()
+        neg_file_list = list()
         self.labels = list()
 
         coco = COCO(ann_file)
@@ -165,7 +168,7 @@ class CocoDetection(VOCDetection):
             }
 
             if is_empty:
-                self.neg_file_list.append({
+                neg_file_list.append({
                     'image': im_fname,
                     **
                     im_info,
@@ -191,14 +194,14 @@ class CocoDetection(VOCDetection):
         if not ct:
             logging.error(
                 "No coco record found in %s' % (ann_file)", exit=True)
+        self.pos_num = len(self.file_list)
+        if self.allow_empty:
+            self.file_list += self._sample_empty(neg_file_list)
         logging.info(
             "{} samples in file {}, including {} positive samples and {} negative samples.".
             format(
-                len(self.file_list) + len(self.neg_file_list), ann_file,
-                len(self.file_list), len(self.neg_file_list)))
-
-        if self.allow_empty:
-            self.file_list += self.neg_file_list
+                len(self.file_list), ann_file, self.pos_num,
+                len(self.file_list) - self.pos_num))
         self.num_samples = len(self.file_list)
 
         self._epoch = 0