Explorar o código

support add negative samples

will-jl944 %!s(int64=4) %!d(string=hai) anos
pai
achega
e21af195de
Modificáronse 1 ficheiros con 63 adicións e 1 borrados
  1. 63 1
      paddlex/cv/datasets/voc.py

+ 63 - 1
paddlex/cv/datasets/voc.py

@@ -14,6 +14,7 @@
 
 from __future__ import absolute_import
 import copy
+import os
 import os.path as osp
 import random
 import re
@@ -239,7 +240,7 @@ class VOCDetection(Dataset):
                 im_info = {
                     'im_id': im_id,
                     'image_shape': np.array(
-                        [im_h, im_w], dtype=np.int32),
+                        [im_h, im_w], dtype=np.int32)
                 }
                 label_info = {
                     'is_crowd': is_crowd,
@@ -321,3 +322,64 @@ class VOCDetection(Dataset):
 
     def set_epoch(self, epoch_id):
         self._epoch = epoch_id
+
+    def add_negative_samples(self, image_dir):
+        """将背景图片加入训练
+
+        Args:
+            image_dir (str):背景图片所在的文件夹目录。
+
+        """
+        import cv2
+        if not osp.isdir(image_dir):
+            raise Exception("{} is not a valid image directory.".format(
+                image_dir))
+        image_list = os.listdir(image_dir)
+        max_img_id = len(self.file_list)
+        ct = 0
+        for image in image_list:
+            if not is_pic(image):
+                continue
+            gt_bbox = np.array([], dtype=np.float32)
+            gt_class = np.array([], dtype=np.int32)
+            gt_score = np.array([], dtype=np.float32)
+            is_crowd = np.array([], dtype=np.int32)
+            difficult = np.array([], dtype=np.int32)
+
+            max_img_id += 1
+            im_fname = osp.join(image_dir, image)
+            img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED)
+            im_h, im_w, im_c = img_data.shape
+
+            im_info = {
+                'im_id': np.asarray([max_img_id]),
+                'image_shape': np.array(
+                    [im_h, im_w], dtype=np.int32)
+            }
+            label_info = {
+                'is_crowd': is_crowd,
+                'gt_class': gt_class,
+                'gt_bbox': gt_bbox,
+                'gt_score': gt_score,
+                'difficult': difficult
+            }
+            if 'gt_poly' in self.file_list[0]:
+                label_info['gt_poly'] = []
+
+            self.neg_file_list.append({
+                'image': im_fname,
+                **
+                im_info,
+                **
+                label_info
+            })
+            self.file_list.append({
+                'image': im_fname,
+                **
+                im_info,
+                **
+                label_info
+            })
+            ct += 1
+        self.num_samples = len(self.file_list)
+        logging.info("{} negative samples added.".format(ct))