Forráskód Böngészése

fix miou is 0 when im_info is []

FlyingQianMM 5 éve
szülő
commit
2c7e59d4ff

+ 0 - 6
paddlex/cv/models/deeplabv3p.py

@@ -420,9 +420,6 @@ class DeepLabv3p(BaseAPI):
                     elif info[0] == 'padding':
                         w, h = info[1][1], info[1][0]
                         one_pred = one_pred[0:h, 0:w]
-                    else:
-                        raise Exception(
-                            "Unexpected info '{}' in im_info".format(info[0]))
                 one_pred = one_pred.astype('int64')
                 one_pred = one_pred[np.newaxis, :, :, np.newaxis]
                 one_label = one_label[np.newaxis, np.newaxis, :, :]
@@ -480,9 +477,6 @@ class DeepLabv3p(BaseAPI):
                     w, h = info[1][1], info[1][0]
                     pred = pred[0:h, 0:w]
                     logit = logit[0:h, 0:w, :]
-                else:
-                    raise Exception("Unexpected info '{}' in im_info".format(
-                        info[0]))
             pred_list.append(pred)
             logit_list.append(logit)
 

+ 2 - 2
paddlex/cv/transforms/seg_transforms.py

@@ -73,8 +73,6 @@ class Compose(SegTransform):
             tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
         """
 
-        if im_info is None:
-            im_info = list()
         if isinstance(im, np.ndarray):
             if len(im.shape) != 3:
                 raise Exception(
@@ -86,6 +84,8 @@ class Compose(SegTransform):
             except:
                 raise ValueError('Can\'t read The image file {}!'.format(im))
         im = im.astype('float32')
+        if im_info is None:
+            im_info = [('origin_shape', im.shape[0:2])]
         if self.to_rgb:
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         if label is not None: