Browse Source

modify the interpret

sunyanfang01 5 years ago
parent
commit
ab2942c123

+ 1 - 14
paddlex/cv/models/classifier.py

@@ -274,20 +274,7 @@ class BaseClassifier(BaseAPI):
             'score': result[0][0][l]
         } for l in pred_label]
         return res
-    
-    def interpretation_predict(self, images):
-        self.arrange_transforms(
-                transforms=self.test_transforms, mode='test')
-        new_imgs = []
-        for i in range(images.shape[0]):
-            img = images[i]
-            new_imgs.append(self.test_transforms(img)[0])
-        new_imgs = np.array(new_imgs)
-        result = self.exe.run(
-            self.test_prog,
-            feed={'image': new_imgs},
-            fetch_list=list(self.explanation_feats.values()))
-        return result
+
 
 class ResNet18(BaseClassifier):
     def __init__(self, num_classes=1000):

+ 1 - 1
paddlex/interpret.py → paddlex/interpret/__init__.py

@@ -13,6 +13,6 @@
 # limitations under the License.
 
 from __future__ import absolute_import
-from .cv.models.interpret import visualize
+from . import visualize
 
 visualize = visualize.visualize

BIN
paddlex/interpret/__pycache__/__init__.cpython-37.pyc


BIN
paddlex/interpret/__pycache__/interpretation_predict.cpython-37.pyc


BIN
paddlex/interpret/__pycache__/visualize.cpython-37.pyc


BIN
paddlex/interpret/as_data_reader/__pycache__/data_path_utils.cpython-37.pyc


BIN
paddlex/interpret/as_data_reader/__pycache__/readers.cpython-37.pyc


+ 0 - 0
paddlex/cv/models/interpret/as_data_reader/data_path_utils.py → paddlex/interpret/as_data_reader/data_path_utils.py


+ 0 - 0
paddlex/cv/models/interpret/as_data_reader/readers.py → paddlex/interpret/as_data_reader/readers.py


BIN
paddlex/interpret/core/__pycache__/_session_preparation.cpython-37.pyc


BIN
paddlex/interpret/core/__pycache__/interpretation.cpython-37.pyc


BIN
paddlex/interpret/core/__pycache__/interpretation_algorithms.cpython-37.pyc


BIN
paddlex/interpret/core/__pycache__/lime_base.cpython-37.pyc


BIN
paddlex/interpret/core/__pycache__/normlime_base.cpython-37.pyc


+ 0 - 0
paddlex/cv/models/interpret/core/_session_preparation.py → paddlex/interpret/core/_session_preparation.py


+ 0 - 0
paddlex/cv/models/interpret/core/interpretation.py → paddlex/interpret/core/interpretation.py


+ 0 - 0
paddlex/cv/models/interpret/core/interpretation_algorithms.py → paddlex/interpret/core/interpretation_algorithms.py


+ 0 - 0
paddlex/cv/models/interpret/core/lime_base.py → paddlex/interpret/core/lime_base.py


+ 0 - 0
paddlex/cv/models/interpret/core/normlime_base.py → paddlex/interpret/core/normlime_base.py


+ 29 - 0
paddlex/interpret/interpretation_predict.py

@@ -0,0 +1,29 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+def interpretation_predict(model, images):
+    model.arrange_transforms(
+            transforms=model.test_transforms, mode='test')
+    new_imgs = []
+    for i in range(images.shape[0]):
+        img = images[i]
+        new_imgs.append(model.test_transforms(img)[0])
+    new_imgs = np.array(new_imgs)
+    result = model.exe.run(
+        model.test_prog,
+        feed={'image': new_imgs},
+        fetch_list=list(model.explanation_feats.values()))
+    return result

+ 16 - 3
paddlex/cv/models/interpret/visualize.py → paddlex/interpret/visualize.py

@@ -17,6 +17,7 @@ import cv2
 import copy
 import os.path as osp
 import numpy as np
+from .interpretation_predict import interpretation_predict
 from .core.interpretation import Interpretation
 from .core.normlime_base import precompute_normlime_weights
 
@@ -28,6 +29,18 @@ def visualize(img_file,
               num_samples=3000, 
               batch_size=50,
               save_dir='./'):
+    """可解释性可视化。
+    Args:
+        img_file (str): 预测图像路径。
+        model (paddlex.cv.models): paddlex中的模型。
+        dataset (paddlex.datasets): 数据集读取器,默认为None。
+        algo (str): 可解释性方式,当前可选'lime'和'normlime'。
+        num_samples (int): 随机采样数量,默认为3000。
+        batch_size (int): 预测数据batch大小,默认为50。
+        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。        
+    """
+    assert model.model_type == 'classifier', \
+        'Now the interpretation visualize only be supported in classifier!'
     if model.status != 'Normal':
         raise Exception('The interpretation only can deal with the Normal model')
     model.arrange_transforms(
@@ -59,7 +72,7 @@ def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
             image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
         tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
-        out = model.interpretation_predict(image)
+        out = interpretation_predict(model, image)
         model.test_transforms.transforms = tmp_transforms
         return out[0]
     labels_name = None
@@ -78,7 +91,7 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
         image = image.astype('float32')
         tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
-        out = model.interpretation_predict(image)
+        out = interpretation_predict(model, image)
         model.test_transforms.transforms = tmp_transforms
         return out[0]
     def predict_func(image):
@@ -87,7 +100,7 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
             image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
         tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
-        out = model.interpretation_predict(image)
+        out = interpretation_predict(model, image)
         model.test_transforms.transforms = tmp_transforms
         return out[0]
     labels_name = None