|
|
@@ -39,6 +39,7 @@ class VOCDetection(Dataset):
|
|
|
一半。
|
|
|
shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
|
|
|
allow_empty (bool): 是否加载负样本。默认为False。
|
|
|
+ empty_ratio (float): 用于指定负样本占总样本数的比例。如果小于0或大于等于1,则保留全部的负样本。默认为1。
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
@@ -48,7 +49,8 @@ class VOCDetection(Dataset):
|
|
|
transforms=None,
|
|
|
num_workers='auto',
|
|
|
shuffle=False,
|
|
|
- allow_empty=False):
|
|
|
+ allow_empty=False,
|
|
|
+ empty_ratio=1.):
|
|
|
# matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
|
|
|
# or matplotlib.backends is imported for the first time
|
|
|
# pycocotools import matplotlib
|
|
|
@@ -73,8 +75,9 @@ class VOCDetection(Dataset):
|
|
|
self.num_workers = get_num_workers(num_workers)
|
|
|
self.shuffle = shuffle
|
|
|
self.allow_empty = allow_empty
|
|
|
+ self.empty_ratio = empty_ratio
|
|
|
self.file_list = list()
|
|
|
- self.neg_file_list = list()
|
|
|
+ neg_file_list = list()
|
|
|
self.labels = list()
|
|
|
|
|
|
annotations = dict()
|
|
|
@@ -265,7 +268,7 @@ class VOCDetection(Dataset):
|
|
|
'file_name': osp.split(img_file)[1]
|
|
|
})
|
|
|
else:
|
|
|
- self.neg_file_list.append({
|
|
|
+ neg_file_list.append({
|
|
|
'image': img_file,
|
|
|
**
|
|
|
im_info,
|
|
|
@@ -282,14 +285,14 @@ class VOCDetection(Dataset):
|
|
|
if not ct:
|
|
|
logging.error(
|
|
|
"No voc record found in %s' % (file_list)", exit=True)
|
|
|
+ self.pos_num = len(self.file_list)
|
|
|
+ if self.allow_empty:
|
|
|
+ self.file_list += self._sample_empty(neg_file_list)
|
|
|
logging.info(
|
|
|
"{} samples in file {}, including {} positive samples and {} negative samples.".
|
|
|
format(
|
|
|
- len(self.file_list) + len(self.neg_file_list), file_list,
|
|
|
- len(self.file_list), len(self.neg_file_list)))
|
|
|
-
|
|
|
- if self.allow_empty:
|
|
|
- self.file_list += self.neg_file_list
|
|
|
+ len(self.file_list), file_list, self.pos_num,
|
|
|
+ len(self.file_list) - self.pos_num))
|
|
|
self.num_samples = len(self.file_list)
|
|
|
self.coco_gt = COCO()
|
|
|
self.coco_gt.dataset = annotations
|
|
|
@@ -323,20 +326,25 @@ class VOCDetection(Dataset):
|
|
|
def set_epoch(self, epoch_id):
|
|
|
self._epoch = epoch_id
|
|
|
|
|
|
- def add_negative_samples(self, image_dir):
|
|
|
+ def add_negative_samples(self, image_dir, empty_ratio=1):
|
|
|
"""将背景图片加入训练
|
|
|
|
|
|
Args:
|
|
|
image_dir (str):背景图片所在的文件夹目录。
|
|
|
+ empty_ratio (float or None): 用于指定负样本占总样本数的比例。如果为None,保留数据集初始化是设置的`empty_ratio`值,
|
|
|
+ 否则更新原有`empty_ratio`值。如果小于0或大于等于1,则保留全部的负样本。默认为1。
|
|
|
|
|
|
"""
|
|
|
import cv2
|
|
|
if not osp.isdir(image_dir):
|
|
|
raise Exception("{} is not a valid image directory.".format(
|
|
|
image_dir))
|
|
|
+ if empty_ratio is not None:
|
|
|
+ self.empty_ratio = empty_ratio
|
|
|
image_list = os.listdir(image_dir)
|
|
|
- max_img_id = len(self.file_list)
|
|
|
- ct = 0
|
|
|
+ max_img_id = max(
|
|
|
+ len(self.file_list) - 1, max(self.coco_gt.getImgIds()))
|
|
|
+ neg_file_list = list()
|
|
|
for image in image_list:
|
|
|
if not is_pic(image):
|
|
|
continue
|
|
|
@@ -366,20 +374,28 @@ class VOCDetection(Dataset):
|
|
|
if 'gt_poly' in self.file_list[0]:
|
|
|
label_info['gt_poly'] = []
|
|
|
|
|
|
- self.neg_file_list.append({
|
|
|
+ neg_file_list.append({
|
|
|
'image': im_fname,
|
|
|
**
|
|
|
im_info,
|
|
|
**
|
|
|
label_info
|
|
|
})
|
|
|
- self.file_list.append({
|
|
|
- 'image': im_fname,
|
|
|
- **
|
|
|
- im_info,
|
|
|
- **
|
|
|
- label_info
|
|
|
- })
|
|
|
- ct += 1
|
|
|
+ self.file_list += self._sample_empty(neg_file_list)
|
|
|
+ logging.info(
|
|
|
+ "{} negative samples added. Dataset contains {} positive samples and {} negative samples.".
|
|
|
+ format(
|
|
|
+ len(self.file_list) - self.num_samples, self.pos_num,
|
|
|
+ len(self.file_list) - self.pos_num))
|
|
|
self.num_samples = len(self.file_list)
|
|
|
- logging.info("{} negative samples added.".format(ct))
|
|
|
+
|
|
|
+ def _sample_empty(self, neg_file_list):
|
|
|
+ if 0. <= self.empty_ratio < 1.:
|
|
|
+ import random
|
|
|
+ total_num = len(self.file_list)
|
|
|
+ neg_num = total_num - self.pos_num
|
|
|
+ sample_num = min((total_num * self.empty_ratio - neg_num) //
|
|
|
+ (1 - self.empty_ratio), len(neg_file_list))
|
|
|
+ return random.sample(neg_file_list, sample_num)
|
|
|
+ else:
|
|
|
+ return neg_file_list
|