فهرست منبع

Merge pull request #213 from FlyingQianMM/develop_qh

fix bug in generate_minibatch
Jason 5 سال پیش
والد
کامیت
4c21da1f8c
1فایلهای تغییر یافته به همراه3 افزوده شده و 1 حذف شده
  1. 3 1
      paddlex/cv/datasets/dataset.py

+ 3 - 1
paddlex/cv/datasets/dataset.py

@@ -218,9 +218,11 @@ def generate_minibatch(batch_data, label_padding_value=255):
             (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
         padding_im[:, :im_h, :im_w] = data[0]
         if len(data) > 1:
-            if isinstance(data[1], np.ndarray):
+            if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
                 # padding the image and label of segmentation
                 # during the training  and evaluating phase
+                # the data[1] of segmentation is a image array,
+                # so len(data[1].shape) > 1
                 padding_label = np.zeros(
                     (1, max_shape[1], max_shape[2]
                      )).astype('int64') + label_padding_value