瀏覽代碼

Update classifier.py

SunAhong1993 5 年之前
父節點
當前提交
fd882edb26
共有 1 個文件被更改,包括 5 次插入3 次删除
  1. 5 3
      paddlex/cv/models/classifier.py

+ 5 - 3
paddlex/cv/models/classifier.py

@@ -63,7 +63,9 @@ class BaseClassifier(BaseAPI):
         net_out = model(image, num_classes=self.num_classes)
         softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
         inputs = OrderedDict([('image', image)])
-        outputs = OrderedDict([('predict', softmax_out), ('logits', net_out)])
+        outputs = OrderedDict([('predict', softmax_out)])
+        if mode == 'test':
+            self.explanation_feats = OrderedDict([('logits', net_out)])
         if mode != 'test':
             cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
             avg_cost = fluid.layers.mean(cost)
@@ -284,8 +286,8 @@ class BaseClassifier(BaseAPI):
         result = self.exe.run(
             self.test_prog,
             feed={'image': new_imgs},
-            fetch_list=list(self.test_outputs.values()))
-        return result[1:]
+            fetch_list=list(self.explanation_feats.values()))
+        return result
 
 class ResNet18(BaseClassifier):
     def __init__(self, num_classes=1000):