sunyanfang01 5 жил өмнө
parent
commit
1b0d4d53da

+ 1 - 1
paddlex/__init__.py

@@ -28,7 +28,7 @@ from . import seg
 from . import cls
 from . import cls
 from . import slim
 from . import slim
 from . import tools
 from . import tools
-from . import explanation
+from . import interpret
 
 
 try:
 try:
     import pycocotools
     import pycocotools

+ 1 - 1
paddlex/cv/models/classifier.py

@@ -275,7 +275,7 @@ class BaseClassifier(BaseAPI):
         } for l in pred_label]
         } for l in pred_label]
         return res
         return res
     
     
-    def explanation_predict(self, images):
+    def interpretation_predict(self, images):
         self.arrange_transforms(
         self.arrange_transforms(
                 transforms=self.test_transforms, mode='test')
                 transforms=self.test_transforms, mode='test')
         new_imgs = []
         new_imgs = []

BIN
paddlex/cv/models/interpret/__pycache__/visualize.cpython-37.pyc


BIN
paddlex/cv/models/interpret/as_data_reader/__pycache__/data_path_utils.cpython-37.pyc


BIN
paddlex/cv/models/interpret/as_data_reader/__pycache__/readers.cpython-37.pyc


+ 0 - 0
paddlex/cv/models/explanation/as_data_reader/data_path_utils.py → paddlex/cv/models/interpret/as_data_reader/data_path_utils.py


+ 0 - 0
paddlex/cv/models/explanation/as_data_reader/readers.py → paddlex/cv/models/interpret/as_data_reader/readers.py


BIN
paddlex/cv/models/interpret/core/__pycache__/_session_preparation.cpython-37.pyc


BIN
paddlex/cv/models/interpret/core/__pycache__/interpretation.cpython-37.pyc


BIN
paddlex/cv/models/interpret/core/__pycache__/interpretation_algorithms.cpython-37.pyc


BIN
paddlex/cv/models/interpret/core/__pycache__/lime_base.cpython-37.pyc


BIN
paddlex/cv/models/interpret/core/__pycache__/normlime_base.cpython-37.pyc


+ 0 - 0
paddlex/cv/models/explanation/core/_session_preparation.py → paddlex/cv/models/interpret/core/_session_preparation.py


+ 9 - 9
paddlex/cv/models/explanation/core/explanation.py → paddlex/cv/models/interpret/core/interpretation.py

@@ -12,31 +12,31 @@
 #See the License for the specific language governing permissions and
 #See the License for the specific language governing permissions and
 #limitations under the License.
 #limitations under the License.
 
 
-from .explanation_algorithms import CAM, LIME, NormLIME
+from .interpretation_algorithms import CAM, LIME, NormLIME
 from .normlime_base import precompute_normlime_weights
 from .normlime_base import precompute_normlime_weights
 
 
 
 
-class Explanation(object):
+class Interpretation(object):
     """
     """
-    Base class for all explanation algorithms.
+    Base class for all interpretation algorithms.
     """
     """
-    def __init__(self, explanation_algorithm_name, predict_fn, label_names, **kwargs):
+    def __init__(self, interpretation_algorithm_name, predict_fn, label_names, **kwargs):
         supported_algorithms = {
         supported_algorithms = {
             'cam': CAM,
             'cam': CAM,
             'lime': LIME,
             'lime': LIME,
             'normlime': NormLIME
             'normlime': NormLIME
         }
         }
 
 
-        self.algorithm_name = explanation_algorithm_name.lower()
+        self.algorithm_name = interpretation_algorithm_name.lower()
         assert self.algorithm_name in supported_algorithms.keys()
         assert self.algorithm_name in supported_algorithms.keys()
         self.predict_fn = predict_fn
         self.predict_fn = predict_fn
 
 
-        # initialization for the explanation algorithm.
-        self.explain_algorithm = supported_algorithms[self.algorithm_name](
+        # initialization for the interpretation algorithm.
+        self.algorithm = supported_algorithms[self.algorithm_name](
             self.predict_fn, label_names, **kwargs
             self.predict_fn, label_names, **kwargs
         )
         )
 
 
-    def explain(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
+    def interpret(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
         """
         """
 
 
         Args:
         Args:
@@ -48,4 +48,4 @@ class Explanation(object):
         Returns:
         Returns:
 
 
         """
         """
-        return self.explain_algorithm.explain(data_, visualization, save_to_disk, save_dir)
+        return self.algorithm.interpret(data_, visualization, save_to_disk, save_dir)

+ 30 - 28
paddlex/cv/models/explanation/core/explanation_algorithms.py → paddlex/cv/models/interpret/core/interpretation_algorithms.py

@@ -46,12 +46,13 @@ class CAM(object):
         logit = result[0][0]
         logit = result[0][0]
         if abs(np.sum(logit) - 1.0) > 1e-4:
         if abs(np.sum(logit) - 1.0) > 1e-4:
             # softmax
             # softmax
+            logit = logit - np.max(logit)
             exp_result = np.exp(logit)
             exp_result = np.exp(logit)
             probability = exp_result / np.sum(exp_result)
             probability = exp_result / np.sum(exp_result)
         else:
         else:
             probability = logit
             probability = logit
 
 
-        # only explain top 1
+        # only interpret top 1
         pred_label = np.argsort(probability)
         pred_label = np.argsort(probability)
         pred_label = pred_label[-1:]
         pred_label = pred_label[-1:]
 
 
@@ -71,7 +72,7 @@ class CAM(object):
         print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
         print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
         return feature_maps, fc_weights
         return feature_maps, fc_weights
 
 
-    def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+    def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
         feature_maps, fc_weights = self.preparation_cam(data_)
         feature_maps, fc_weights = self.preparation_cam(data_)
         cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
         cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
 
 
@@ -123,7 +124,7 @@ class LIME(object):
         self.predict_fn = predict_fn
         self.predict_fn = predict_fn
         self.labels = None
         self.labels = None
         self.image = None
         self.image = None
-        self.lime_explainer = None
+        self.lime_interpreter = None
         self.label_names = label_names
         self.label_names = label_names
 
 
     def preparation_lime(self, data_):
     def preparation_lime(self, data_):
@@ -134,12 +135,13 @@ class LIME(object):
 
 
         if abs(np.sum(result) - 1.0) > 1e-4:
         if abs(np.sum(result) - 1.0) > 1e-4:
             # softmax
             # softmax
+            result = result - np.max(result)
             exp_result = np.exp(result)
             exp_result = np.exp(result)
             probability = exp_result / np.sum(exp_result)
             probability = exp_result / np.sum(exp_result)
         else:
         else:
             probability = result
             probability = result
 
 
-        # only explain top 1
+        # only interpret top 1
         pred_label = np.argsort(probability)
         pred_label = np.argsort(probability)
         pred_label = pred_label[-1:]
         pred_label = pred_label[-1:]
 
 
@@ -156,14 +158,14 @@ class LIME(object):
         print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
         print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
 
 
         end = time.time()
         end = time.time()
-        algo = lime_base.LimeImageExplainer()
-        explainer = algo.explain_instance(self.image, self.predict_fn, self.labels, 0,
-                                          num_samples=self.num_samples, batch_size=self.batch_size)
-        self.lime_explainer = explainer
+        algo = lime_base.LimeImageInterpreter()
+        interpreter = algo.interpret_instance(self.image, self.predict_fn, self.labels, 0,
+                                              num_samples=self.num_samples, batch_size=self.batch_size)
+        self.lime_interpreter = interpreter
         print('lime time: ', time.time() - end, 's.')
         print('lime time: ', time.time() - end, 's.')
 
 
-    def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
-        if self.lime_explainer is None:
+    def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+        if self.lime_interpreter is None:
             self.preparation_lime(data_)
             self.preparation_lime(data_)
 
 
         if visualization or save_to_disk:
         if visualization or save_to_disk:
@@ -187,13 +189,13 @@ class LIME(object):
             axes[0].imshow(self.image)
             axes[0].imshow(self.image)
             axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
             axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
 
 
-            axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments))
+            axes[1].imshow(mark_boundaries(self.image, self.lime_interpreter.segments))
             axes[1].set_title("superpixel segmentation")
             axes[1].set_title("superpixel segmentation")
 
 
             # LIME visualization
             # LIME visualization
             for i, w in enumerate(weights_choices):
             for i, w in enumerate(weights_choices):
-                num_to_show = auto_choose_num_features_to_show(self.lime_explainer, l, w)
-                temp, mask = self.lime_explainer.get_image_and_mask(
+                num_to_show = auto_choose_num_features_to_show(self.lime_interpreter, l, w)
+                temp, mask = self.lime_interpreter.get_image_and_mask(
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 )
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
@@ -274,20 +276,20 @@ class NormLIME(object):
         print('performing NormLIME operations ...')
         print('performing NormLIME operations ...')
 
 
         cluster_labels = self.predict_cluster_labels(
         cluster_labels = self.predict_cluster_labels(
-            compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_explainer.segments
+            compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_interpreter.segments
         )
         )
 
 
         g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
         g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
 
 
         return g_weights
         return g_weights
 
 
-    def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+    def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
         if self.normlime_weights is None:
         if self.normlime_weights is None:
             raise ValueError("Not find the correct precomputed NormLIME result. \n"
             raise ValueError("Not find the correct precomputed NormLIME result. \n"
                              "\t Try to call compute_normlime_weights() first or load the correct path.")
                              "\t Try to call compute_normlime_weights() first or load the correct path.")
 
 
         g_weights = self.preparation_normlime(data_)
         g_weights = self.preparation_normlime(data_)
-        lime_weights = self._lime.lime_explainer.local_exp
+        lime_weights = self._lime.lime_interpreter.local_weights
 
 
         if visualization or save_to_disk:
         if visualization or save_to_disk:
             import matplotlib.pyplot as plt
             import matplotlib.pyplot as plt
@@ -312,23 +314,23 @@ class NormLIME(object):
             axes[0].imshow(self.image)
             axes[0].imshow(self.image)
             axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
             axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
 
 
-            axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments))
+            axes[1].imshow(mark_boundaries(self.image, self._lime.lime_interpreter.segments))
             axes[1].set_title("superpixel segmentation")
             axes[1].set_title("superpixel segmentation")
 
 
             # LIME visualization
             # LIME visualization
             for i, w in enumerate(weights_choices):
             for i, w in enumerate(weights_choices):
-                num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
+                num_to_show = auto_choose_num_features_to_show(self._lime.lime_interpreter, l, w)
                 nums_to_show.append(num_to_show)
                 nums_to_show.append(num_to_show)
-                temp, mask = self._lime.lime_explainer.get_image_and_mask(
+                temp, mask = self._lime.lime_interpreter.get_image_and_mask(
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 )
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
                 axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels")
                 axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels")
 
 
             # NormLIME visualization
             # NormLIME visualization
-            self._lime.lime_explainer.local_exp = g_weights
+            self._lime.lime_interpreter.local_weights = g_weights
             for i, num_to_show in enumerate(nums_to_show):
             for i, num_to_show in enumerate(nums_to_show):
-                temp, mask = self._lime.lime_explainer.get_image_and_mask(
+                temp, mask = self._lime.lime_interpreter.get_image_and_mask(
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 )
                 axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
                 axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
@@ -336,15 +338,15 @@ class NormLIME(object):
 
 
             # NormLIME*LIME visualization
             # NormLIME*LIME visualization
             combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
             combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
-            self._lime.lime_explainer.local_exp = combined_weights
+            self._lime.lime_interpreter.local_weights = combined_weights
             for i, num_to_show in enumerate(nums_to_show):
             for i, num_to_show in enumerate(nums_to_show):
-                temp, mask = self._lime.lime_explainer.get_image_and_mask(
+                temp, mask = self._lime.lime_interpreter.get_image_and_mask(
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 )
                 axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
                 axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
                 axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels")
                 axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels")
 
 
-            self._lime.lime_explainer.local_exp = lime_weights
+            self._lime.lime_interpreter.local_weights = lime_weights
 
 
         if save_to_disk and save_outdir is not None:
         if save_to_disk and save_outdir is not None:
             os.makedirs(save_outdir, exist_ok=True)
             os.makedirs(save_outdir, exist_ok=True)
@@ -354,9 +356,9 @@ class NormLIME(object):
             plt.show()
             plt.show()
 
 
 
 
-def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
-    segments = lime_explainer.segments
-    lime_weights = lime_explainer.local_exp[label]
+def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show):
+    segments = lime_interpreter.segments
+    lime_weights = lime_interpreter.local_weights[label]
     num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
     num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
 
 
     # l1 norm with filtered weights.
     # l1 norm with filtered weights.
@@ -381,7 +383,7 @@ def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
         return 5
         return 5
 
 
     if n == 0:
     if n == 0:
-        return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1)
+        return auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show-0.1)
 
 
     return n
     return n
 
 

+ 89 - 75
paddlex/cv/models/explanation/core/lime_base.py → paddlex/cv/models/interpret/core/lime_base.py

@@ -1,18 +1,33 @@
-#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
-
-from __future__ import print_function
+"""
+Copyright (c) 2016, Marco Tulio Correia Ribeiro
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+  list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+  this list of conditions and the following disclaimer in the documentation
+  and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+"""
+The code in this file (lime_base.py) is modified from https://github.com/marcotcr/lime.
+"""
+
+
 import numpy as np
 import numpy as np
 import scipy as sp
 import scipy as sp
 import sklearn
 import sklearn
@@ -88,7 +103,7 @@ class LimeBase(object):
         return np.array(used_features)
         return np.array(used_features)
 
 
     def feature_selection(self, data, labels, weights, num_features, method):
     def feature_selection(self, data, labels, weights, num_features, method):
-        """Selects features for the model. see explain_instance_with_data to
+        """Selects features for the model. see interpret_instance_with_data to
            understand the parameters."""
            understand the parameters."""
         if method == 'none':
         if method == 'none':
             return np.array(range(data.shape[1]))
             return np.array(range(data.shape[1]))
@@ -154,15 +169,15 @@ class LimeBase(object):
             return self.feature_selection(data, labels, weights,
             return self.feature_selection(data, labels, weights,
                                           num_features, n_method)
                                           num_features, n_method)
 
 
-    def explain_instance_with_data(self,
-                                   neighborhood_data,
-                                   neighborhood_labels,
-                                   distances,
-                                   label,
-                                   num_features,
-                                   feature_selection='auto',
-                                   model_regressor=None):
-        """Takes perturbed data, labels and distances, returns explanation.
+    def interpret_instance_with_data(self,
+                                     neighborhood_data,
+                                     neighborhood_labels,
+                                     distances,
+                                     label,
+                                     num_features,
+                                     feature_selection='auto',
+                                     model_regressor=None):
+        """Takes perturbed data, labels and distances, returns interpretation.
 
 
         Args:
         Args:
             neighborhood_data: perturbed data, 2d array. first element is
             neighborhood_data: perturbed data, 2d array. first element is
@@ -170,8 +185,8 @@ class LimeBase(object):
             neighborhood_labels: corresponding perturbed labels. should have as
             neighborhood_labels: corresponding perturbed labels. should have as
                                  many columns as the number of possible labels.
                                  many columns as the number of possible labels.
             distances: distances to original data point.
             distances: distances to original data point.
-            label: label for which we want an explanation
-            num_features: maximum number of features in explanation
+            label: label for which we want an interpretation
+            num_features: maximum number of features in interpretation
             feature_selection: how to select num_features. options are:
             feature_selection: how to select num_features. options are:
                 'forward_selection': iteratively add features to the model.
                 'forward_selection': iteratively add features to the model.
                     This is costly when num_features is high
                     This is costly when num_features is high
@@ -183,7 +198,7 @@ class LimeBase(object):
                 'none': uses all features, ignores num_features
                 'none': uses all features, ignores num_features
                 'auto': uses forward_selection if num_features <= 6, and
                 'auto': uses forward_selection if num_features <= 6, and
                     'highest_weights' otherwise.
                     'highest_weights' otherwise.
-            model_regressor: sklearn regressor to use in explanation.
+            model_regressor: sklearn regressor to use in interpretation.
                 Defaults to Ridge regression if None. Must have
                 Defaults to Ridge regression if None. Must have
                 model_regressor.coef_ and 'sample_weight' as a parameter
                 model_regressor.coef_ and 'sample_weight' as a parameter
                 to model_regressor.fit()
                 to model_regressor.fit()
@@ -194,8 +209,8 @@ class LimeBase(object):
             exp is a sorted list of tuples, where each tuple (x,y) corresponds
             exp is a sorted list of tuples, where each tuple (x,y) corresponds
             to the feature id (x) and the local weight (y). The list is sorted
             to the feature id (x) and the local weight (y). The list is sorted
             by decreasing absolute value of y.
             by decreasing absolute value of y.
-            score is the R^2 value of the returned explanation
-            local_pred is the prediction of the explanation model on the original instance
+            score is the R^2 value of the returned interpretation
+            local_pred is the prediction of the interpretation model on the original instance
         """
         """
 
 
         weights = self.kernel_fn(distances)
         weights = self.kernel_fn(distances)
@@ -227,7 +242,7 @@ class LimeBase(object):
                 prediction_score, local_pred)
                 prediction_score, local_pred)
 
 
 
 
-class ImageExplanation(object):
+class ImageInterpretation(object):
     def __init__(self, image, segments):
     def __init__(self, image, segments):
         """Init function.
         """Init function.
 
 
@@ -238,7 +253,7 @@ class ImageExplanation(object):
         self.image = image
         self.image = image
         self.segments = segments
         self.segments = segments
         self.intercept = {}
         self.intercept = {}
-        self.local_exp = {}
+        self.local_weights = {}
         self.local_pred = None
         self.local_pred = None
 
 
     def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
     def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
@@ -246,40 +261,40 @@ class ImageExplanation(object):
         """Init function.
         """Init function.
 
 
         Args:
         Args:
-            label: label to explain
+            label: label to interpret
             positive_only: if True, only take superpixels that positively contribute to
             positive_only: if True, only take superpixels that positively contribute to
                 the prediction of the label.
                 the prediction of the label.
             negative_only: if True, only take superpixels that negatively contribute to
             negative_only: if True, only take superpixels that negatively contribute to
                 the prediction of the label. If false, and so is positive_only, then both
                 the prediction of the label. If false, and so is positive_only, then both
                 negativey and positively contributions will be taken.
                 negativey and positively contributions will be taken.
                 Both can't be True at the same time
                 Both can't be True at the same time
-            hide_rest: if True, make the non-explanation part of the return
+            hide_rest: if True, make the non-interpretation part of the return
                 image gray
                 image gray
-            num_features: number of superpixels to include in explanation
-            min_weight: minimum weight of the superpixels to include in explanation
+            num_features: number of superpixels to include in interpretation
+            min_weight: minimum weight of the superpixels to include in interpretation
 
 
         Returns:
         Returns:
             (image, mask), where image is a 3d numpy array and mask is a 2d
             (image, mask), where image is a 3d numpy array and mask is a 2d
             numpy array that can be used with
             numpy array that can be used with
             skimage.segmentation.mark_boundaries
             skimage.segmentation.mark_boundaries
         """
         """
-        if label not in self.local_exp:
-            raise KeyError('Label not in explanation')
+        if label not in self.local_weights:
+            raise KeyError('Label not in interpretation')
         if positive_only & negative_only:
         if positive_only & negative_only:
             raise ValueError("Positive_only and negative_only cannot be true at the same time.")
             raise ValueError("Positive_only and negative_only cannot be true at the same time.")
         segments = self.segments
         segments = self.segments
         image = self.image
         image = self.image
-        exp = self.local_exp[label]
+        local_weights_label = self.local_weights[label]
         mask = np.zeros(segments.shape, segments.dtype)
         mask = np.zeros(segments.shape, segments.dtype)
         if hide_rest:
         if hide_rest:
             temp = np.zeros(self.image.shape)
             temp = np.zeros(self.image.shape)
         else:
         else:
             temp = self.image.copy()
             temp = self.image.copy()
         if positive_only:
         if positive_only:
-            fs = [x[0] for x in exp
+            fs = [x[0] for x in local_weights_label
                   if x[1] > 0 and x[1] > min_weight][:num_features]
                   if x[1] > 0 and x[1] > min_weight][:num_features]
         if negative_only:
         if negative_only:
-            fs = [x[0] for x in exp
+            fs = [x[0] for x in local_weights_label
                   if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
                   if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
         if positive_only or negative_only:
         if positive_only or negative_only:
             for f in fs:
             for f in fs:
@@ -287,7 +302,7 @@ class ImageExplanation(object):
                 mask[segments == f] = 1
                 mask[segments == f] = 1
             return temp, mask
             return temp, mask
         else:
         else:
-            for f, w in exp[:num_features]:
+            for f, w in local_weights_label[:num_features]:
                 if np.abs(w) < min_weight:
                 if np.abs(w) < min_weight:
                     continue
                     continue
                 c = 0 if w < 0 else 1
                 c = 0 if w < 0 else 1
@@ -300,32 +315,31 @@ class ImageExplanation(object):
         """
         """
 
 
         Args:
         Args:
-            label: label to explain
+            label: label to interpret
             min_weight:
             min_weight:
 
 
         Returns:
         Returns:
             image, is a 3d numpy array
             image, is a 3d numpy array
         """
         """
-        if label not in self.local_exp:
-            raise KeyError('Label not in explanation')
+        if label not in self.local_weights:
+            raise KeyError('Label not in interpretation')
 
 
         from matplotlib import cm
         from matplotlib import cm
 
 
         segments = self.segments
         segments = self.segments
         image = self.image
         image = self.image
-        exp = self.local_exp[label]
+        local_weights_label = self.local_weights[label]
         temp = np.zeros_like(image)
         temp = np.zeros_like(image)
 
 
-        weight_max = abs(exp[0][1])
-        exp = [(f, w/weight_max) for f, w in exp]
-        exp = sorted(exp, key=lambda x: x[1], reverse=True)  # negatives are at last.
+        weight_max = abs(local_weights_label[0][1])
+        local_weights_label = [(f, w/weight_max) for f, w in local_weights_label]
+        local_weights_label = sorted(local_weights_label, key=lambda x: x[1], reverse=True)  # negatives are at last.
 
 
         cmaps = cm.get_cmap('Spectral')
         cmaps = cm.get_cmap('Spectral')
-        # sigmoid_space = 1 / (1 + np.exp(-np.linspace(-20, 20, len(exp))))
-        colors = cmaps(np.linspace(0, 1, len(exp)))
+        colors = cmaps(np.linspace(0, 1, len(local_weights_label)))
         colors = colors[:, :3]
         colors = colors[:, :3]
 
 
-        for i, (f, w) in enumerate(exp):
+        for i, (f, w) in enumerate(local_weights_label):
             if np.abs(w) < min_weight:
             if np.abs(w) < min_weight:
                 continue
                 continue
             temp[segments == f] = image[segments == f].copy()
             temp[segments == f] = image[segments == f].copy()
@@ -333,14 +347,14 @@ class ImageExplanation(object):
         return temp
         return temp
 
 
 
 
-class LimeImageExplainer(object):
-    """Explains predictions on Image (i.e. matrix) data.
+class LimeImageInterpreter(object):
+    """Interpres predictions on Image (i.e. matrix) data.
     For numerical features, perturb them by sampling from a Normal(0,1) and
     For numerical features, perturb them by sampling from a Normal(0,1) and
     doing the inverse operation of mean-centering and scaling, according to the
     doing the inverse operation of mean-centering and scaling, according to the
     means and stds in the training data. For categorical features, perturb by
     means and stds in the training data. For categorical features, perturb by
     sampling according to the training distribution, and making a binary
     sampling according to the training distribution, and making a binary
     feature that is 1 when the value is the same as the instance being
     feature that is 1 when the value is the same as the instance being
-    explained."""
+    interpreted."""
 
 
     def __init__(self, kernel_width=.25, kernel=None, verbose=False,
     def __init__(self, kernel_width=.25, kernel=None, verbose=False,
                  feature_selection='auto', random_state=None):
                  feature_selection='auto', random_state=None):
@@ -355,7 +369,7 @@ class LimeImageExplainer(object):
             verbose: if true, print local prediction values from linear model
             verbose: if true, print local prediction values from linear model
             feature_selection: feature selection method. can be
             feature_selection: feature selection method. can be
                 'forward_selection', 'lasso_path', 'none' or 'auto'.
                 'forward_selection', 'lasso_path', 'none' or 'auto'.
-                See function 'explain_instance_with_data' in lime_base.py for
+                See function 'einterpret_instance_with_data' in lime_base.py for
                 details on what each of the options does.
                 details on what each of the options does.
             random_state: an integer or numpy.RandomState that will be used to
             random_state: an integer or numpy.RandomState that will be used to
                 generate random numbers. If None, the random state will be
                 generate random numbers. If None, the random state will be
@@ -373,18 +387,18 @@ class LimeImageExplainer(object):
         self.feature_selection = feature_selection
         self.feature_selection = feature_selection
         self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
         self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
 
 
-    def explain_instance(self, image, classifier_fn, labels=(1,),
-                         hide_color=None,
-                         num_features=100000, num_samples=1000,
-                         batch_size=10,
-                         distance_metric='cosine',
-                         model_regressor=None
-                         ):
-        """Generates explanations for a prediction.
+    def interpret_instance(self, image, classifier_fn, labels=(1,),
+                           hide_color=None,
+                           num_features=100000, num_samples=1000,
+                           batch_size=10,
+                           distance_metric='cosine',
+                           model_regressor=None
+                           ):
+        """Generates interpretations for a prediction.
 
 
         First, we generate neighborhood data by randomly perturbing features
         First, we generate neighborhood data by randomly perturbing features
         from the instance (see __data_inverse). We then learn locally weighted
         from the instance (see __data_inverse). We then learn locally weighted
-        linear models on this neighborhood data to explain each of the classes
+        linear models on this neighborhood data to interpret each of the classes
         in an interpretable way (see lime_base.py).
         in an interpretable way (see lime_base.py).
 
 
         Args:
         Args:
@@ -393,19 +407,19 @@ class LimeImageExplainer(object):
             classifier_fn: classifier prediction probability function, which
             classifier_fn: classifier prediction probability function, which
                 takes a numpy array and outputs prediction probabilities.  For
                 takes a numpy array and outputs prediction probabilities.  For
                 ScikitClassifiers , this is classifier.predict_proba.
                 ScikitClassifiers , this is classifier.predict_proba.
-            labels: iterable with labels to be explained.
+            labels: iterable with labels to be interpreted.
             hide_color: TODO
             hide_color: TODO
-            num_features: maximum number of features present in explanation
+            num_features: maximum number of features present in interpretation
             num_samples: size of the neighborhood to learn the linear model
             num_samples: size of the neighborhood to learn the linear model
             batch_size: TODO
             batch_size: TODO
             distance_metric: the distance metric to use for weights.
             distance_metric: the distance metric to use for weights.
-            model_regressor: sklearn regressor to use in explanation. Defaults
+            model_regressor: sklearn regressor to use in interpretation. Defaults
             to Ridge regression in LimeBase. Must have model_regressor.coef_
             to Ridge regression in LimeBase. Must have model_regressor.coef_
             and 'sample_weight' as a parameter to model_regressor.fit()
             and 'sample_weight' as a parameter to model_regressor.fit()
 
 
         Returns:
         Returns:
-            An ImageExplanation object (see lime_image.py) with the corresponding
-            explanations.
+            An ImageIinterpretation object (see lime_image.py) with the corresponding
+            interpretations.
         """
         """
         if len(image.shape) == 2:
         if len(image.shape) == 2:
             image = gray2rgb(image)
             image = gray2rgb(image)
@@ -455,15 +469,15 @@ class LimeImageExplainer(object):
             metric=distance_metric
             metric=distance_metric
         ).ravel()
         ).ravel()
 
 
-        ret_exp = ImageExplanation(image, segments)
+        interpretation_image = ImageInterpretation(image, segments)
         for label in top:
         for label in top:
-            (ret_exp.intercept[label],
-             ret_exp.local_exp[label],
-             ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data(
+            (interpretation_image.intercept[label],
+             interpretation_image.local_weights[label],
+             interpretation_image.score, interpretation_image.local_pred) = self.base.interpret_instance_with_data(
                 data, labels, distances, label, num_features,
                 data, labels, distances, label, num_features,
                 model_regressor=model_regressor,
                 model_regressor=model_regressor,
                 feature_selection=self.feature_selection)
                 feature_selection=self.feature_selection)
-        return ret_exp
+        return interpretation_image
 
 
     def data_labels(self,
     def data_labels(self,
                     image,
                     image,

+ 6 - 6
paddlex/cv/models/explanation/core/normlime_base.py → paddlex/cv/models/interpret/core/normlime_base.py

@@ -87,11 +87,11 @@ def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_
     return compute_normlime_weights(fname_list, save_dir, num_samples)
     return compute_normlime_weights(fname_list, save_dir, num_samples)
 
 
 
 
-def save_one_lime_predict_and_kmean_labels(lime_exp_all_weights, image_pred_labels, cluster_labels, save_path):
+def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, cluster_labels, save_path):
 
 
     lime_weights = {}
     lime_weights = {}
     for label in image_pred_labels:
     for label in image_pred_labels:
-        lime_weights[label] = lime_exp_all_weights[label]
+        lime_weights[label] = lime_all_weights[label]
 
 
     for_normlime_weights = {
     for_normlime_weights = {
         'lime_weights': lime_weights,  # a dict: class_label: (seg_label, weight)
         'lime_weights': lime_weights,  # a dict: class_label: (seg_label, weight)
@@ -145,15 +145,15 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
 
 
         pred_label = pred_label[:top_k]
         pred_label = pred_label[:top_k]
 
 
-        algo = lime_base.LimeImageExplainer()
-        explainer = algo.explain_instance(image_show[0], predict_fn, pred_label, 0,
+        algo = lime_base.LimeImageInterpreter()
+        interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0,
                                           num_samples=num_samples, batch_size=batch_size)
                                           num_samples=num_samples, batch_size=batch_size)
 
 
         cluster_labels = kmeans_model.predict(
         cluster_labels = kmeans_model.predict(
-            get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), explainer.segments)
+            get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments)
         )
         )
         save_one_lime_predict_and_kmean_labels(
         save_one_lime_predict_and_kmean_labels(
-            explainer.local_exp, pred_label,
+            interpreter.local_weights, pred_label,
             cluster_labels,
             cluster_labels,
             save_path
             save_path
         )
         )

+ 20 - 20
paddlex/cv/models/explanation/visualize.py → paddlex/cv/models/interpret/visualize.py

@@ -17,19 +17,19 @@ import cv2
 import copy
 import copy
 import os.path as osp
 import os.path as osp
 import numpy as np
 import numpy as np
-from .core.explanation import Explanation
+from .core.interpretation import Interpretation
 from .core.normlime_base import precompute_normlime_weights
 from .core.normlime_base import precompute_normlime_weights
 
 
 
 
 def visualize(img_file, 
 def visualize(img_file, 
               model, 
               model, 
               dataset=None,
               dataset=None,
-              explanation_type='lime',
+              algo='lime',
               num_samples=3000, 
               num_samples=3000, 
               batch_size=50,
               batch_size=50,
               save_dir='./'):
               save_dir='./'):
     if model.status != 'Normal':
     if model.status != 'Normal':
-        raise Exception('The explanation only can deal with the Normal model')
+        raise Exception('The interpretation only can deal with the Normal model')
     model.arrange_transforms(
     model.arrange_transforms(
                 transforms=model.test_transforms, mode='test')
                 transforms=model.test_transforms, mode='test')
     tmp_transforms = copy.deepcopy(model.test_transforms)
     tmp_transforms = copy.deepcopy(model.test_transforms)
@@ -37,48 +37,48 @@ def visualize(img_file,
     img = tmp_transforms(img_file)[0]
     img = tmp_transforms(img_file)[0]
     img = np.around(img).astype('uint8')
     img = np.around(img).astype('uint8')
     img = np.expand_dims(img, axis=0)
     img = np.expand_dims(img, axis=0)
-    explaier = None
-    if explanation_type == 'lime':
-        explaier = get_lime_explaier(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
-    elif explanation_type == 'normlime':
+    interpreter = None
+    if algo == 'lime':
+        interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
+    elif algo == 'normlime':
         if dataset is None:
         if dataset is None:
-            raise Exception('The dataset is None. Cannot implement this kind of explanation')
-        explaier = get_normlime_explaier(img, model, dataset, 
+            raise Exception('The dataset is None. Cannot implement this kind of interpretation')
+        interpreter = get_normlime_interpreter(img, model, dataset, 
                                      num_samples=num_samples, batch_size=batch_size,
                                      num_samples=num_samples, batch_size=batch_size,
                                      save_dir=save_dir)
                                      save_dir=save_dir)
     else:
     else:
-        raise Exception('The {} explanantion method is not supported yet!'.format(explanation_type))
+        raise Exception('The {} interpretation method is not supported yet!'.format(algo))
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
-    explaier.explain(img, save_dir=save_dir)
+    interpreter.interpret(img, save_dir=save_dir)
     
     
     
     
-def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
+def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
     def predict_func(image):
     def predict_func(image):
         image = image.astype('float32')
         image = image.astype('float32')
         for i in range(image.shape[0]):
         for i in range(image.shape[0]):
             image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
             image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
         tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
-        out = model.explanation_predict(image)
+        out = model.interpretation_predict(image)
         model.test_transforms.transforms = tmp_transforms
         model.test_transforms.transforms = tmp_transforms
         return out[0]
         return out[0]
     labels_name = None
     labels_name = None
     if dataset is not None:
     if dataset is not None:
         labels_name = dataset.labels
         labels_name = dataset.labels
-    explaier = Explanation('lime', 
+    interpreter = Interpretation('lime', 
                             predict_func,
                             predict_func,
                             labels_name,
                             labels_name,
                             num_samples=num_samples, 
                             num_samples=num_samples, 
                             batch_size=batch_size)
                             batch_size=batch_size)
-    return explaier
+    return interpreter
 
 
 
 
-def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
+def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
     def precompute_predict_func(image):
     def precompute_predict_func(image):
         image = image.astype('float32')
         image = image.astype('float32')
         tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
-        out = model.explanation_predict(image)
+        out = model.interpretation_predict(image)
         model.test_transforms.transforms = tmp_transforms
         model.test_transforms.transforms = tmp_transforms
         return out[0]
         return out[0]
     def predict_func(image):
     def predict_func(image):
@@ -87,7 +87,7 @@ def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50,
             image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
             image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
         tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
-        out = model.explanation_predict(image)
+        out = model.interpretation_predict(image)
         model.test_transforms.transforms = tmp_transforms
         model.test_transforms.transforms = tmp_transforms
         return out[0]
         return out[0]
     labels_name = None
     labels_name = None
@@ -105,13 +105,13 @@ def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50,
                                       num_samples=num_samples, 
                                       num_samples=num_samples, 
                                       batch_size=batch_size,
                                       batch_size=batch_size,
                                       save_dir=save_dir)
                                       save_dir=save_dir)
-    explaier = Explanation('normlime', 
+    interpreter = Interpretation('normlime', 
                             predict_func,
                             predict_func,
                             labels_name,
                             labels_name,
                             num_samples=num_samples, 
                             num_samples=num_samples, 
                             batch_size=batch_size,
                             batch_size=batch_size,
                             normlime_weights=npy_dir)
                             normlime_weights=npy_dir)
-    return explaier
+    return interpreter
 
 
 
 
 def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
 def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):

+ 1 - 1
paddlex/explanation.py → paddlex/interpret.py

@@ -13,6 +13,6 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from __future__ import absolute_import
 from __future__ import absolute_import
-from .cv.models.explanation import visualize
+from .cv.models.interpret import visualize
 
 
 visualize = visualize.visualize
 visualize = visualize.visualize