浏览代码

update predict function

chenguowei01 5 年之前
父节点
当前提交
cd965f2504
共有 1 个文件被更改,包括 6 次插入1 次删除
  1. 6 1
      paddlex/cv/models/deeplabv3p.py

+ 6 - 1
paddlex/cv/models/deeplabv3p.py

@@ -400,14 +400,19 @@ class DeepLabv3p(BaseAPI):
             fetch_list=list(self.test_outputs.values()))
         pred = result[0]
         pred = np.squeeze(pred).astype('uint8')
+        logit = result[1]
+        logit = np.squeeze(logit)
+        logit = np.transpose(logit, (1, 2, 0))
         for info in im_info[::-1]:
             if info[0] == 'resize':
                 w, h = info[1][1], info[1][0]
                 pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
+                logit = cv2.resize(logit, (w, h), cv2.INTER_LINEAR)
             elif info[0] == 'padding':
                 w, h = info[1][1], info[1][0]
                 pred = pred[0:h, 0:w]
+                logit = logit[0:h, 0:w, :]
             else:
                 raise Exception("Unexpected info '{}' in im_info".format(
                     info[0]))
-        return {'label_map': pred, 'score_map': result[1]}
+        return {'label_map': pred, 'score_map': logit}