|
|
@@ -14,6 +14,7 @@
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
import copy
|
|
|
+import os
|
|
|
import os.path as osp
|
|
|
import random
|
|
|
import re
|
|
|
@@ -239,7 +240,7 @@ class VOCDetection(Dataset):
|
|
|
im_info = {
|
|
|
'im_id': im_id,
|
|
|
'image_shape': np.array(
|
|
|
- [im_h, im_w], dtype=np.int32),
|
|
|
+ [im_h, im_w], dtype=np.int32)
|
|
|
}
|
|
|
label_info = {
|
|
|
'is_crowd': is_crowd,
|
|
|
@@ -321,3 +322,64 @@ class VOCDetection(Dataset):
|
|
|
|
|
|
def set_epoch(self, epoch_id):
|
|
|
self._epoch = epoch_id
|
|
|
+
|
|
|
+ def add_negative_samples(self, image_dir):
|
|
|
+ """将背景图片加入训练
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image_dir (str):背景图片所在的文件夹目录。
|
|
|
+
|
|
|
+ """
|
|
|
+ import cv2
|
|
|
+ if not osp.isdir(image_dir):
|
|
|
+ raise Exception("{} is not a valid image directory.".format(
|
|
|
+ image_dir))
|
|
|
+ image_list = os.listdir(image_dir)
|
|
|
+ max_img_id = len(self.file_list)
|
|
|
+ ct = 0
|
|
|
+ for image in image_list:
|
|
|
+ if not is_pic(image):
|
|
|
+ continue
|
|
|
+ gt_bbox = np.array([], dtype=np.float32)
|
|
|
+ gt_class = np.array([], dtype=np.int32)
|
|
|
+ gt_score = np.array([], dtype=np.float32)
|
|
|
+ is_crowd = np.array([], dtype=np.int32)
|
|
|
+ difficult = np.array([], dtype=np.int32)
|
|
|
+
|
|
|
+ max_img_id += 1
|
|
|
+ im_fname = osp.join(image_dir, image)
|
|
|
+ img_data = cv2.imread(im_fname, cv2.IMREAD_UNCHANGED)
|
|
|
+ im_h, im_w, im_c = img_data.shape
|
|
|
+
|
|
|
+ im_info = {
|
|
|
+ 'im_id': np.asarray([max_img_id]),
|
|
|
+ 'image_shape': np.array(
|
|
|
+ [im_h, im_w], dtype=np.int32)
|
|
|
+ }
|
|
|
+ label_info = {
|
|
|
+ 'is_crowd': is_crowd,
|
|
|
+ 'gt_class': gt_class,
|
|
|
+ 'gt_bbox': gt_bbox,
|
|
|
+ 'gt_score': gt_score,
|
|
|
+ 'difficult': difficult
|
|
|
+ }
|
|
|
+ if 'gt_poly' in self.file_list[0]:
|
|
|
+ label_info['gt_poly'] = []
|
|
|
+
|
|
|
+ self.neg_file_list.append({
|
|
|
+ 'image': im_fname,
|
|
|
+ **
|
|
|
+ im_info,
|
|
|
+ **
|
|
|
+ label_info
|
|
|
+ })
|
|
|
+ self.file_list.append({
|
|
|
+ 'image': im_fname,
|
|
|
+ **
|
|
|
+ im_info,
|
|
|
+ **
|
|
|
+ label_info
|
|
|
+ })
|
|
|
+ ct += 1
|
|
|
+ self.num_samples = len(self.file_list)
|
|
|
+ logging.info("{} negative samples added.".format(ct))
|