瀏覽代碼

batch_prediction of segmenter returns list of dict

will-jl944 4 年之前
父節點
當前提交
2f573a3678
共有 1 個文件被更改,包括 7 次插入1 次删除
  1. 7 1
      dygraph/paddlex/cv/models/segmenter.py

+ 7 - 1
dygraph/paddlex/cv/models/segmenter.py

@@ -463,7 +463,13 @@ class BaseSegmenter(BaseModel):
         label_map = label_map.numpy().astype('uint8')
         score_map = outputs['score_map']
         score_map = score_map.numpy().astype('float32')
-        return {'label_map': label_map, 'score_map': score_map}
+        prediction = [{
+            'label_map': l,
+            'score_map': s
+        } for l, s in zip(label_map, score_map)]
+        if isinstance(img_file, (str, np.ndarray)):
+            prediction = prediction[0]
+        return prediction
 
     def _preprocess(self, images, transforms, model_type):
         arrange_transforms(