|
@@ -37,6 +37,7 @@ class VOCDetection(Dataset):
|
|
|
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
|
|
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
|
|
|
一半。
|
|
一半。
|
|
|
shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
|
|
shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
|
|
|
|
|
+ allow_empty (bool): 是否加载负样本。默认为False。
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
def __init__(self,
|
|
@@ -45,7 +46,8 @@ class VOCDetection(Dataset):
|
|
|
label_list,
|
|
label_list,
|
|
|
transforms=None,
|
|
transforms=None,
|
|
|
num_workers='auto',
|
|
num_workers='auto',
|
|
|
- shuffle=False):
|
|
|
|
|
|
|
+ shuffle=False,
|
|
|
|
|
+ allow_empty=False):
|
|
|
# matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
|
|
# matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
|
|
|
# or matplotlib.backends is imported for the first time
|
|
# or matplotlib.backends is imported for the first time
|
|
|
# pycocotools import matplotlib
|
|
# pycocotools import matplotlib
|
|
@@ -69,7 +71,9 @@ class VOCDetection(Dataset):
|
|
|
self.batch_transforms = None
|
|
self.batch_transforms = None
|
|
|
self.num_workers = get_num_workers(num_workers)
|
|
self.num_workers = get_num_workers(num_workers)
|
|
|
self.shuffle = shuffle
|
|
self.shuffle = shuffle
|
|
|
|
|
+ self.allow_empty = allow_empty
|
|
|
self.file_list = list()
|
|
self.file_list = list()
|
|
|
|
|
+ self.neg_file_list = list()
|
|
|
self.labels = list()
|
|
self.labels = list()
|
|
|
|
|
|
|
|
annotations = dict()
|
|
annotations = dict()
|
|
@@ -121,17 +125,10 @@ class VOCDetection(Dataset):
|
|
|
continue
|
|
continue
|
|
|
tree = ET.parse(xml_file)
|
|
tree = ET.parse(xml_file)
|
|
|
if tree.find('id') is None:
|
|
if tree.find('id') is None:
|
|
|
- im_id = np.array([ct])
|
|
|
|
|
|
|
+ im_id = np.asarray([ct])
|
|
|
else:
|
|
else:
|
|
|
ct = int(tree.find('id').text)
|
|
ct = int(tree.find('id').text)
|
|
|
- im_id = np.array([int(tree.find('id').text)])
|
|
|
|
|
- pattern = re.compile('<object>', re.IGNORECASE)
|
|
|
|
|
- obj_match = pattern.findall(
|
|
|
|
|
- str(ET.tostringlist(tree.getroot())))
|
|
|
|
|
- if len(obj_match) == 0:
|
|
|
|
|
- continue
|
|
|
|
|
- obj_tag = obj_match[0][1:-1]
|
|
|
|
|
- objs = tree.findall(obj_tag)
|
|
|
|
|
|
|
+ im_id = np.asarray([int(tree.find('id').text)])
|
|
|
pattern = re.compile('<size>', re.IGNORECASE)
|
|
pattern = re.compile('<size>', re.IGNORECASE)
|
|
|
size_tag = pattern.findall(
|
|
size_tag = pattern.findall(
|
|
|
str(ET.tostringlist(tree.getroot())))
|
|
str(ET.tostringlist(tree.getroot())))
|
|
@@ -149,18 +146,26 @@ class VOCDetection(Dataset):
|
|
|
else:
|
|
else:
|
|
|
im_w = 0
|
|
im_w = 0
|
|
|
im_h = 0
|
|
im_h = 0
|
|
|
- gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
|
|
|
|
|
- gt_class = np.zeros((len(objs), 1), dtype=np.int32)
|
|
|
|
|
- gt_score = np.ones((len(objs), 1), dtype=np.float32)
|
|
|
|
|
- is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
|
|
|
|
|
- difficult = np.zeros((len(objs), 1), dtype=np.int32)
|
|
|
|
|
- skipped_indices = list()
|
|
|
|
|
|
|
+
|
|
|
|
|
+ pattern = re.compile('<object>', re.IGNORECASE)
|
|
|
|
|
+ obj_match = pattern.findall(
|
|
|
|
|
+ str(ET.tostringlist(tree.getroot())))
|
|
|
|
|
+ if len(obj_match) > 0:
|
|
|
|
|
+ obj_tag = obj_match[0][1:-1]
|
|
|
|
|
+ objs = tree.findall(obj_tag)
|
|
|
|
|
+ else:
|
|
|
|
|
+ objs = list()
|
|
|
|
|
+
|
|
|
|
|
+ gt_bbox = list()
|
|
|
|
|
+ gt_class = list()
|
|
|
|
|
+ gt_score = list()
|
|
|
|
|
+ is_crowd = list()
|
|
|
|
|
+ difficult = list()
|
|
|
for i, obj in enumerate(objs):
|
|
for i, obj in enumerate(objs):
|
|
|
pattern = re.compile('<name>', re.IGNORECASE)
|
|
pattern = re.compile('<name>', re.IGNORECASE)
|
|
|
name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
|
|
name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
|
|
|
1:-1]
|
|
1:-1]
|
|
|
cname = obj.find(name_tag).text.strip()
|
|
cname = obj.find(name_tag).text.strip()
|
|
|
- gt_class[i][0] = cname2cid[cname]
|
|
|
|
|
pattern = re.compile('<difficult>', re.IGNORECASE)
|
|
pattern = re.compile('<difficult>', re.IGNORECASE)
|
|
|
diff_tag = pattern.findall(str(ET.tostringlist(obj)))
|
|
diff_tag = pattern.findall(str(ET.tostringlist(obj)))
|
|
|
if len(diff_tag) == 0:
|
|
if len(diff_tag) == 0:
|
|
@@ -204,15 +209,16 @@ class VOCDetection(Dataset):
|
|
|
y2 = min(im_h - 1, y2)
|
|
y2 = min(im_h - 1, y2)
|
|
|
|
|
|
|
|
if not (x2 >= x1 and y2 >= y1):
|
|
if not (x2 >= x1 and y2 >= y1):
|
|
|
- skipped_indices.append(i)
|
|
|
|
|
logging.warning(
|
|
logging.warning(
|
|
|
"Bounding box for object {} does not satisfy x1 <= x2 and y1 <= y2, "
|
|
"Bounding box for object {} does not satisfy x1 <= x2 and y1 <= y2, "
|
|
|
"so this object is skipped".format(i))
|
|
"so this object is skipped".format(i))
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
- gt_bbox[i] = [x1, y1, x2, y2]
|
|
|
|
|
- is_crowd[i][0] = 0
|
|
|
|
|
- difficult[i][0] = _difficult
|
|
|
|
|
|
|
+ gt_bbox.append([x1, y1, x2, y2])
|
|
|
|
|
+ gt_class.append([cname2cid[cname]])
|
|
|
|
|
+ gt_score.append([1.])
|
|
|
|
|
+ is_crowd.append(0)
|
|
|
|
|
+ difficult.append([_difficult])
|
|
|
annotations['annotations'].append({
|
|
annotations['annotations'].append({
|
|
|
'iscrowd': 0,
|
|
'iscrowd': 0,
|
|
|
'image_id': int(im_id[0]),
|
|
'image_id': int(im_id[0]),
|
|
@@ -224,16 +230,16 @@ class VOCDetection(Dataset):
|
|
|
})
|
|
})
|
|
|
ann_ct += 1
|
|
ann_ct += 1
|
|
|
|
|
|
|
|
- if skipped_indices:
|
|
|
|
|
- gt_bbox = np.delete(gt_bbox, skipped_indices, axis=0)
|
|
|
|
|
- gt_class = np.delete(gt_class, skipped_indices, axis=0)
|
|
|
|
|
- gt_score = np.delete(gt_score, skipped_indices, axis=0)
|
|
|
|
|
- is_crowd = np.delete(is_crowd, skipped_indices, axis=0)
|
|
|
|
|
- difficult = np.delete(difficult, skipped_indices, axis=0)
|
|
|
|
|
|
|
+ gt_bbox = np.array(gt_bbox, dtype=np.float32)
|
|
|
|
|
+ gt_class = np.array(gt_class, dtype=np.int32)
|
|
|
|
|
+ gt_score = np.array(gt_score, dtype=np.float32)
|
|
|
|
|
+ is_crowd = np.array(is_crowd, dtype=np.int32)
|
|
|
|
|
+ difficult = np.array(difficult, dtype=np.int32)
|
|
|
|
|
|
|
|
im_info = {
|
|
im_info = {
|
|
|
'im_id': im_id,
|
|
'im_id': im_id,
|
|
|
- 'image_shape': np.array([im_h, im_w]).astype('int32'),
|
|
|
|
|
|
|
+ 'image_shape': np.array(
|
|
|
|
|
+ [im_h, im_w], dtype=np.int32),
|
|
|
}
|
|
}
|
|
|
label_info = {
|
|
label_info = {
|
|
|
'is_crowd': is_crowd,
|
|
'is_crowd': is_crowd,
|
|
@@ -243,7 +249,7 @@ class VOCDetection(Dataset):
|
|
|
'difficult': difficult
|
|
'difficult': difficult
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if gt_bbox.size != 0:
|
|
|
|
|
|
|
+ if gt_bbox.size > 0:
|
|
|
self.file_list.append({
|
|
self.file_list.append({
|
|
|
'image': img_file,
|
|
'image': img_file,
|
|
|
**
|
|
**
|
|
@@ -251,22 +257,38 @@ class VOCDetection(Dataset):
|
|
|
**
|
|
**
|
|
|
label_info
|
|
label_info
|
|
|
})
|
|
})
|
|
|
- ct += 1
|
|
|
|
|
annotations['images'].append({
|
|
annotations['images'].append({
|
|
|
'height': im_h,
|
|
'height': im_h,
|
|
|
'width': im_w,
|
|
'width': im_w,
|
|
|
'id': int(im_id[0]),
|
|
'id': int(im_id[0]),
|
|
|
'file_name': osp.split(img_file)[1]
|
|
'file_name': osp.split(img_file)[1]
|
|
|
})
|
|
})
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.neg_file_list.append({
|
|
|
|
|
+ 'image': img_file,
|
|
|
|
|
+ **
|
|
|
|
|
+ im_info,
|
|
|
|
|
+ **
|
|
|
|
|
+ label_info
|
|
|
|
|
+ })
|
|
|
|
|
+ ct += 1
|
|
|
|
|
+
|
|
|
if self.use_mix:
|
|
if self.use_mix:
|
|
|
self.num_max_boxes = max(self.num_max_boxes, 2 * len(objs))
|
|
self.num_max_boxes = max(self.num_max_boxes, 2 * len(objs))
|
|
|
else:
|
|
else:
|
|
|
self.num_max_boxes = max(self.num_max_boxes, len(objs))
|
|
self.num_max_boxes = max(self.num_max_boxes, len(objs))
|
|
|
|
|
|
|
|
- if not len(self.file_list) > 0:
|
|
|
|
|
- raise Exception('not found any voc record in %s' % (file_list))
|
|
|
|
|
- logging.info("{} samples in file {}".format(
|
|
|
|
|
- len(self.file_list), file_list))
|
|
|
|
|
|
|
+ if not ct:
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "No voc record found in %s' % (file_list)", exit=True)
|
|
|
|
|
+ logging.info(
|
|
|
|
|
+ "{} samples in file {}, including {} positive samples and {} negative samples.".
|
|
|
|
|
+ format(
|
|
|
|
|
+ len(self.file_list) + len(self.neg_file_list), file_list,
|
|
|
|
|
+ len(self.file_list), len(self.neg_file_list)))
|
|
|
|
|
+
|
|
|
|
|
+ if self.allow_empty:
|
|
|
|
|
+ self.file_list += self.neg_file_list
|
|
|
self.num_samples = len(self.file_list)
|
|
self.num_samples = len(self.file_list)
|
|
|
self.coco_gt = COCO()
|
|
self.coco_gt = COCO()
|
|
|
self.coco_gt.dataset = annotations
|
|
self.coco_gt.dataset = annotations
|