|
|
@@ -478,12 +478,15 @@ class BaseSegmenter(BaseModel):
|
|
|
label_map = label_map.numpy().astype('uint8')
|
|
|
score_map = outputs['score_map']
|
|
|
score_map = score_map.numpy().astype('float32')
|
|
|
- prediction = [{
|
|
|
- 'label_map': l,
|
|
|
- 'score_map': s
|
|
|
- } for l, s in zip(label_map, score_map)]
|
|
|
- if isinstance(img_file, (str, np.ndarray)):
|
|
|
- prediction = prediction[0]
|
|
|
+ if isinstance(img_file, list) and len(img_file) > 1:
|
|
|
+ prediction = [{
|
|
|
+ 'label_map': l,
|
|
|
+ 'score_map': s
|
|
|
+ } for l, s in zip(label_map, score_map)]
|
|
|
+ elif isinstance(img_file, list):
|
|
|
+ prediction = [{'label_map': label_map, 'score_map': score_map}]
|
|
|
+ else:
|
|
|
+ prediction = {'label_map': label_map, 'score_map': score_map}
|
|
|
return prediction
|
|
|
|
|
|
def _preprocess(self, images, transforms, model_type):
|