|
|
@@ -360,18 +360,19 @@ class DeepLabv3p(BaseAPI):
|
|
|
pred = pred[0:num_samples]
|
|
|
|
|
|
for i in range(num_samples):
|
|
|
- one_pred = pred[i].astype('uint8')
|
|
|
+ one_pred = np.squeeze(pred[i]).astype('uint8')
|
|
|
one_label = labels[i]
|
|
|
for info in im_info[i][::-1]:
|
|
|
if info[0] == 'resize':
|
|
|
w, h = info[1][1], info[1][0]
|
|
|
- one_pred = cv2.resize(one_pred, (w, h), cv2.INTER_NEAREST)
|
|
|
+ one_pred = cv2.resize(one_pred, (w, h),
|
|
|
+ cv2.INTER_NEAREST)
|
|
|
elif info[0] == 'padding':
|
|
|
w, h = info[1][1], info[1][0]
|
|
|
one_pred = one_pred[0:h, 0:w]
|
|
|
else:
|
|
|
- raise Exception("Unexpected info '{}' in im_info".format(
|
|
|
- info[0]))
|
|
|
+ raise Exception(
|
|
|
+ "Unexpected info '{}' in im_info".format(info[0]))
|
|
|
one_pred = one_pred.astype('int64')
|
|
|
one_pred = one_pred[np.newaxis, :, :, np.newaxis]
|
|
|
one_label = one_label[np.newaxis, np.newaxis, :, :]
|