|
|
@@ -104,22 +104,28 @@ class BaseSegmenter(BaseModel):
|
|
|
outputs = OrderedDict()
|
|
|
if mode == 'test':
|
|
|
origin_shape = inputs[1]
|
|
|
- score_map = self._postprocess(
|
|
|
- logit, origin_shape, transforms=inputs[2])
|
|
|
- label_map = paddle.argmax(
|
|
|
- score_map, axis=1, keepdim=True, dtype='int32')
|
|
|
- score_map = paddle.transpose(
|
|
|
- paddle.nn.functional.softmax(
|
|
|
- score_map, axis=1),
|
|
|
- perm=[0, 2, 3, 1])
|
|
|
- score_map = paddle.squeeze(score_map)
|
|
|
- label_map = paddle.squeeze(label_map)
|
|
|
- outputs = {'label_map': label_map, 'score_map': score_map}
|
|
|
+ if self.status == 'Infer':
|
|
|
+ score_map, label_map = self._postprocess(
|
|
|
+ net_out, origin_shape, transforms=inputs[2])
|
|
|
+ else:
|
|
|
+ logit = self._postprocess(
|
|
|
+ logit, origin_shape, transforms=inputs[2])
|
|
|
+ score_map = paddle.transpose(
|
|
|
+ F.softmax(
|
|
|
+ logit, axis=1), perm=[0, 2, 3, 1])
|
|
|
+ label_map = paddle.argmax(
|
|
|
+ score_map, axis=-1, keepdim=True, dtype='int32')
|
|
|
+ outputs['label_map'] = paddle.squeeze(label_map)
|
|
|
+ outputs['score_map'] = paddle.squeeze(score_map)
|
|
|
+
|
|
|
if mode == 'eval':
|
|
|
- pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
|
|
|
+ if self.status == 'Infer':
|
|
|
+ pred = paddle.transpose(net_out[1], perm=[0, 3, 1, 2])
|
|
|
+ else:
|
|
|
+ pred = paddle.argmax(
|
|
|
+ logit, axis=1, keepdim=True, dtype='int32')
|
|
|
label = inputs[1]
|
|
|
origin_shape = [label.shape[-2:]]
|
|
|
- # TODO: 替换cv2后postprocess移出run
|
|
|
pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
|
|
|
intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
|
|
|
pred, label, self.num_classes)
|
|
|
@@ -570,33 +576,72 @@ class BaseSegmenter(BaseModel):
|
|
|
def _postprocess(self, batch_pred, batch_origin_shape, transforms):
|
|
|
batch_restore_list = BaseSegmenter.get_transforms_shape_info(
|
|
|
batch_origin_shape, transforms)
|
|
|
- results = list()
|
|
|
+ if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
|
|
|
+ return self._infer_postprocess(
|
|
|
+ batch_score_map=batch_pred[0],
|
|
|
+ batch_label_map=batch_pred[1],
|
|
|
+ batch_restore_list=batch_restore_list)
|
|
|
+ results = []
|
|
|
for pred, restore_list in zip(batch_pred, batch_restore_list):
|
|
|
- if not isinstance(pred, np.ndarray):
|
|
|
- pred = paddle.unsqueeze(pred, axis=0)
|
|
|
+ pred = paddle.unsqueeze(pred, axis=0)
|
|
|
for item in restore_list[::-1]:
|
|
|
- # TODO: 替换成cv2的interpolate(部署阶段无法使用paddle op)
|
|
|
h, w = item[1][0], item[1][1]
|
|
|
if item[0] == 'resize':
|
|
|
- if not isinstance(pred, np.ndarray):
|
|
|
- pred = F.interpolate(pred, (h, w), mode='nearest')
|
|
|
+ pred = F.interpolate(
|
|
|
+ pred, (h, w), mode='nearest', data_format='NCHW')
|
|
|
+ elif item[0] == 'padding':
|
|
|
+ x, y = item[2]
|
|
|
+ pred = pred[:, :, y:y + h, x:x + w]
|
|
|
+ else:
|
|
|
+ pass
|
|
|
+ results.append(pred)
|
|
|
+ batch_pred = paddle.concat(results, axis=0)
|
|
|
+ return batch_pred
|
|
|
+
|
|
|
+ def _infer_postprocess(self, batch_score_map, batch_label_map,
|
|
|
+ batch_restore_list):
|
|
|
+ score_maps = []
|
|
|
+ label_maps = []
|
|
|
+ for score_map, label_map, restore_list in zip(
|
|
|
+ batch_score_map, batch_label_map, batch_restore_list):
|
|
|
+ if not isinstance(score_map, np.ndarray):
|
|
|
+ score_map = paddle.unsqueeze(score_map, axis=0)
|
|
|
+ label_map = paddle.unsqueeze(label_map, axis=0)
|
|
|
+ for item in restore_list[::-1]:
|
|
|
+ h, w = item[1][0], item[1][1]
|
|
|
+ if item[0] == 'resize':
|
|
|
+ if isinstance(score_map, np.ndarray):
|
|
|
+ score_map = cv2.resize(
|
|
|
+ score_map, (h, w), interpolation=cv2.INTER_LINEAR)
|
|
|
+ label_map = cv2.resize(
|
|
|
+ label_map, (h, w), interpolation=cv2.INTER_NEAREST)
|
|
|
else:
|
|
|
- pred = cv2.resize(
|
|
|
- pred, (h, w), interpolation=cv2.INTER_NEAREST)
|
|
|
+ score_map = F.interpolate(
|
|
|
+ score_map, (h, w),
|
|
|
+ mode='bilinear',
|
|
|
+ data_format='NHWC')
|
|
|
+ label_map = F.interpolate(
|
|
|
+ label_map, (h, w),
|
|
|
+ mode='nearest',
|
|
|
+ data_format='NHWC')
|
|
|
elif item[0] == 'padding':
|
|
|
x, y = item[2]
|
|
|
- if not isinstance(pred, np.ndarray):
|
|
|
- pred = pred[:, :, y:y + h, x:x + w]
|
|
|
+ if isinstance(score_map, np.ndarray):
|
|
|
+ score_map = score_map[..., y:y + h, x:x + w]
|
|
|
+ label_map = label_map[..., y:y + h, x:x + w]
|
|
|
else:
|
|
|
- pred = pred[..., y:y + h, x:x + w]
|
|
|
+ score_map = score_map[:, :, y:y + h, x:x + w]
|
|
|
+ label_map = label_map[:, :, y:y + h, x:x + w]
|
|
|
else:
|
|
|
pass
|
|
|
- results.append(pred)
|
|
|
- if not isinstance(pred, np.ndarray):
|
|
|
- batch_pred = paddle.concat(results, axis=0)
|
|
|
+ score_maps.append(score_map)
|
|
|
+ label_maps.append(label_map)
|
|
|
+ if isinstance(score_maps[0], np.ndarray):
|
|
|
+ return np.stack(score_maps, axis=0), np.stack(label_maps, axis=0)
|
|
|
else:
|
|
|
- batch_pred = np.stack(results, axis=0)
|
|
|
- return batch_pred
|
|
|
+ return paddle.concat(
|
|
|
+ score_maps, axis=0), paddle.concat(
|
|
|
+ label_maps, axis=0)
|
|
|
|
|
|
|
|
|
class UNet(BaseSegmenter):
|