Przeglądaj źródła

skip pad_gt in batch_padding

will-jl944 4 lat temu
rodzic
commit
b4befaa271
1 zmienionych plików z 1 dodań i 65 usunięć
  1. 1 65
      dygraph/paddlex/cv/transforms/batch_operators.py

+ 1 - 65
dygraph/paddlex/cv/transforms/batch_operators.py

@@ -149,10 +149,9 @@ class BatchRandomResizeByShort(Transform):
 
 
 class _BatchPadding(Transform):
-    def __init__(self, pad_to_stride=0, pad_gt=False):
+    def __init__(self, pad_to_stride=0):
         super(_BatchPadding, self).__init__()
         self.pad_to_stride = pad_to_stride
-        self.pad_gt = pad_gt
 
     def __call__(self, samples):
         coarsest_stride = self.pad_to_stride
@@ -171,69 +170,6 @@ class _BatchPadding(Transform):
             padding_im[:im_h, :im_w, :] = im
             data['image'] = padding_im
 
-        if self.pad_gt:
-            gt_num = []
-            if 'gt_poly' in data and data['gt_poly'] is not None and len(data[
-                    'gt_poly']) > 0:
-                pad_mask = True
-            else:
-                pad_mask = False
-
-            if pad_mask:
-                poly_num = []
-                poly_part_num = []
-                point_num = []
-
-            for data in samples:
-                gt_num.append(data['gt_bbox'].shape[0])
-                if pad_mask:
-                    poly_num.append(len(data['gt_poly']))
-                    for poly in data['gt_poly']:
-                        poly_part_num.append(int(len(poly)))
-                        for p_p in poly:
-                            point_num.append(int(len(p_p) / 2))
-            gt_num_max = max(gt_num)
-
-            for i, data in enumerate(samples):
-                gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32)
-                gt_class_data = -np.ones([gt_num_max], dtype=np.int32)
-                is_crowd_data = np.ones([gt_num_max], dtype=np.int32)
-
-                if pad_mask:
-                    poly_num_max = max(poly_num)
-                    poly_part_num_max = max(poly_part_num)
-                    point_num_max = max(point_num)
-                    gt_masks_data = -np.ones(
-                        [poly_num_max, poly_part_num_max, point_num_max, 2],
-                        dtype=np.float32)
-
-                gt_num = data['gt_bbox'].shape[0]
-                gt_box_data[0:gt_num, :] = data['gt_bbox']
-                gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])
-                if 'is_crowd' in data:
-                    is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd'])
-                    data['is_crowd'] = is_crowd_data
-
-                data['gt_bbox'] = gt_box_data
-                data['gt_class'] = gt_class_data
-
-                if pad_mask:
-                    for j, poly in enumerate(data['gt_poly']):
-                        for k, p_p in enumerate(poly):
-                            pp_np = np.array(p_p).reshape(-1, 2)
-                            gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np
-                    data['gt_poly'] = gt_masks_data
-
-                if 'gt_score' in data:
-                    gt_score_data = np.zeros([gt_num_max], dtype=np.float32)
-                    gt_score_data[0:gt_num] = data['gt_score'][:gt_num, 0]
-                    data['gt_score'] = gt_score_data
-
-                if 'difficult' in data:
-                    diff_data = np.zeros([gt_num_max], dtype=np.int32)
-                    diff_data[0:gt_num] = data['difficult'][:gt_num, 0]
-                    data['difficult'] = diff_data
-
         return samples