Browse Source

Merge pull request #227 from wuyefeilin/develop

eval in origin image
Jason 5 years ago
parent
commit
cd856fb1b4

+ 11 - 3
paddlex/cv/datasets/dataset.py

@@ -217,10 +217,18 @@ def generate_minibatch(batch_data, label_padding_value=255):
         padding_im = np.zeros(
             (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
         padding_im[:, :im_h, :im_w] = data[0]
-        if len(data) > 1:
+        if len(data) > 2:
+           # padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
+            if len(data[1]) == 0 or 'padding' not in [
+                    data[1][i][0] for i in range(len(data[1]))
+            ]:
+                data[1].append(('padding', [im_h, im_w]))
+            padding_batch.append((padding_im, data[1], data[2]))
+
+            
+        elif len(data) > 1:
             if isinstance(data[1], np.ndarray) and len(data[1].shape) > 1:
-                # padding the image and label of segmentation
-                # during the training  and evaluating phase
+                # padding the image and label of segmentation during the training
                 # the data[1] of segmentation is a image array,
                 # so len(data[1].shape) > 1
                 padding_label = np.zeros(

+ 20 - 4
paddlex/cv/models/deeplabv3p.py

@@ -340,7 +340,8 @@ class DeepLabv3p(BaseAPI):
         for step, data in tqdm.tqdm(
                 enumerate(data_generator()), total=total_steps):
             images = np.array([d[0] for d in data])
-            labels = np.array([d[1] for d in data])
+            im_info = [d[1] for d in data]
+            labels = [d[2] for d in data]
 
             num_samples = images.shape[0]
             if num_samples < batch_size:
@@ -358,10 +359,25 @@ class DeepLabv3p(BaseAPI):
             if num_samples < batch_size:
                 pred = pred[0:num_samples]
 
-            mask = labels != self.ignore_index
-            conf_mat.calculate(pred=pred, label=labels, ignore=mask)
+            for i in range(num_samples):
+                one_pred = pred[i].astype('uint8')
+                one_label = labels[i]
+                for info in im_info[i][::-1]:
+                    if info[0] == 'resize':
+                        w, h = info[1][1], info[1][0]
+                        one_pred = cv2.resize(one_pred, (w, h), cv2.INTER_NEAREST)
+                    elif info[0] == 'padding':
+                        w, h = info[1][1], info[1][0]
+                        one_pred = one_pred[0:h, 0:w]
+                    else:
+                        raise Exception("Unexpected info '{}' in im_info".format(
+                            info[0]))
+                one_pred = one_pred.astype('int64')
+                one_pred = one_pred[np.newaxis, :, :, np.newaxis]
+                one_label = one_label[np.newaxis, np.newaxis, :, :]
+                mask = one_label != self.ignore_index
+                conf_mat.calculate(pred=one_pred, label=one_label, ignore=mask)
             _, iou = conf_mat.mean_iou()
-
             logging.debug("[EVAL] Epoch={}, Step={}/{}, iou={}".format(
                 epoch_id, step + 1, total_steps, iou))
 

+ 9 - 1
paddlex/cv/transforms/seg_transforms.py

@@ -90,6 +90,7 @@ class Compose(SegTransform):
         if label is not None:
             if not isinstance(label, np.ndarray):
                 label = np.asarray(Image.open(label))
+            origin_label = label.copy()
         for op in self.transforms:
             if isinstance(op, SegTransform):
                 outputs = op(im, im_info, label)
@@ -104,6 +105,10 @@ class Compose(SegTransform):
                     outputs = (im, im_info, label)
                 else:
                     outputs = (im, im_info)
+        if self.transforms[-1].__class__.__name__ == 'ArrangeSegmenter':
+            if self.transforms[-1].mode == 'eval':
+                if label is not None:
+                    outputs = (im, im_info, origin_label)
         return outputs
 
     def add_augmenters(self, augmenters):
@@ -1092,9 +1097,12 @@ class ArrangeSegmenter(SegTransform):
                 'quant'时,返回的tuple为(im,),为图像np.ndarray数据。
         """
         im = permute(im, False)
-        if self.mode == 'train' or self.mode == 'eval':
+        if self.mode == 'train':
             label = label[np.newaxis, :, :]
             return (im, label)
+        if self.mode == 'eval':
+            label = label[np.newaxis, :, :]
+            return (im, im_info, label)
         elif self.mode == 'test':
             return (im, im_info)
         else: