浏览代码

fix python

syyxsxx 5 年之前
父节点
当前提交
b1395ec55d
共有 1 个文件被更改,包括 7 次插入5 次删除
  1. 7 5
      deploy/openvino/python/deploy.py

+ 7 - 5
deploy/openvino/python/deploy.py

@@ -180,11 +180,11 @@ class Predictor:
         """
         """
         it = iter(self.net.outputs)
         it = iter(self.net.outputs)
         next(it)
         next(it)
+        label_name = next(it)
+        label_map = np.squeeze(preds[label_name]).astype('uint8')
         score_name = next(it)
         score_name = next(it)
         score_map = np.squeeze(preds[score_name])
         score_map = np.squeeze(preds[score_name])
         score_map = np.transpose(score_map, (1, 2, 0))
         score_map = np.transpose(score_map, (1, 2, 0))
-        label_name = next(it)
-        label_map = np.squeeze(preds[label_name]).astype('uint8')
         im_info = preprocessed_inputs['im_info']
         im_info = preprocessed_inputs['im_info']
         for info in im_info[::-1]:
         for info in im_info[::-1]:
             if info[0] == 'resize':
             if info[0] == 'resize':
@@ -203,10 +203,12 @@ class Predictor:
     def detector_postprocess(self, preds, preprocessed_inputs):
     def detector_postprocess(self, preds, preprocessed_inputs):
         """对图像检测结果做后处理
         """对图像检测结果做后处理
         """
         """
-        output_name = next(iter(self.net.outputs))
-        outputs = preds[output_name][0]
+        outputs = self.net.outputs
+        for name in outpus:
+            if (len(outputs[name].shape == 3)):
+                output = preds[name][0]
         result = []
         result = []
-        for out in outputs:
+        for out in output:
             if (out[0] > 0):
             if (out[0] > 0):
                 result.append(out.tolist())
                 result.append(out.tolist())
             else:
             else: