Эх сурвалжийг харах

batch_prediction of segmenter returns list of dict

will-jl944 4 жил өмнө
parent
commit
2f573a3678

+ 7 - 1
dygraph/paddlex/cv/models/segmenter.py

@@ -463,7 +463,13 @@ class BaseSegmenter(BaseModel):
         label_map = label_map.numpy().astype('uint8')
         score_map = outputs['score_map']
         score_map = score_map.numpy().astype('float32')
-        return {'label_map': label_map, 'score_map': score_map}
+        prediction = [{
+            'label_map': l,
+            'score_map': s
+        } for l, s in zip(label_map, score_map)]
+        if isinstance(img_file, (str, np.ndarray)):
+            prediction = prediction[0]
+        return prediction
 
     def _preprocess(self, images, transforms, model_type):
         arrange_transforms(