浏览代码

Merge pull request #213 from FlyingQianMM/develop_qh

fix bug in generate_minibatch
Jason 5 年之前
父节点
当前提交
4c21da1f8c
共有 1 个文件被更改,包括 3 次插入1 次删除
  1. 3 1
      paddlex/cv/datasets/dataset.py

+ 3 - 1
paddlex/cv/datasets/dataset.py

@@ -218,9 +218,11 @@ def generate_minibatch(batch_data, label_padding_value=255):
             (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
             (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
         padding_im[:, :im_h, :im_w] = data[0]
         padding_im[:, :im_h, :im_w] = data[0]
         if len(data) > 1:
         if len(data) > 1:
-            if isinstance(data[1], np.ndarray):
+            if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
                 # padding the image and label of segmentation
                 # padding the image and label of segmentation
                 # during the training  and evaluating phase
                 # during the training  and evaluating phase
+                # the data[1] of segmentation is a image array,
+                # so len(data[1].shape) > 1
                 padding_label = np.zeros(
                 padding_label = np.zeros(
                     (1, max_shape[1], max_shape[2]
                     (1, max_shape[1], max_shape[2]
                      )).astype('int64') + label_padding_value
                      )).astype('int64') + label_padding_value