Pārlūkot izejas kodu

update predict function

chenguowei01 5 gadi atpakaļ
vecāks
revīzija
cd965f2504
1 mainītis faili ar 6 papildinājumiem un 1 dzēšanām
  1. 6 1
      paddlex/cv/models/deeplabv3p.py

+ 6 - 1
paddlex/cv/models/deeplabv3p.py

@@ -400,14 +400,19 @@ class DeepLabv3p(BaseAPI):
             fetch_list=list(self.test_outputs.values()))
         pred = result[0]
         pred = np.squeeze(pred).astype('uint8')
+        logit = result[1]
+        logit = np.squeeze(logit)
+        logit = np.transpose(logit, (1, 2, 0))
         for info in im_info[::-1]:
             if info[0] == 'resize':
                 w, h = info[1][1], info[1][0]
                 pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
+                logit = cv2.resize(logit, (w, h), cv2.INTER_LINEAR)
             elif info[0] == 'padding':
                 w, h = info[1][1], info[1][0]
                 pred = pred[0:h, 0:w]
+                logit = logit[0:h, 0:w, :]
             else:
                 raise Exception("Unexpected info '{}' in im_info".format(
                     info[0]))
-        return {'label_map': pred, 'score_map': result[1]}
+        return {'label_map': pred, 'score_map': logit}