Parcourir la source

Merge pull request #366 from FlyingQianMM/develop_new

fix generate_mini_batch bugs
Jason il y a 5 ans
Parent
commit
9f13df142a
1 fichiers modifiés avec 19 ajouts et 25 suppressions
  1. 19 25
      paddlex/cv/datasets/dataset.py

+ 19 - 25
paddlex/cv/datasets/dataset.py

@@ -220,42 +220,36 @@ def generate_minibatch(batch_data, label_padding_value=255, mapper=None):
         padding_im = np.zeros(
             (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
         padding_im[:, :im_h, :im_w] = data[0]
-        if len(data) > 2:
-            # padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
-            if len(data[1]) == 0 or 'padding' not in [
-                    data[1][i][0] for i in range(len(data[1]))
-            ]:
-                data[1].append(('padding', [im_h, im_w]))
-            padding_batch.append((padding_im, data[1], data[2]))
-
-        elif len(data) > 1:
-            if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
-                # padding the image and label of segmentation during the training
-                # the data[1] of segmentation is a image array,
-                # so len(data[1].shape) > 1
-                padding_label = np.zeros(
-                    (1, max_shape[1], max_shape[2]
-                     )).astype('int64') + label_padding_value
-                _, label_h, label_w = data[1].shape
-                padding_label[:, :label_h, :label_w] = data[1]
-                padding_batch.append((padding_im, padding_label))
-            elif len(data[1]) == 0 or isinstance(data[1][0], tuple) and data[
-                    1][0][0] in ['origin_shape', 'resize', 'padding']:
-                # padding the image and insert 'padding' into `im_info`
-                # of segmentation during the infering phase
+        if len(data) > 1:
+            if isinstance(data[1], np.ndarray):
+                if data[1].ndim == 3:
+                    # padding the image and label of segmentation during the training
+                    # the data[1] of segmentation is a image array, so data[1].ndim is 3.
+                    padding_label = np.zeros(
+                        (1, max_shape[1], max_shape[2]
+                         )).astype('int64') + label_padding_value
+                    _, label_h, label_w = data[1].shape
+                    padding_label[:, :label_h, :label_w] = data[1]
+                    padding_batch.append((padding_im, padding_label))
+                else:
+                    # padding the image of detection
+                    padding_batch.append((padding_im, ) + tuple(data[1:]))
+            elif isinstance(data[1], (list, tuple)):
+                # padding the image and insert 'padding' into `im_info` of segmentation
+                # during evaluating or inferring phase.
                 if len(data[1]) == 0 or 'padding' not in [
                         data[1][i][0] for i in range(len(data[1]))
                 ]:
                     data[1].append(('padding', [im_h, im_w]))
                 padding_batch.append((padding_im, ) + tuple(data[1:]))
             else:
-                # padding the image of detection, or
+                # padding the image of classification during the train/eval  phase
                 # padding the image of classification during the trainging
                 # and evaluating phase
                 padding_batch.append((padding_im, ) + tuple(data[1:]))
         else:
             # padding the image of classification during the infering phase
-            padding_batch.append((padding_im))
+            padding_batch.append((padding_im, ))
     return padding_batch