sunyanfang01 преди 5 години
родител
ревизия
2ec3e44938
променени са 4 файла, в които са добавени 33 реда и са изтрити 43 реда
  1. 1 1
      paddlex/cv/datasets/coco.py
  2. 1 1
      paddlex/cv/datasets/easydata_det.py
  3. 1 1
      paddlex/cv/datasets/voc.py
  4. 30 40
      paddlex/cv/transforms/det_transforms.py

+ 1 - 1
paddlex/cv/datasets/coco.py

@@ -112,7 +112,7 @@ class CocoDetection(VOCDetection):
 
             im_info = {
                 'im_id': np.array([img_id]).astype('int32'),
-                'origin_shape': np.array([im_h, im_w]).astype('int32'),
+                'image_shape': np.array([im_h, im_w]).astype('int32'),
             }
             label_info = {
                 'is_crowd': is_crowd,

+ 1 - 1
paddlex/cv/datasets/easydata_det.py

@@ -143,7 +143,7 @@ class EasyDataDet(VOCDetection):
                     ann_ct += 1
                 im_info = {
                     'im_id': im_id,
-                    'origin_shape': np.array([im_h, im_w]).astype('int32'),
+                    'image_shape': np.array([im_h, im_w]).astype('int32'),
                 }
                 label_info = {
                     'is_crowd': is_crowd,

+ 1 - 1
paddlex/cv/datasets/voc.py

@@ -146,7 +146,7 @@ class VOCDetection(Dataset):
 
                 im_info = {
                     'im_id': im_id,
-                    'origin_shape': np.array([im_h, im_w]).astype('int32'),
+                    'image_shape': np.array([im_h, im_w]).astype('int32'),
                 }
                 label_info = {
                     'is_crowd': is_crowd,

+ 30 - 40
paddlex/cv/transforms/det_transforms.py

@@ -58,8 +58,8 @@ class Compose:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
             im_info (dict): 存储与图像相关的信息,dict中的字段如下:
                 - im_id (np.ndarray): 图像序列号,形状为(1,)。
-                - origin_shape (np.ndarray): 图像原始大小,形状为(2,),
-                                        origin_shape[0]为高,origin_shape[1]为宽。
+                - image_shape (np.ndarray): 图像原始大小,形状为(2,),
+                                        image_shape[0]为高,image_shape[1]为宽。
                 - mixup (list): list为[im, im_info, label_info],分别对应
                                 与当前图像进行mixup的图像np.ndarray数据、图像相关信息、标注框相关信息;
                                 注意,当前epoch若无需进行mixup,则无该字段。
@@ -93,9 +93,6 @@ class Compose:
             # make default im_info with [h, w, 1]
             im_info['im_resize_info'] = np.array(
                 [im.shape[0], im.shape[1], 1.], dtype=np.float32)
-            # copy augment_shape from origin_shape
-            im_info['augment_shape'] = np.array([im.shape[0],
-                                                 im.shape[1]]).astype('int32')
             if not self.use_mixup:
                 if 'mixup' in im_info:
                     del im_info['mixup']
@@ -387,16 +384,13 @@ class RandomHorizontalFlip:
             raise TypeError(
                 'Cannot do RandomHorizontalFlip! ' +
                 'Becasuse the im_info and label_info can not be None!')
-        if 'augment_shape' not in im_info:
-            raise TypeError('Cannot do RandomHorizontalFlip! ' + \
-                            'Becasuse augment_shape is not in im_info!')
         if 'gt_bbox' not in label_info:
             raise TypeError('Cannot do RandomHorizontalFlip! ' + \
                             'Becasuse gt_bbox is not in label_info!')
-        augment_shape = im_info['augment_shape']
+        image_shape = im_info['image_shape']
         gt_bbox = label_info['gt_bbox']
-        height = augment_shape[0]
-        width = augment_shape[1]
+        height = image_shape[0]
+        width = image_shape[1]
 
         if np.random.uniform(0, 1) < self.prob:
             im = horizontal_flip(im)
@@ -567,7 +561,7 @@ class MixupImage:
             (2)拼接原图像标注框和mixup图像标注框。
             (3)拼接原图像标注框类别和mixup图像标注框类别。
             (4)原图像标注框混合得分乘以factor,mixup图像标注框混合得分乘以(1-factor),叠加2个结果。
-    3. 更新im_info中的augment_shape信息。
+    3. 更新im_info中的image_shape信息。
 
     Args:
         alpha (float): 随机beta分布的下限。默认为1.5。
@@ -610,7 +604,7 @@ class MixupImage:
                    当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
                    存储与标注框相关信息的字典。
                    其中,im_info更新字段为:
-                       - augment_shape (np.ndarray): mixup后的图像高、宽二者组成的np.ndarray,形状为(2,)。
+                       - image_shape (np.ndarray): mixup后的图像高、宽二者组成的np.ndarray,形状为(2,)。
                    im_info删除的字段:
                        - mixup (list): 与当前字段进行mixup的图像相关信息。
                    label_info更新字段为:
@@ -674,8 +668,8 @@ class MixupImage:
         label_info['gt_score'] = gt_score
         label_info['gt_class'] = gt_class
         label_info['is_crowd'] = is_crowd
-        im_info['augment_shape'] = np.array([im.shape[0],
-                                             im.shape[1]]).astype('int32')
+        im_info['image_shape'] = np.array([im.shape[0],
+                                           im.shape[1]]).astype('int32')
         im_info.pop('mixup')
         if label_info is None:
             return (im, im_info)
@@ -721,7 +715,7 @@ class RandomExpand:
                    当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
                    存储与标注框相关信息的字典。
                    其中,im_info更新字段为:
-                       - augment_shape (np.ndarray): 扩张后的图像高、宽二者组成的np.ndarray,形状为(2,)。
+                       - image_shape (np.ndarray): 扩张后的图像高、宽二者组成的np.ndarray,形状为(2,)。
                    label_info更新字段为:
                        - gt_bbox (np.ndarray): 随机扩张后真实标注框坐标,形状为(n, 4),
                                           其中n代表真实标注框的个数。
@@ -734,9 +728,6 @@ class RandomExpand:
             raise TypeError(
                 'Cannot do RandomExpand! ' +
                 'Becasuse the im_info and label_info can not be None!')
-        if 'augment_shape' not in im_info:
-            raise TypeError('Cannot do RandomExpand! ' + \
-                            'Becasuse augment_shape is not in im_info!')
         if 'gt_bbox' not in label_info or \
                 'gt_class' not in label_info:
             raise TypeError('Cannot do RandomExpand! ' + \
@@ -744,9 +735,9 @@ class RandomExpand:
         if np.random.uniform(0., 1.) < self.prob:
             return (im, im_info, label_info)
 
-        augment_shape = im_info['augment_shape']
-        height = int(augment_shape[0])
-        width = int(augment_shape[1])
+        image_shape = im_info['image_shape']
+        height = int(image_shape[0])
+        width = int(image_shape[1])
 
         expand_ratio = np.random.uniform(1., self.ratio)
         h = int(height * expand_ratio)
@@ -759,7 +750,7 @@ class RandomExpand:
         canvas *= np.array(self.fill_value, dtype=np.float32)
         canvas[y:y + height, x:x + width, :] = im
 
-        im_info['augment_shape'] = np.array([h, w]).astype('int32')
+        im_info['image_shape'] = np.array([h, w]).astype('int32')
         if 'gt_bbox' in label_info and len(label_info['gt_bbox']) > 0:
             label_info['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
         if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
@@ -815,12 +806,14 @@ class RandomCrop:
             tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
                    当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
                    存储与标注框相关信息的字典。
-                   其中,label_info更新字段为:
-                       - gt_bbox (np.ndarray): 随机裁剪后真实标注框坐标,形状为(n, 4),
+                   其中,im_info更新字段为:
+                           - image_shape (np.ndarray): 扩裁剪的图像高、宽二者组成的np.ndarray,形状为(2,)。
+                       label_info更新字段为:
+                           - gt_bbox (np.ndarray): 随机裁剪后真实标注框坐标,形状为(n, 4),
                                           其中n代表真实标注框的个数。
-                       - gt_class (np.ndarray): 随机裁剪后每个真实标注框对应的类别序号,形状为(n, 1),
+                           - gt_class (np.ndarray): 随机裁剪后每个真实标注框对应的类别序号,形状为(n, 1),
                                            其中n代表真实标注框的个数。
-                       - gt_score (np.ndarray): 随机裁剪后每个真实标注框对应的混合得分,形状为(n, 1),
+                           - gt_score (np.ndarray): 随机裁剪后每个真实标注框对应的混合得分,形状为(n, 1),
                                            其中n代表真实标注框的个数。
 
         Raises:
@@ -830,9 +823,6 @@ class RandomCrop:
             raise TypeError(
                 'Cannot do RandomCrop! ' +
                 'Becasuse the im_info and label_info can not be None!')
-        if 'augment_shape' not in im_info:
-            raise TypeError('Cannot do RandomCrop! ' + \
-                            'Becasuse augment_shape is not in im_info!')
         if 'gt_bbox' not in label_info or \
                 'gt_class' not in label_info:
             raise TypeError('Cannot do RandomCrop! ' + \
@@ -841,9 +831,9 @@ class RandomCrop:
         if len(label_info['gt_bbox']) == 0:
             return (im, im_info, label_info)
 
-        augment_shape = im_info['augment_shape']
-        w = augment_shape[1]
-        h = augment_shape[0]
+        image_shape = im_info['image_shape']
+        w = image_shape[1]
+        h = image_shape[0]
         gt_bbox = label_info['gt_bbox']
         thresholds = list(self.thresholds)
         if self.allow_no_crop:
@@ -902,7 +892,7 @@ class RandomCrop:
                 label_info['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
                 label_info['gt_class'] = np.take(
                     label_info['gt_class'], valid_ids, axis=0)
-                im_info['augment_shape'] = np.array(
+                im_info['image_shape'] = np.array(
                     [crop_box[3] - crop_box[1],
                      crop_box[2] - crop_box[0]]).astype('int32')
                 if 'gt_score' in label_info:
@@ -973,7 +963,7 @@ class ArrangeFasterRCNN:
             im_resize_info = im_info['im_resize_info']
             im_id = im_info['im_id']
             im_shape = np.array(
-                (im_info['augment_shape'][0], im_info['augment_shape'][1], 1),
+                (im_info['image_shape'][0], im_info['image_shape'][1], 1),
                 dtype=np.float32)
             gt_bbox = label_info['gt_bbox']
             gt_class = label_info['gt_class']
@@ -986,7 +976,7 @@ class ArrangeFasterRCNN:
                                 'Becasuse the im_info can not be None!')
             im_resize_info = im_info['im_resize_info']
             im_shape = np.array(
-                (im_info['augment_shape'][0], im_info['augment_shape'][1], 1),
+                (im_info['image_shape'][0], im_info['image_shape'][1], 1),
                 dtype=np.float32)
             outputs = (im, im_resize_info, im_shape)
         return outputs
@@ -1066,7 +1056,7 @@ class ArrangeMaskRCNN:
                                 'Becasuse the im_info can not be None!')
             im_resize_info = im_info['im_resize_info']
             im_shape = np.array(
-                (im_info['augment_shape'][0], im_info['augment_shape'][1], 1),
+                (im_info['image_shape'][0], im_info['image_shape'][1], 1),
                 dtype=np.float32)
             if self.mode == 'eval':
                 im_id = im_info['im_id']
@@ -1117,7 +1107,7 @@ class ArrangeYOLOv3:
                 raise TypeError(
                     'Cannot do ArrangeYolov3! ' +
                     'Becasuse the im_info and label_info can not be None!')
-            im_shape = im_info['augment_shape']
+            im_shape = im_info['image_shape']
             if len(label_info['gt_bbox']) != len(label_info['gt_class']):
                 raise ValueError("gt num mismatch: bbox and class.")
             if len(label_info['gt_bbox']) != len(label_info['gt_score']):
@@ -1141,7 +1131,7 @@ class ArrangeYOLOv3:
                 raise TypeError(
                     'Cannot do ArrangeYolov3! ' +
                     'Becasuse the im_info and label_info can not be None!')
-            im_shape = im_info['augment_shape']
+            im_shape = im_info['image_shape']
             if len(label_info['gt_bbox']) != len(label_info['gt_class']):
                 raise ValueError("gt num mismatch: bbox and class.")
             im_id = im_info['im_id']
@@ -1160,6 +1150,6 @@ class ArrangeYOLOv3:
             if im_info is None:
                 raise TypeError('Cannot do ArrangeYolov3! ' +
                                 'Becasuse the im_info can not be None!')
-            im_shape = im_info['augment_shape']
+            im_shape = im_info['image_shape']
             outputs = (im, im_shape)
         return outputs