Ver Fonte

fix scope bug for interpret

jiangjiajun há 5 anos atrás
pai
commit
b1ff0a34d0

+ 12 - 4
paddlex/interpret/interpretation_predict.py

@@ -15,11 +15,17 @@
 import numpy as np
 import cv2
 import copy
+import paddle.fluid as fluid
+from paddlex.cv.transforms import arrange_transforms
 
 
 def interpretation_predict(model, images):
     images = images.astype('float32')
-    model.arrange_transforms(transforms=model.test_transforms, mode='test')
+    arrange_transforms(
+        model.model_type,
+        model.__class__.__name__,
+        transforms=model.test_transforms,
+        mode='test')
     tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
     model.test_transforms.transforms = model.test_transforms.transforms[-2:]
 
@@ -29,9 +35,11 @@ def interpretation_predict(model, images):
         new_imgs.append(model.test_transforms(images[i])[0])
 
     new_imgs = np.array(new_imgs)
-    out = model.exe.run(model.test_prog,
-                        feed={'image': new_imgs},
-                        fetch_list=list(model.interpretation_feats.values()))
+    with fluid.scope_guard(model.scope):
+        out = model.exe.run(
+            model.test_prog,
+            feed={'image': new_imgs},
+            fetch_list=list(model.interpretation_feats.values()))
 
     model.test_transforms.transforms = tmp_transforms
 

+ 14 - 5
paddlex/interpret/visualize.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -22,6 +22,7 @@ from .interpretation_predict import interpretation_predict
 from .core.interpretation import Interpretation
 from .core.normlime_base import precompute_global_classifier
 from .core._session_preparation import gen_user_home
+from paddlex.cv.transforms import arrange_transforms
 
 
 def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
@@ -48,7 +49,11 @@ def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
             'The interpretation only can deal with the Normal model')
     if not osp.exists(save_dir):
         os.makedirs(save_dir)
-    model.arrange_transforms(transforms=model.test_transforms, mode='test')
+    arrange_transforms(
+        model.model_type,
+        model.__class__.__name__,
+        transforms=model.test_transforms,
+        mode='test')
     tmp_transforms = copy.deepcopy(model.test_transforms)
     tmp_transforms.transforms = tmp_transforms.transforms[:-2]
     img = tmp_transforms(img_file)[0]
@@ -94,7 +99,11 @@ def normlime(img_file,
             'The interpretation only can deal with the Normal model')
     if not osp.exists(save_dir):
         os.makedirs(save_dir)
-    model.arrange_transforms(transforms=model.test_transforms, mode='test')
+    arrange_transforms(
+        model.model_type,
+        model.__class__.__name__,
+        transforms=model.test_transforms,
+        mode='test')
     tmp_transforms = copy.deepcopy(model.test_transforms)
     tmp_transforms.transforms = tmp_transforms.transforms[:-2]
     img = tmp_transforms(img_file)[0]