瀏覽代碼

fix seg prediction bug when input is single image

will-jl944 4 年之前
父節點
當前提交
1180aeef4e
共有 1 個文件被更改,包括 9 次插入6 次删除
  1. 9 6
      dygraph/paddlex/cv/models/segmenter.py

+ 9 - 6
dygraph/paddlex/cv/models/segmenter.py

@@ -478,12 +478,15 @@ class BaseSegmenter(BaseModel):
         label_map = label_map.numpy().astype('uint8')
         score_map = outputs['score_map']
         score_map = score_map.numpy().astype('float32')
-        prediction = [{
-            'label_map': l,
-            'score_map': s
-        } for l, s in zip(label_map, score_map)]
-        if isinstance(img_file, (str, np.ndarray)):
-            prediction = prediction[0]
+        if isinstance(img_file, list) and len(img_file) > 1:
+            prediction = [{
+                'label_map': l,
+                'score_map': s
+            } for l, s in zip(label_map, score_map)]
+        elif isinstance(img_file, list):
+            prediction = [{'label_map': label_map, 'score_map': score_map}]
+        else:
+            prediction = {'label_map': label_map, 'score_map': score_map}
         return prediction
 
     def _preprocess(self, images, transforms, model_type):