Przeglądaj źródła

support training with negative samples for det models

will-jl944 4 lat temu
rodzic
commit
9edd5d7423
2 zmienionych plików z 115 dodań i 58 usunięć
  1. 59 24
      paddlex/cv/datasets/coco.py
  2. 56 34
      paddlex/cv/datasets/voc.py

+ 59 - 24
paddlex/cv/datasets/coco.py

@@ -40,7 +40,8 @@ class CocoDetection(VOCDetection):
                  ann_file,
                  transforms=None,
                  num_workers='auto',
-                 shuffle=False):
+                 shuffle=False,
+                 allow_empty=False):
         # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
         # or matplotlib.backends is imported for the first time
         # pycocotools import matplotlib
@@ -70,12 +71,14 @@ class CocoDetection(VOCDetection):
         self.batch_transforms = None
         self.num_workers = get_num_workers(num_workers)
         self.shuffle = shuffle
+        self.allow_empty = allow_empty
         self.file_list = list()
+        self.neg_file_list = list()
         self.labels = list()
 
         coco = COCO(ann_file)
         self.coco_gt = coco
-        img_ids = coco.getImgIds()
+        img_ids = sorted(coco.getImgIds())
         cat_ids = coco.getCatIds()
         catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
         cname2clsid = dict({
@@ -85,7 +88,10 @@ class CocoDetection(VOCDetection):
         for label, cid in sorted(cname2clsid.items(), key=lambda d: d[1]):
             self.labels.append(label)
         logging.info("Starting to read file list from dataset...")
+
+        ct = 0
         for img_id in img_ids:
+            is_empty = False
             img_anno = coco.loadImgs(img_id)[0]
             im_fname = osp.join(data_dir, img_anno['file_name'])
             if not is_pic(im_fname):
@@ -111,6 +117,11 @@ class CocoDetection(VOCDetection):
                         "im_id: {}, area: {} x1: {}, y1: {}, x2: {}, y2: {}."
                         .format(img_id, float(inst['area']), x1, y1, x2, y2))
             num_bbox = len(bboxes)
+            if num_bbox == 0 and not self.allow_empty:
+                continue
+            elif num_bbox == 0:
+                is_empty = True
+
             gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
             gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
             gt_score = np.ones((num_bbox, 1), dtype=np.float32)
@@ -125,11 +136,19 @@ class CocoDetection(VOCDetection):
                 gt_bbox[i, :] = box['clean_bbox']
                 is_crowd[i][0] = box['iscrowd']
                 if 'segmentation' in box and box['iscrowd'] == 1:
-                    gt_poly[i] = [[0.0, 0.0], ]
+                    gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
                 elif 'segmentation' in box and box['segmentation']:
-                    gt_poly[i] = box['segmentation']
+                    if not np.array(box[
+                            'segmentation']).size > 0 and not self.allow_empty:
+                        bboxes.pop(i)
+                        gt_poly.pop(i)
+                        np.delete(is_crowd, i)
+                        np.delete(gt_class, i)
+                        np.delete(gt_bbox, i)
+                    else:
+                        gt_poly[i] = box['segmentation']
                     has_segmentation = True
-            if has_segmentation and not any(gt_poly):
+            if has_segmentation and not any(gt_poly) and not self.allow_empty:
                 continue
 
             im_info = {
@@ -145,25 +164,41 @@ class CocoDetection(VOCDetection):
                 'difficult': difficult
             }
 
-            if None in gt_poly:
-                del label_info['gt_poly']
-
-            self.file_list.append(({
-                'image': im_fname,
-                **
-                im_info,
-                **
-                label_info
-            }))
-        if self.use_mix:
-            self.num_max_boxes = max(self.num_max_boxes, 2 * len(instances))
-        else:
-            self.num_max_boxes = max(self.num_max_boxes, len(instances))
-
-        if not len(self.file_list) > 0:
-            raise Exception('not found any coco record in %s' % ann_file)
-        logging.info("{} samples in file {}".format(
-            len(self.file_list), ann_file))
+            if is_empty:
+                self.neg_file_list.append({
+                    'image': im_fname,
+                    **
+                    im_info,
+                    **
+                    label_info
+                })
+            else:
+                self.file_list.append({
+                    'image': im_fname,
+                    **
+                    im_info,
+                    **
+                    label_info
+                })
+            ct += 1
+
+            if self.use_mix:
+                self.num_max_boxes = max(self.num_max_boxes,
+                                         2 * len(instances))
+            else:
+                self.num_max_boxes = max(self.num_max_boxes, len(instances))
+
+        if not ct:
+            logging.error(
+                "No coco record found in %s' % (ann_file)", exit=True)
+        logging.info(
+            "{} samples in file {}, including {} positive samples and {} negative samples.".
+            format(
+                len(self.file_list) + len(self.neg_file_list), ann_file,
+                len(self.file_list), len(self.neg_file_list)))
+
+        if self.allow_empty:
+            self.file_list += self.neg_file_list
         self.num_samples = len(self.file_list)
 
         self._epoch = 0

+ 56 - 34
paddlex/cv/datasets/voc.py

@@ -37,6 +37,7 @@ class VOCDetection(Dataset):
             系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
             一半。
         shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
+        allow_empty (bool): 是否加载负样本。默认为False。
     """
 
     def __init__(self,
@@ -45,7 +46,8 @@ class VOCDetection(Dataset):
                  label_list,
                  transforms=None,
                  num_workers='auto',
-                 shuffle=False):
+                 shuffle=False,
+                 allow_empty=False):
         # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
         # or matplotlib.backends is imported for the first time
         # pycocotools import matplotlib
@@ -69,7 +71,9 @@ class VOCDetection(Dataset):
         self.batch_transforms = None
         self.num_workers = get_num_workers(num_workers)
         self.shuffle = shuffle
+        self.allow_empty = allow_empty
         self.file_list = list()
+        self.neg_file_list = list()
         self.labels = list()
 
         annotations = dict()
@@ -121,17 +125,10 @@ class VOCDetection(Dataset):
                     continue
                 tree = ET.parse(xml_file)
                 if tree.find('id') is None:
-                    im_id = np.array([ct])
+                    im_id = np.asarray([ct])
                 else:
                     ct = int(tree.find('id').text)
-                    im_id = np.array([int(tree.find('id').text)])
-                pattern = re.compile('<object>', re.IGNORECASE)
-                obj_match = pattern.findall(
-                    str(ET.tostringlist(tree.getroot())))
-                if len(obj_match) == 0:
-                    continue
-                obj_tag = obj_match[0][1:-1]
-                objs = tree.findall(obj_tag)
+                    im_id = np.asarray([int(tree.find('id').text)])
                 pattern = re.compile('<size>', re.IGNORECASE)
                 size_tag = pattern.findall(
                     str(ET.tostringlist(tree.getroot())))
@@ -149,18 +146,26 @@ class VOCDetection(Dataset):
                 else:
                     im_w = 0
                     im_h = 0
-                gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
-                gt_class = np.zeros((len(objs), 1), dtype=np.int32)
-                gt_score = np.ones((len(objs), 1), dtype=np.float32)
-                is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
-                difficult = np.zeros((len(objs), 1), dtype=np.int32)
-                skipped_indices = list()
+
+                pattern = re.compile('<object>', re.IGNORECASE)
+                obj_match = pattern.findall(
+                    str(ET.tostringlist(tree.getroot())))
+                if len(obj_match) > 0:
+                    obj_tag = obj_match[0][1:-1]
+                    objs = tree.findall(obj_tag)
+                else:
+                    objs = list()
+
+                gt_bbox = list()
+                gt_class = list()
+                gt_score = list()
+                is_crowd = list()
+                difficult = list()
                 for i, obj in enumerate(objs):
                     pattern = re.compile('<name>', re.IGNORECASE)
                     name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
                         1:-1]
                     cname = obj.find(name_tag).text.strip()
-                    gt_class[i][0] = cname2cid[cname]
                     pattern = re.compile('<difficult>', re.IGNORECASE)
                     diff_tag = pattern.findall(str(ET.tostringlist(obj)))
                     if len(diff_tag) == 0:
@@ -204,15 +209,16 @@ class VOCDetection(Dataset):
                         y2 = min(im_h - 1, y2)
 
                     if not (x2 >= x1 and y2 >= y1):
-                        skipped_indices.append(i)
                         logging.warning(
                             "Bounding box for object {} does not satisfy x1 <= x2 and y1 <= y2, "
                             "so this object is skipped".format(i))
                         continue
 
-                    gt_bbox[i] = [x1, y1, x2, y2]
-                    is_crowd[i][0] = 0
-                    difficult[i][0] = _difficult
+                    gt_bbox.append([x1, y1, x2, y2])
+                    gt_class.append([cname2cid[cname]])
+                    gt_score.append([1.])
+                    is_crowd.append(0)
+                    difficult.append([_difficult])
                     annotations['annotations'].append({
                         'iscrowd': 0,
                         'image_id': int(im_id[0]),
@@ -224,16 +230,16 @@ class VOCDetection(Dataset):
                     })
                     ann_ct += 1
 
-                if skipped_indices:
-                    gt_bbox = np.delete(gt_bbox, skipped_indices, axis=0)
-                    gt_class = np.delete(gt_class, skipped_indices, axis=0)
-                    gt_score = np.delete(gt_score, skipped_indices, axis=0)
-                    is_crowd = np.delete(is_crowd, skipped_indices, axis=0)
-                    difficult = np.delete(difficult, skipped_indices, axis=0)
+                gt_bbox = np.array(gt_bbox, dtype=np.float32)
+                gt_class = np.array(gt_class, dtype=np.int32)
+                gt_score = np.array(gt_score, dtype=np.float32)
+                is_crowd = np.array(is_crowd, dtype=np.int32)
+                difficult = np.array(difficult, dtype=np.int32)
 
                 im_info = {
                     'im_id': im_id,
-                    'image_shape': np.array([im_h, im_w]).astype('int32'),
+                    'image_shape': np.array(
+                        [im_h, im_w], dtype=np.int32),
                 }
                 label_info = {
                     'is_crowd': is_crowd,
@@ -243,7 +249,7 @@ class VOCDetection(Dataset):
                     'difficult': difficult
                 }
 
-                if gt_bbox.size != 0:
+                if gt_bbox.size > 0:
                     self.file_list.append({
                         'image': img_file,
                         **
@@ -251,22 +257,38 @@ class VOCDetection(Dataset):
                         **
                         label_info
                     })
-                    ct += 1
                     annotations['images'].append({
                         'height': im_h,
                         'width': im_w,
                         'id': int(im_id[0]),
                         'file_name': osp.split(img_file)[1]
                     })
+                else:
+                    self.neg_file_list.append({
+                        'image': img_file,
+                        **
+                        im_info,
+                        **
+                        label_info
+                    })
+                ct += 1
+
                 if self.use_mix:
                     self.num_max_boxes = max(self.num_max_boxes, 2 * len(objs))
                 else:
                     self.num_max_boxes = max(self.num_max_boxes, len(objs))
 
-        if not len(self.file_list) > 0:
-            raise Exception('not found any voc record in %s' % (file_list))
-        logging.info("{} samples in file {}".format(
-            len(self.file_list), file_list))
+        if not ct:
+            logging.error(
+                "No voc record found in %s' % (file_list)", exit=True)
+        logging.info(
+            "{} samples in file {}, including {} positive samples and {} negative samples.".
+            format(
+                len(self.file_list) + len(self.neg_file_list), file_list,
+                len(self.file_list), len(self.neg_file_list)))
+
+        if self.allow_empty:
+            self.file_list += self.neg_file_list
         self.num_samples = len(self.file_list)
         self.coco_gt = COCO()
         self.coco_gt.dataset = annotations