Forráskód Böngészése

Merge pull request #249 from FlyingQianMM/develop_qh_3

fix bug in segmentation evaluation
Jason 5 éve
szülő
commit
7e91288575
1 módosított fájl, 5 hozzáadás és 4 törlés
  1. 5 4
      paddlex/cv/models/deeplabv3p.py

+ 5 - 4
paddlex/cv/models/deeplabv3p.py

@@ -360,18 +360,19 @@ class DeepLabv3p(BaseAPI):
                 pred = pred[0:num_samples]
 
             for i in range(num_samples):
-                one_pred = pred[i].astype('uint8')
+                one_pred = np.squeeze(pred[i]).astype('uint8')
                 one_label = labels[i]
                 for info in im_info[i][::-1]:
                     if info[0] == 'resize':
                         w, h = info[1][1], info[1][0]
-                        one_pred = cv2.resize(one_pred, (w, h), cv2.INTER_NEAREST)
+                        one_pred = cv2.resize(one_pred, (w, h),
+                                              cv2.INTER_NEAREST)
                     elif info[0] == 'padding':
                         w, h = info[1][1], info[1][0]
                         one_pred = one_pred[0:h, 0:w]
                     else:
-                        raise Exception("Unexpected info '{}' in im_info".format(
-                            info[0]))
+                        raise Exception(
+                            "Unexpected info '{}' in im_info".format(info[0]))
                 one_pred = one_pred.astype('int64')
                 one_pred = one_pred[np.newaxis, :, :, np.newaxis]
                 one_label = one_label[np.newaxis, np.newaxis, :, :]