浏览代码

fix bug in generate_minibatch

FlyingQianMM 5 年之前
父节点
当前提交
2c4d02eaa0
共有 1 个文件被更改,包括 3 次插入1 次删除
  1. 3 1
      paddlex/cv/datasets/dataset.py

+ 3 - 1
paddlex/cv/datasets/dataset.py

@@ -218,9 +218,11 @@ def generate_minibatch(batch_data, label_padding_value=255):
             (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
         padding_im[:, :im_h, :im_w] = data[0]
         if len(data) > 1:
-            if isinstance(data[1], np.ndarray):
+            if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
                 # padding the image and label of segmentation
                 # during the training  and evaluating phase
+                # the data[1] of segmentation is a image array,
+                # so len(data[1].shape) > 1
                 padding_label = np.zeros(
                     (1, max_shape[1], max_shape[2]
                      )).astype('int64') + label_padding_value