Browse Source

refine seg postprocess python code

will-jl944 4 years ago
parent
commit
bd9307c6d9
1 changed files with 46 additions and 32 deletions
  1. 46 32
      paddlex/cv/models/segmenter.py

+ 46 - 32
paddlex/cv/models/segmenter.py

@@ -105,28 +105,35 @@ class BaseSegmenter(BaseModel):
         if mode == 'test':
             origin_shape = inputs[1]
             if self.status == 'Infer':
-                label_map, score_map = self._postprocess(
+                label_map_list, score_map_list = self._postprocess(
                     net_out, origin_shape, transforms=inputs[2])
             else:
-                logit = self._postprocess(
+                logit_list = self._postprocess(
                     logit, origin_shape, transforms=inputs[2])
-                score_map = paddle.transpose(
-                    F.softmax(
-                        logit, axis=1), perm=[0, 2, 3, 1])
-                label_map = paddle.argmax(
-                    score_map, axis=-1, keepdim=True, dtype='int32')
-            outputs['label_map'] = paddle.squeeze(label_map)
-            outputs['score_map'] = paddle.squeeze(score_map)
+                label_map_list = []
+                score_map_list = []
+                for logit in logit_list:
+                    logit = paddle.transpose(logit, perm=[0, 2, 3, 1])  # NHWC
+                    label_map_list.append((paddle.argmax(
+                        logit, axis=-1, keepdim=False, dtype='int32')).squeeze(
+                        ).numpy().astype('int32'))
+                    score_map_list.append(
+                        F.softmax(
+                            logit, axis=-1).squeeze().numpy().astype(
+                                'float32'))
+            outputs['label_map'] = label_map_list
+            outputs['score_map'] = score_map_list
 
         if mode == 'eval':
             if self.status == 'Infer':
-                pred = paddle.transpose(net_out[1], perm=[0, 3, 1, 2])
+                pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
                 pred = paddle.argmax(
                     logit, axis=1, keepdim=True, dtype='int32')
             label = inputs[1]
             origin_shape = [label.shape[-2:]]
-            pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
+            pred = self._postprocess(
+                pred, origin_shape, transforms=inputs[2])[0]  # NCHW
             intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
                 pred, label, self.num_classes)
             outputs['intersect_area'] = intersect_area
@@ -477,8 +484,8 @@ class BaseSegmenter(BaseModel):
             If img_file is a string or np.array, the result is a dict with key-value pairs:
             {"label map": `label map`, "score_map": `score map`}.
             If img_file is a list, the result is a list composed of dicts with the corresponding fields:
-            label_map(np.ndarray): the predicted label map
-            score_map(np.ndarray): the prediction score map (NHWC)
+            label_map(np.ndarray): the predicted label map (HW)
+            score_map(np.ndarray): the prediction score map (HWC)
 
         """
         if transforms is None and not hasattr(self, 'test_transforms'):
@@ -494,19 +501,23 @@ class BaseSegmenter(BaseModel):
         self.net.eval()
         data = (batch_im, batch_origin_shape, transforms.transforms)
         outputs = self.run(self.net, data, 'test')
-        label_map = outputs['label_map']
-        label_map = label_map.numpy().astype('uint8')
-        score_map = outputs['score_map']
-        score_map = score_map.numpy().astype('float32')
+        label_map_list = outputs['label_map']
+        score_map_list = outputs['score_map']
         if isinstance(img_file, list) and len(img_file) > 1:
             prediction = [{
                 'label_map': l,
                 'score_map': s
-            } for l, s in zip(label_map, score_map)]
+            } for l, s in zip(label_map_list, score_map_list)]
         elif isinstance(img_file, list):
-            prediction = [{'label_map': label_map, 'score_map': score_map}]
+            prediction = [{
+                'label_map': label_map_list[0],
+                'score_map': score_map_list[0]
+            }]
         else:
-            prediction = {'label_map': label_map, 'score_map': score_map}
+            prediction = {
+                'label_map': label_map_list[0],
+                'score_map': score_map_list[0]
+            }
         return prediction
 
     def _preprocess(self, images, transforms, to_tensor=True):
@@ -586,21 +597,24 @@ class BaseSegmenter(BaseModel):
                 batch_score_map=batch_pred[1],
                 batch_restore_list=batch_restore_list)
         results = []
+        if batch_pred.dtype == paddle.float32:
+            mode = 'bilinear'
+        else:
+            mode = 'nearest'
         for pred, restore_list in zip(batch_pred, batch_restore_list):
             pred = paddle.unsqueeze(pred, axis=0)
             for item in restore_list[::-1]:
                 h, w = item[1][0], item[1][1]
                 if item[0] == 'resize':
                     pred = F.interpolate(
-                        pred, (h, w), mode='nearest', data_format='NCHW')
+                        pred, (h, w), mode=mode, data_format='NCHW')
                 elif item[0] == 'padding':
                     x, y = item[2]
                     pred = pred[:, :, y:y + h, x:x + w]
                 else:
                     pass
             results.append(pred)
-        batch_pred = paddle.concat(results, axis=0)
-        return batch_pred
+        return results
 
     def _infer_postprocess(self, batch_label_map, batch_score_map,
                            batch_restore_list):
@@ -609,7 +623,7 @@ class BaseSegmenter(BaseModel):
         for label_map, score_map, restore_list in zip(
                 batch_label_map, batch_score_map, batch_restore_list):
             if not isinstance(label_map, np.ndarray):
-                label_map = paddle.unsqueeze(label_map, axis=0)
+                label_map = paddle.unsqueeze(label_map, axis=[0, 3])
                 score_map = paddle.unsqueeze(score_map, axis=0)
             for item in restore_list[::-1]:
                 h, w = item[1][0], item[1][1]
@@ -638,14 +652,14 @@ class BaseSegmenter(BaseModel):
                         score_map = score_map[:, :, y:y + h, x:x + w]
                 else:
                     pass
-            label_maps.append(label_map)
-            score_maps.append(score_map)
-        if isinstance(label_maps[0], np.ndarray):
-            return np.stack(label_maps, axis=0), np.stack(score_maps, axis=0)
-        else:
-            return paddle.concat(
-                label_maps, axis=0), paddle.concat(
-                    score_maps, axis=0)
+            label_map = label_map.squeeze()
+            score_map = score_map.squeeze()
+            if not isinstance(label_map, np.ndarray):
+                label_map = label_map.numpy()
+                score_map = score_map.numpy()
+            label_maps.append(label_map.squeeze())
+            score_maps.append(score_map.squeeze())
+        return label_maps, score_maps
 
 
 class UNet(BaseSegmenter):