Эх сурвалжийг харах

optimize seg python deploy postprocess

will-jl944 4 жил өмнө
parent
commit
a2804bf07c
1 өөрчлөгдсөн 6 нэмэгдсэн , 13 устгасан
  1. 6 13
      paddlex/deploy.py

+ 6 - 13
paddlex/deploy.py

@@ -147,34 +147,27 @@ class Predictor(object):
         if self._model.model_type == 'classifier':
         if self._model.model_type == 'classifier':
             true_topk = min(self._model.num_classes, topk)
             true_topk = min(self._model.num_classes, topk)
             preds = self._model._postprocess(net_outputs[0], true_topk)
             preds = self._model._postprocess(net_outputs[0], true_topk)
-            if len(preds) == 1:
-                preds = preds[0]
         elif self._model.model_type == 'segmenter':
         elif self._model.model_type == 'segmenter':
             label_map, score_map = self._model._postprocess(
             label_map, score_map = self._model._postprocess(
                 net_outputs,
                 net_outputs,
                 batch_origin_shape=ori_shape,
                 batch_origin_shape=ori_shape,
                 transforms=transforms.transforms)
                 transforms=transforms.transforms)
-            label_map = np.squeeze(label_map)
-            score_map = np.squeeze(score_map)
-            if score_map.ndim == 3:
-                preds = {'label_map': label_map, 'score_map': score_map}
-            else:
-                preds = [{
-                    'label_map': l,
-                    'score_map': s
-                } for l, s in zip(label_map, score_map)]
+            preds = [{
+                'label_map': l,
+                'score_map': s
+            } for l, s in zip(label_map, score_map)]
         elif self._model.model_type == 'detector':
         elif self._model.model_type == 'detector':
             net_outputs = {
             net_outputs = {
                 k: v
                 k: v
                 for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
                 for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
             }
             }
             preds = self._model._postprocess(net_outputs)
             preds = self._model._postprocess(net_outputs)
-            if len(preds) == 1:
-                preds = preds[0]
         else:
         else:
             logging.error(
             logging.error(
                 "Invalid model type {}.".format(self._model.model_type),
                 "Invalid model type {}.".format(self._model.model_type),
                 exit=True)
                 exit=True)
+        if len(preds) == 1:
+            preds = preds[0]
 
 
         return preds
         return preds