|
|
@@ -34,6 +34,7 @@ class CocoDetection(VOCDetection):
|
|
|
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。
|
|
|
shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
|
|
|
allow_empty (bool): 是否加载负样本。默认为False。
|
|
|
+ empty_ratio (float): 用于指定负样本占总样本数的比例。如果小于0或大于等于1,则保留全部的负样本。默认为1。
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
@@ -42,7 +43,8 @@ class CocoDetection(VOCDetection):
|
|
|
transforms=None,
|
|
|
num_workers='auto',
|
|
|
shuffle=False,
|
|
|
- allow_empty=False):
|
|
|
+ allow_empty=False,
|
|
|
+ empty_ratio=1.):
|
|
|
# matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
|
|
|
# or matplotlib.backends is imported for the first time
|
|
|
# pycocotools import matplotlib
|
|
|
@@ -73,8 +75,9 @@ class CocoDetection(VOCDetection):
|
|
|
self.num_workers = get_num_workers(num_workers)
|
|
|
self.shuffle = shuffle
|
|
|
self.allow_empty = allow_empty
|
|
|
+ self.empty_ratio = empty_ratio
|
|
|
self.file_list = list()
|
|
|
- self.neg_file_list = list()
|
|
|
+ neg_file_list = list()
|
|
|
self.labels = list()
|
|
|
|
|
|
coco = COCO(ann_file)
|
|
|
@@ -165,7 +168,7 @@ class CocoDetection(VOCDetection):
|
|
|
}
|
|
|
|
|
|
if is_empty:
|
|
|
- self.neg_file_list.append({
|
|
|
+ neg_file_list.append({
|
|
|
'image': im_fname,
|
|
|
**
|
|
|
im_info,
|
|
|
@@ -191,14 +194,14 @@ class CocoDetection(VOCDetection):
|
|
|
if not ct:
|
|
|
logging.error(
|
|
|
"No coco record found in %s' % (ann_file)", exit=True)
|
|
|
+ self.pos_num = len(self.file_list)
|
|
|
+ if self.allow_empty:
|
|
|
+ self.file_list += self._sample_empty(neg_file_list)
|
|
|
logging.info(
|
|
|
"{} samples in file {}, including {} positive samples and {} negative samples.".
|
|
|
format(
|
|
|
- len(self.file_list) + len(self.neg_file_list), ann_file,
|
|
|
- len(self.file_list), len(self.neg_file_list)))
|
|
|
-
|
|
|
- if self.allow_empty:
|
|
|
- self.file_list += self.neg_file_list
|
|
|
+ len(self.file_list), ann_file, self.pos_num,
|
|
|
+ len(self.file_list) - self.pos_num))
|
|
|
self.num_samples = len(self.file_list)
|
|
|
|
|
|
self._epoch = 0
|