SunAhong1993 пре 5 година
родитељ
комит
fd882edb26
1 измењених фајлова са 5 додато и 3 уклоњено
  1. 5 3
      paddlex/cv/models/classifier.py

+ 5 - 3
paddlex/cv/models/classifier.py

@@ -63,7 +63,9 @@ class BaseClassifier(BaseAPI):
         net_out = model(image, num_classes=self.num_classes)
         softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
         inputs = OrderedDict([('image', image)])
-        outputs = OrderedDict([('predict', softmax_out), ('logits', net_out)])
+        outputs = OrderedDict([('predict', softmax_out)])
+        if mode == 'test':
+            self.explanation_feats = OrderedDict([('logits', net_out)])
         if mode != 'test':
             cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
             avg_cost = fluid.layers.mean(cost)
@@ -284,8 +286,8 @@ class BaseClassifier(BaseAPI):
         result = self.exe.run(
             self.test_prog,
             feed={'image': new_imgs},
-            fetch_list=list(self.test_outputs.values()))
-        return result[1:]
+            fetch_list=list(self.explanation_feats.values()))
+        return result
 
 class ResNet18(BaseClassifier):
     def __init__(self, num_classes=1000):