|
|
@@ -15,6 +15,7 @@
|
|
|
import math
|
|
|
import os.path as osp
|
|
|
import numpy as np
|
|
|
+import cv2
|
|
|
from collections import OrderedDict
|
|
|
import paddle
|
|
|
import paddle.nn.functional as F
|
|
|
@@ -107,7 +108,10 @@ class BaseSegmenter(BaseModel):
|
|
|
logit, origin_shape, transforms=inputs[2])
|
|
|
label_map = paddle.argmax(
|
|
|
score_map, axis=1, keepdim=True, dtype='int32')
|
|
|
- score_map = paddle.nn.functional.softmax(score_map, axis=1)
|
|
|
+ score_map = paddle.transpose(
|
|
|
+ paddle.nn.functional.softmax(
|
|
|
+ score_map, axis=1),
|
|
|
+ perm=[0, 2, 3, 1])
|
|
|
score_map = paddle.squeeze(score_map)
|
|
|
label_map = paddle.squeeze(label_map)
|
|
|
outputs = {'label_map': label_map, 'score_map': score_map}
|
|
|
@@ -464,7 +468,7 @@ class BaseSegmenter(BaseModel):
|
|
|
{"label map": `label map`, "score_map": `score map`}.
|
|
|
If img_file is a list, the result is a list composed of dicts with the corresponding fields:
|
|
|
label_map(np.ndarray): the predicted label map
|
|
|
- score_map(np.ndarray): the prediction score map
|
|
|
+ score_map(np.ndarray): the prediction score map (NHWC)
|
|
|
|
|
|
"""
|
|
|
if transforms is None and not hasattr(self, 'test_transforms'):
|
|
|
@@ -568,19 +572,30 @@ class BaseSegmenter(BaseModel):
|
|
|
batch_origin_shape, transforms)
|
|
|
results = list()
|
|
|
for pred, restore_list in zip(batch_pred, batch_restore_list):
|
|
|
- pred = paddle.unsqueeze(pred, axis=0)
|
|
|
+ if not isinstance(pred, np.ndarray):
|
|
|
+ pred = paddle.unsqueeze(pred, axis=0)
|
|
|
for item in restore_list[::-1]:
|
|
|
# TODO: 替换成cv2的interpolate(部署阶段无法使用paddle op)
|
|
|
h, w = item[1][0], item[1][1]
|
|
|
if item[0] == 'resize':
|
|
|
- pred = F.interpolate(pred, (h, w), mode='nearest')
|
|
|
+ if not isinstance(pred, np.ndarray):
|
|
|
+ pred = F.interpolate(pred, (h, w), mode='nearest')
|
|
|
+ else:
|
|
|
+ pred = cv2.resize(
|
|
|
+ pred, (h, w), interpolation=cv2.INTER_NEAREST)
|
|
|
elif item[0] == 'padding':
|
|
|
x, y = item[2]
|
|
|
- pred = pred[:, :, y:y + h, x:x + w]
|
|
|
+ if not isinstance(pred, np.ndarray):
|
|
|
+ pred = pred[:, :, y:y + h, x:x + w]
|
|
|
+ else:
|
|
|
+ pred = pred[..., y:y + h, x:x + w]
|
|
|
else:
|
|
|
pass
|
|
|
results.append(pred)
|
|
|
- batch_pred = paddle.concat(results, axis=0)
|
|
|
+ if not isinstance(pred, np.ndarray):
|
|
|
+ batch_pred = paddle.concat(results, axis=0)
|
|
|
+ else:
|
|
|
+ batch_pred = np.stack(results, axis=0)
|
|
|
return batch_pred
|
|
|
|
|
|
|