Ver Fonte

add lime

sunyanfang01 há 5 anos atrás
pai
commit
011cff21f2

+ 16 - 10
paddlex/cv/models/classifier.py

@@ -27,7 +27,6 @@ from .base import BaseAPI
 
 class BaseClassifier(BaseAPI):
     """构建分类器,并实现其训练、评估、预测和模型导出。
-
     Args:
         model_name (str): 分类器的模型名字,取值范围为['ResNet18',
                           'ResNet34', 'ResNet50', 'ResNet101',
@@ -61,10 +60,10 @@ class BaseClassifier(BaseAPI):
         if mode != 'test':
             label = fluid.data(dtype='int64', shape=[None, 1], name='label')
         model = getattr(paddlex.cv.nets, str.lower(self.model_name))
-        net_out = model(image, num_classes=self.num_classes)
+        net_out, feat = model(image, num_classes=self.num_classes)
         softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
         inputs = OrderedDict([('image', image)])
-        outputs = OrderedDict([('predict', softmax_out)])
+        outputs = OrderedDict([('predict', softmax_out), ('net_out', feat[-1])])
         if mode != 'test':
             cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
             avg_cost = fluid.layers.mean(cost)
@@ -115,7 +114,6 @@ class BaseClassifier(BaseAPI):
               early_stop_patience=5,
               resume_checkpoint=None):
         """训练。
-
         Args:
             num_epochs (int): 训练迭代轮数。
             train_dataset (paddlex.datasets): 训练数据读取器。
@@ -139,7 +137,6 @@ class BaseClassifier(BaseAPI):
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
             resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
-
         Raises:
             ValueError: 模型从inference model进行加载。
         """
@@ -183,13 +180,11 @@ class BaseClassifier(BaseAPI):
                  epoch_id=None,
                  return_details=False):
         """评估。
-
         Args:
             eval_dataset (paddlex.datasets): 验证数据读取器。
             batch_size (int): 验证数据批大小。默认为1。
             epoch_id (int): 当前评估模型所在的训练轮数。
             return_details (bool): 是否返回详细信息。
-
         Returns:
           dict: 当return_details为False时,返回dict, 包含关键字:'acc1'、'acc5',
               分别表示最大值的accuracy、前5个最大值的accuracy。
@@ -248,12 +243,10 @@ class BaseClassifier(BaseAPI):
 
     def predict(self, img_file, transforms=None, topk=1):
         """预测。
-
         Args:
             img_file (str): 预测图像路径。
             transforms (paddlex.cls.transforms): 数据预处理操作。
             topk (int): 预测时前k个最大值。
-
         Returns:
             list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
             分别对应预测类别id、预测类别标签、预测得分。
@@ -279,7 +272,20 @@ class BaseClassifier(BaseAPI):
             'score': result[0][0][l]
         } for l in pred_label]
         return res
-
+    
+    def explanation_predict(self, images):
+        self.arrange_transforms(
+                transforms=self.test_transforms, mode='test')
+        new_imgs = []
+        for i in range(images.shape[0]):
+            img = images[i]
+            new_imgs.append(self.test_transforms(img)[0])
+        new_imgs = np.array(new_imgs)
+        result = self.exe.run(
+            self.test_prog,
+            feed={'image': new_imgs},
+            fetch_list=list(self.test_outputs.values()))
+        return result[1:]
 
 class ResNet18(BaseClassifier):
     def __init__(self, num_classes=1000):

+ 27 - 0
paddlex/cv/models/explanation/as_data_reader/data_path_utils.py

@@ -0,0 +1,27 @@
+import os
+
+
+def imagenet_val_files_and_labels(dataset_directory):
+    classes = open(os.path.join(dataset_directory, 'imagenet_lsvrc_2015_synsets.txt')).readlines()
+    class_to_indx = {classes[i].split('\n')[0]: i for i in range(len(classes))}
+
+    images_path = os.path.join(dataset_directory, 'val')
+    filenames = []
+    labels = []
+    lines = open(os.path.join(dataset_directory, 'imagenet_2012_validation_synset_labels.txt'), 'r').readlines()
+    for i, line in enumerate(lines):
+        class_name = line.split('\n')[0]
+        a = 'ILSVRC2012_val_%08d.JPEG' % (i + 1)
+        filenames.append(f'{images_path}/{a}')
+        labels.append(class_to_indx[class_name])
+        # print(filenames[-1], labels[-1])
+
+    return filenames, labels
+
+
+def _find_classes(dir):
+    # Faster and available in Python 3.5 and above
+    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
+    classes.sort()
+    class_to_idx = {classes[i]: i for i in range(len(classes))}
+    return classes, class_to_idx

+ 211 - 0
paddlex/cv/models/explanation/as_data_reader/readers.py

@@ -0,0 +1,211 @@
+import os
+import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+import cv2
+import numpy as np
+import six
+import glob
+from as_data_reader.data_path_utils import _find_classes
+from PIL import Image
+
+
+def resize_short(img, target_size, interpolation=None):
+    """resize image
+
+    Args:
+        img: image data
+        target_size: resize short target size
+        interpolation: interpolation mode
+
+    Returns:
+        resized image data
+    """
+    percent = float(target_size) / min(img.shape[0], img.shape[1])
+    resized_width = int(round(img.shape[1] * percent))
+    resized_height = int(round(img.shape[0] * percent))
+    if interpolation:
+        resized = cv2.resize(
+            img, (resized_width, resized_height), interpolation=interpolation)
+    else:
+        resized = cv2.resize(img, (resized_width, resized_height))
+    return resized
+
+
+def crop_image(img, target_size, center=True):
+    """crop image
+
+    Args:
+        img: images data
+        target_size: crop target size
+        center: crop mode
+
+    Returns:
+        img: cropped image data
+    """
+    height, width = img.shape[:2]
+    size = target_size
+    if center:
+        w_start = (width - size) // 2
+        h_start = (height - size) // 2
+    else:
+        w_start = np.random.randint(0, width - size + 1)
+        h_start = np.random.randint(0, height - size + 1)
+    w_end = w_start + size
+    h_end = h_start + size
+    img = img[h_start:h_end, w_start:w_end, :]
+    return img
+
+
+def preprocess_image(img, random_mirror=False):
+    """
+    centered, scaled by 1/255.
+    :param img: np.array: shape: [ns, h, w, 3], color order: rgb.
+    :return: np.array: shape: [ns, h, w, 3]
+    """
+    mean = [0.485, 0.456, 0.406]
+    std = [0.229, 0.224, 0.225]
+
+    # transpose to [ns, 3, h, w]
+    img = img.astype('float32').transpose((0, 3, 1, 2)) / 255
+
+    img_mean = np.array(mean).reshape((3, 1, 1))
+    img_std = np.array(std).reshape((3, 1, 1))
+    img -= img_mean
+    img /= img_std
+
+    if random_mirror:
+        mirror = int(np.random.uniform(0, 2))
+        if mirror == 1:
+            img = img[:, :, ::-1, :]
+
+    return img
+
+
+def read_image(img_path, target_size=256, crop_size=224):
+    """
+    resize_short to 256, then center crop to 224.
+    :param img_path: one image path
+    :return: np.array: shape: [1, h, w, 3], color order: rgb.
+    """
+
+    if isinstance(img_path, str):
+        with open(img_path, 'rb') as f:
+            img = Image.open(f)
+            img = img.convert('RGB')
+            img = np.array(img)
+            # img = cv2.imread(img_path)
+
+            img = resize_short(img, target_size, interpolation=None)
+            img = crop_image(img, target_size=crop_size, center=True)
+            # img = img[:, :, ::-1]
+            img = np.expand_dims(img, axis=0)
+            return img
+    elif isinstance(img_path, np.ndarray):
+        assert len(img_path.shape) == 4
+        return img_path
+    else:
+        ValueError(f"Not recognized data type {type(img_path)}.")
+
+
+class ReaderConfig(object):
+    """
+    A generic data loader where the images are arranged in this way:
+
+        root/train/dog/xxy.jpg
+        root/train/dog/xxz.jpg
+        ...
+        root/train/cat/nsdf3.jpg
+        root/train/cat/asd932_.jpg
+        ...
+
+        root/test/dog/xxx.jpg
+        ...
+        root/test/cat/123.jpg
+        ...
+
+    """
+    def __init__(self, dataset_dir, is_test):
+        image_paths, labels, self.num_classes = self.get_dataset_info(dataset_dir, is_test)
+        random_per = np.random.permutation(range(len(image_paths)))
+        self.image_paths = image_paths[random_per]
+        self.labels = labels[random_per]
+        self.is_test = is_test
+
+    def get_reader(self):
+        def reader():
+            IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+            target_size = 256
+            crop_size = 224
+
+            for i, img_path in enumerate(self.image_paths):
+                if not img_path.lower().endswith(IMG_EXTENSIONS):
+                    continue
+
+                img = cv2.imread(img_path)
+                if img is None:
+                    print(img_path)
+                    continue
+                img = resize_short(img, target_size, interpolation=None)
+                img = crop_image(img, crop_size, center=self.is_test)
+                img = img[:, :, ::-1]
+                img = np.expand_dims(img, axis=0)
+
+                img = preprocess_image(img, not self.is_test)
+
+                yield img, self.labels[i]
+
+        return reader
+
+    def get_dataset_info(self, dataset_dir, is_test=False):
+        IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+
+        # read
+        if is_test:
+            datasubset_dir = os.path.join(dataset_dir, 'test')
+        else:
+            datasubset_dir = os.path.join(dataset_dir, 'train')
+
+        class_names, class_to_idx = _find_classes(datasubset_dir)
+        # num_classes = len(class_names)
+        image_paths = []
+        labels = []
+        for class_name in class_names:
+            classes_dir = os.path.join(datasubset_dir, class_name)
+            for img_path in glob.glob(os.path.join(classes_dir, '*')):
+                if not img_path.lower().endswith(IMG_EXTENSIONS):
+                    continue
+
+                image_paths.append(img_path)
+                labels.append(class_to_idx[class_name])
+
+        image_paths = np.array(image_paths)
+        labels = np.array(labels)
+        return image_paths, labels, len(class_names)
+
+
+def create_reader(list_image_path, list_label=None, is_test=False):
+    def reader():
+        IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+        target_size = 256
+        crop_size = 224
+
+        for i, img_path in enumerate(list_image_path):
+            if not img_path.lower().endswith(IMG_EXTENSIONS):
+                continue
+
+            img = cv2.imread(img_path)
+            if img is None:
+                print(img_path)
+                continue
+
+            img = resize_short(img, target_size, interpolation=None)
+            img = crop_image(img, crop_size, center=is_test)
+            img = img[:, :, ::-1]
+            img_show = np.expand_dims(img, axis=0)
+
+            img = preprocess_image(img_show, not is_test)
+
+            label = 0 if list_label is None else list_label[i]
+
+            yield img_show, img, label
+
+    return reader

+ 13 - 0
paddlex/cv/models/explanation/core/_session_preparation.py

@@ -0,0 +1,13 @@
+import os
+import paddle.fluid as fluid
+import numpy as np
+
+
+def paddle_get_fc_weights(var_name="fc_0.w_0"):
+    fc_weights = fluid.global_scope().find_var(var_name).get_tensor()
+    return np.array(fc_weights)
+
+
+def paddle_resize(extracted_features, outsize):
+    resized_features = fluid.layers.resize_bilinear(extracted_features, outsize)
+    return resized_features

+ 37 - 0
paddlex/cv/models/explanation/core/explanation.py

@@ -0,0 +1,37 @@
+from .explanation_algorithms import CAM, LIME, NormLIME
+
+
+class Explanation(object):
+    """
+    Base class for all explanation algorithms.
+    """
+    def __init__(self, explanation_algorithm_name, predict_fn, **kwargs):
+        supported_algorithms = {
+            'cam': CAM,
+            'lime': LIME,
+            'normlime': NormLIME
+        }
+
+        self.algorithm_name = explanation_algorithm_name.lower()
+        assert self.algorithm_name in supported_algorithms.keys()
+        self.predict_fn = predict_fn
+
+        # initialization for the explanation algorithm.
+        self.explain_algorithm = supported_algorithms[self.algorithm_name](
+            self.predict_fn, **kwargs
+        )
+
+    def explain(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
+        """
+
+        Args:
+            data_: data_ can be a path or numpy.ndarray.
+            visualization: whether to show using matplotlib.
+            save_to_disk: whether to save the figure in local disk.
+            save_dir: dir to save figure if save_to_disk is True.
+
+        Returns:
+
+        """
+        return self.explain_algorithm.explain(data_, visualization, save_to_disk, save_dir)
+

+ 458 - 0
paddlex/cv/models/explanation/core/explanation_algorithms.py

@@ -0,0 +1,458 @@
+import os
+import numpy as np
+import time
+
+from . import lime_base
+from ..as_data_reader.readers import read_image
+from ._session_preparation import paddle_get_fc_weights
+
+import cv2
+
+
+class CAM(object):
+    def __init__(self, predict_fn):
+        """
+
+        Args:
+            predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255)
+                        output: [
+                        logits [N, num_classes],
+                        feature map before global average pooling [N, num_channels, h_, w_]
+                        ]
+
+        """
+        self.predict_fn = predict_fn
+
+    def preparation_cam(self, data_path):
+        image_show = read_image(data_path)
+        result = self.predict_fn(image_show)
+
+        logit = result[0][0]
+        if abs(np.sum(logit) - 1.0) > 1e-4:
+            # softmax
+            exp_result = np.exp(logit)
+            probability = exp_result / np.sum(exp_result)
+        else:
+            probability = logit
+
+        # only explain top 1
+        pred_label = np.argsort(probability)
+        pred_label = pred_label[-1:]
+
+        self.predicted_label = pred_label[0]
+        self.predicted_probability = probability[pred_label[0]]
+        self.image = image_show[0]
+        self.labels = pred_label
+
+        fc_weights = paddle_get_fc_weights()
+        feature_maps = result[1]
+
+        print('predicted result: ', pred_label[0], probability[pred_label[0]])
+        return feature_maps, fc_weights
+
+    def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+        feature_maps, fc_weights = self.preparation_cam(data_)
+        cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
+
+        if visualization or save_to_disk:
+            import matplotlib.pyplot as plt
+            from skimage.segmentation import mark_boundaries
+            l = self.labels[0]
+
+            psize = 5
+            nrows = 1
+            ncols = 2
+
+            plt.close()
+            f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            for ax in axes.ravel():
+                ax.axis("off")
+            axes = axes.ravel()
+            axes[0].imshow(self.image)
+            axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
+
+            axes[1].imshow(cam)
+            axes[1].set_title("CAM")
+
+        if save_to_disk and save_outdir is not None:
+            os.makedirs(save_outdir, exist_ok=True)
+            save_fig(data_, save_outdir, 'cam')
+
+        if visualization:
+            plt.show()
+
+        return
+
+
+class LIME(object):
+    def __init__(self, predict_fn, num_samples=3000, batch_size=50):
+        """
+        LIME wrapper. See lime_base.py for the detailed LIME implementation.
+        Args:
+            predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME.
+            num_samples: the number of samples that LIME takes for fitting.
+            batch_size: batch size for model inference each time.
+        """
+        self.num_samples = num_samples
+        self.batch_size = batch_size
+
+        self.predict_fn = predict_fn
+        self.labels = None
+        self.image = None
+        self.lime_explainer = None
+
+    def preparation_lime(self, data_path):
+        image_show = read_image(data_path)
+        result = self.predict_fn(image_show)
+
+        result = result[0]  # only one image here.
+
+        if abs(np.sum(result) - 1.0) > 1e-4:
+            # softmax
+            exp_result = np.exp(result)
+            probability = exp_result / np.sum(exp_result)
+        else:
+            probability = result
+
+        # only explain top 1
+        pred_label = np.argsort(probability)
+        pred_label = pred_label[-1:]
+
+        self.predicted_label = pred_label[0]
+        self.predicted_probability = probability[pred_label[0]]
+        self.image = image_show[0]
+        self.labels = pred_label
+
+        print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]: .3f}')
+
+        end = time.time()
+        algo = lime_base.LimeImageExplainer()
+        explainer = algo.explain_instance(self.image, self.predict_fn, self.labels, 0,
+                                          num_samples=self.num_samples, batch_size=self.batch_size)
+        self.lime_explainer = explainer
+        print('lime time: ', time.time() - end, 's.')
+
+    def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+        if self.lime_explainer is None:
+            self.preparation_lime(data_)
+
+        if visualization or save_to_disk:
+            import matplotlib.pyplot as plt
+            from skimage.segmentation import mark_boundaries
+            l = self.labels[0]
+
+            psize = 5
+            nrows = 2
+            weights_choices = [0.6, 0.75, 0.85]
+            ncols = len(weights_choices)
+
+            plt.close()
+            f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            for ax in axes.ravel():
+                ax.axis("off")
+            axes = axes.ravel()
+            axes[0].imshow(self.image)
+            axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
+
+            axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments))
+            axes[1].set_title("superpixel segmentation")
+
+            # LIME visualization
+            for i, w in enumerate(weights_choices):
+                num_to_show = auto_choose_num_features_to_show(self.lime_explainer, l, w)
+                temp, mask = self.lime_explainer.get_image_and_mask(
+                    l, positive_only=False, hide_rest=False, num_features=num_to_show
+                )
+                axes[ncols + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
+
+        if save_to_disk and save_outdir is not None:
+            os.makedirs(save_outdir, exist_ok=True)
+            save_fig(data_, save_outdir, 'lime', self.num_samples)
+
+        if visualization:
+            plt.show()
+
+        return
+
+
+class NormLIME(object):
+    def __init__(self, predict_fn, num_samples=3000, batch_size=50,
+                 kmeans_model_for_normlime=None, normlime_weights=None):
+        assert kmeans_model_for_normlime is not None, "NormLIME needs the KMeans model."
+        if normlime_weights is None:
+            raise NotImplementedError("Computing NormLIME weights is not implemented yet.")
+
+        self.num_samples = num_samples
+        self.batch_size = batch_size
+
+        self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
+        self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
+
+        self.predict_fn = predict_fn
+
+        self.labels = None
+        self.image = None
+
+    def predict_cluster_labels(self, feature_map, segments):
+        return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments))
+
+    def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
+        # global weights
+        g_weights = {y: [] for y in pred_labels}
+        for y in pred_labels:
+            cluster_weights_y = self.normlime_weights[y]
+            g_weights[y] = [
+                # some are not in the dict, 3000 samples may be not enough.
+                (i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
+            ]
+
+            g_weights[y] = sorted(g_weights[y],
+                                  key=lambda x: np.abs(x[1]), reverse=True)
+
+        return g_weights
+
+    def preparation_normlime(self, data_path):
+        self._lime = LIME(
+            lambda images: self.predict_fn(images)[0],
+            self.num_samples,
+            self.batch_size
+        )
+        self._lime.preparation_lime(data_path)
+
+        image_show = read_image(data_path)
+        result = self.predict_fn(image_show)
+
+        logit = result[0][0]  # only one image here.
+        if abs(np.sum(logit) - 1.0) > 1e-4:
+            # softmax
+            exp_result = np.exp(logit)
+            probability = exp_result / np.sum(exp_result)
+        else:
+            probability = logit
+
+        # only explain top 1
+        pred_label = np.argsort(probability)
+        pred_label = pred_label[-1:]
+
+        self.predicted_label = pred_label[0]
+        self.predicted_probability = probability[pred_label[0]]
+        self.image = image_show[0]
+        self.labels = pred_label
+        print('predicted result: ', pred_label[0], probability[pred_label[0]])
+
+        local_feature_map = result[1][0]
+        cluster_labels = self.predict_cluster_labels(
+            local_feature_map.transpose((1, 2, 0)), self._lime.lime_explainer.segments
+        )
+
+        g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
+
+        return g_weights
+
+    def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+        g_weights = self.preparation_normlime(data_)
+        lime_weights = self._lime.lime_explainer.local_exp
+
+        if visualization or save_to_disk:
+            import matplotlib.pyplot as plt
+            from skimage.segmentation import mark_boundaries
+            l = self.labels[0]
+
+            psize = 5
+            nrows = 4
+            weights_choices = [0.6, 0.85, 0.99]
+            ncols = len(weights_choices)
+
+            plt.close()
+            f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            for ax in axes.ravel():
+                ax.axis("off")
+
+            axes = axes.ravel()
+            axes[0].imshow(self.image)
+            axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
+
+            axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments))
+            axes[1].set_title("superpixel segmentation")
+
+            # LIME visualization
+            for i, w in enumerate(weights_choices):
+                num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
+                temp, mask = self._lime.lime_explainer.get_image_and_mask(
+                    l, positive_only=False, hide_rest=False, num_features=num_to_show
+                )
+                axes[ncols + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
+
+            # NormLIME visualization
+            self._lime.lime_explainer.local_exp = g_weights
+            for i, w in enumerate(weights_choices):
+                num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
+                temp, mask = self._lime.lime_explainer.get_image_and_mask(
+                    l, positive_only=False, hide_rest=False, num_features=num_to_show
+                )
+                axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols * 2 + i].set_title(f"label {l}, first {num_to_show} superpixels")
+
+            # NormLIME*LIME visualization
+            combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
+            self._lime.lime_explainer.local_exp = combined_weights
+            for i, w in enumerate(weights_choices):
+                num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
+                temp, mask = self._lime.lime_explainer.get_image_and_mask(
+                    l, positive_only=False, hide_rest=False, num_features=num_to_show
+                )
+                axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols * 3 + i].set_title(f"label {l}, first {num_to_show} superpixels")
+
+            self._lime.lime_explainer.local_exp = lime_weights
+
+        if save_to_disk and save_outdir is not None:
+            os.makedirs(save_outdir, exist_ok=True)
+            save_fig(data_, save_outdir, 'normlime', self.num_samples)
+
+        if visualization:
+            plt.show()
+
+
+def load_kmeans_model(fname):
+    import pickle
+    with open(fname, 'rb') as f:
+        kmeans_model = pickle.load(f)
+
+    return kmeans_model
+
+
+def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
+    segments = lime_explainer.segments
+    lime_weights = lime_explainer.local_exp[label]
+    num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
+
+    # l1 norm with filtered weights.
+    used_weights = [(tuple_w[0], tuple_w[1]) for i, tuple_w in enumerate(lime_weights) if tuple_w[1] > 0]
+    norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)])
+    normalized_weights = [(tuple_w[0], tuple_w[1] / norm) for i, tuple_w in enumerate(lime_weights)]
+
+    a = 0.0
+    n = 0
+    for i, tuple_w in enumerate(normalized_weights):
+        if tuple_w[1] < 0:
+            continue
+        if len(np.where(segments == tuple_w[0])[0]) < num_pixels_threshold_in_a_sp:
+            continue
+
+        a += tuple_w[1]
+        if a > percentage_to_show:
+            n = i + 1
+            break
+
+    if n == 0:
+        return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1)
+
+    return n
+
+
+def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam_max=None):
+    _, nc, h, w = feature_maps.shape
+
+    cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1)
+    cam = cam.sum((0, 1))
+
+    if cam_min is None:
+        cam_min = np.min(cam)
+    if cam_max is None:
+        cam_max = np.max(cam)
+
+    cam = cam - cam_min
+    cam = cam / cam_max
+    cam = np.uint8(255 * cam)
+    cam_img = cv2.resize(cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
+
+    heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET)
+    heatmap = np.float32(heatmap)
+    cam = heatmap + np.float32(image_show)
+    cam = cam / np.max(cam)
+
+    return cam
+
+
+def avg_using_superpixels(features, segments):
+    one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
+    for x in np.unique(segments):
+        one_list[x] = np.mean(features[segments == x], axis=0)
+
+    return one_list
+
+
+def centroid_using_superpixels(features, segments):
+    from skimage.measure import regionprops
+    regions = regionprops(segments + 1)
+    one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
+    for i, r in enumerate(regions):
+        one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
+    # print(one_list.shape)
+    return one_list
+
+
+def get_feature_for_kmeans(feature_map, segments):
+    from sklearn.preprocessing import normalize
+    centroid_feature = centroid_using_superpixels(feature_map, segments)
+    avg_feature = avg_using_superpixels(feature_map, segments)
+    x = np.concatenate((centroid_feature, avg_feature), axis=-1)
+    x = normalize(x)
+    return x
+
+
+def combine_normlime_and_lime(lime_weights, g_weights):
+    pred_labels = lime_weights.keys()
+    combined_weights = {y: [] for y in pred_labels}
+
+    for y in pred_labels:
+        normlized_lime_weights_y = lime_weights[y]
+        lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y}
+
+        normlized_g_weight_y = g_weights[y]
+        normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y}
+
+        combined_weights[y] = [
+            (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
+            for seg_k in lime_weights_dict.keys()
+        ]
+
+        combined_weights[y] = sorted(combined_weights[y],
+                                     key=lambda x: np.abs(x[1]), reverse=True)
+
+    return combined_weights
+
+
+def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
+    import matplotlib.pyplot as plt
+    if isinstance(data_, str):
+        if algorithm_name == 'cam':
+            f_out = f"{algorithm_name}_{data_.split('/')[-1]}.png"
+        else:
+            f_out = f"{algorithm_name}_{data_.split('/')[-1]}_s{num_samples}.png"
+        plt.savefig(
+            os.path.join(save_outdir, f_out)
+        )
+    else:
+        n = 0
+        if algorithm_name == 'cam':
+            f_out = f'cam-{n}.png'
+        else:
+            f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
+        while os.path.exists(
+                os.path.join(save_outdir, f_out)
+        ):
+            n += 1
+            if algorithm_name == 'cam':
+                f_out = f'cam-{n}.png'
+            else:
+                f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
+            continue
+        plt.savefig(
+            os.path.join(
+                save_outdir, f_out
+            )
+        )

+ 502 - 0
paddlex/cv/models/explanation/core/lime_base.py

@@ -0,0 +1,502 @@
+"""
+Contains abstract functionality for learning locally linear sparse model.
+"""
+from __future__ import print_function
+import numpy as np
+import scipy as sp
+import sklearn
+import sklearn.preprocessing
+from skimage.color import gray2rgb
+from sklearn.linear_model import Ridge, lars_path
+from sklearn.utils import check_random_state
+
+import copy
+from functools import partial
+from skimage.segmentation import quickshift
+from skimage.measure import regionprops
+
+
+class LimeBase(object):
+    """Class for learning a locally linear sparse model from perturbed data"""
+    def __init__(self,
+                 kernel_fn,
+                 verbose=False,
+                 random_state=None):
+        """Init function
+
+        Args:
+            kernel_fn: function that transforms an array of distances into an
+                        array of proximity values (floats).
+            verbose: if true, print local prediction values from linear model.
+            random_state: an integer or numpy.RandomState that will be used to
+                generate random numbers. If None, the random state will be
+                initialized using the internal numpy seed.
+        """
+        self.kernel_fn = kernel_fn
+        self.verbose = verbose
+        self.random_state = check_random_state(random_state)
+
+    @staticmethod
+    def generate_lars_path(weighted_data, weighted_labels):
+        """Generates the lars path for weighted data.
+
+        Args:
+            weighted_data: data that has been weighted by kernel
+            weighted_label: labels, weighted by kernel
+
+        Returns:
+            (alphas, coefs), both are arrays corresponding to the
+            regularization parameter and coefficients, respectively
+        """
+        x_vector = weighted_data
+        alphas, _, coefs = lars_path(x_vector,
+                                     weighted_labels,
+                                     method='lasso',
+                                     verbose=False)
+        return alphas, coefs
+
+    def forward_selection(self, data, labels, weights, num_features):
+        """Iteratively adds features to the model"""
+        clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
+        used_features = []
+        for _ in range(min(num_features, data.shape[1])):
+            max_ = -100000000
+            best = 0
+            for feature in range(data.shape[1]):
+                if feature in used_features:
+                    continue
+                clf.fit(data[:, used_features + [feature]], labels,
+                        sample_weight=weights)
+                score = clf.score(data[:, used_features + [feature]],
+                                  labels,
+                                  sample_weight=weights)
+                if score > max_:
+                    best = feature
+                    max_ = score
+            used_features.append(best)
+        return np.array(used_features)
+
+    def feature_selection(self, data, labels, weights, num_features, method):
+        """Selects features for the model. see explain_instance_with_data to
+           understand the parameters."""
+        if method == 'none':
+            return np.array(range(data.shape[1]))
+        elif method == 'forward_selection':
+            return self.forward_selection(data, labels, weights, num_features)
+        elif method == 'highest_weights':
+            clf = Ridge(alpha=0.01, fit_intercept=True,
+                        random_state=self.random_state)
+            clf.fit(data, labels, sample_weight=weights)
+
+            coef = clf.coef_
+            if sp.sparse.issparse(data):
+                coef = sp.sparse.csr_matrix(clf.coef_)
+                weighted_data = coef.multiply(data[0])
+                # Note: most efficient to slice the data before reversing
+                sdata = len(weighted_data.data)
+                argsort_data = np.abs(weighted_data.data).argsort()
+                # Edge case where data is more sparse than requested number of feature importances
+                # In that case, we just pad with zero-valued features
+                if sdata < num_features:
+                    nnz_indexes = argsort_data[::-1]
+                    indices = weighted_data.indices[nnz_indexes]
+                    num_to_pad = num_features - sdata
+                    indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype)))
+                    indices_set = set(indices)
+                    pad_counter = 0
+                    for i in range(data.shape[1]):
+                        if i not in indices_set:
+                            indices[pad_counter + sdata] = i
+                            pad_counter += 1
+                            if pad_counter >= num_to_pad:
+                                break
+                else:
+                    nnz_indexes = argsort_data[sdata - num_features:sdata][::-1]
+                    indices = weighted_data.indices[nnz_indexes]
+                return indices
+            else:
+                weighted_data = coef * data[0]
+                feature_weights = sorted(
+                    zip(range(data.shape[1]), weighted_data),
+                    key=lambda x: np.abs(x[1]),
+                    reverse=True)
+                return np.array([x[0] for x in feature_weights[:num_features]])
+        elif method == 'lasso_path':
+            weighted_data = ((data - np.average(data, axis=0, weights=weights))
+                             * np.sqrt(weights[:, np.newaxis]))
+            weighted_labels = ((labels - np.average(labels, weights=weights))
+                               * np.sqrt(weights))
+            nonzero = range(weighted_data.shape[1])
+            _, coefs = self.generate_lars_path(weighted_data,
+                                               weighted_labels)
+            for i in range(len(coefs.T) - 1, 0, -1):
+                nonzero = coefs.T[i].nonzero()[0]
+                if len(nonzero) <= num_features:
+                    break
+            used_features = nonzero
+            return used_features
+        elif method == 'auto':
+            if num_features <= 6:
+                n_method = 'forward_selection'
+            else:
+                n_method = 'highest_weights'
+            return self.feature_selection(data, labels, weights,
+                                          num_features, n_method)
+
+    def explain_instance_with_data(self,
+                                   neighborhood_data,
+                                   neighborhood_labels,
+                                   distances,
+                                   label,
+                                   num_features,
+                                   feature_selection='auto',
+                                   model_regressor=None):
+        """Takes perturbed data, labels and distances, returns explanation.
+
+        Args:
+            neighborhood_data: perturbed data, 2d array. first element is
+                               assumed to be the original data point.
+            neighborhood_labels: corresponding perturbed labels. should have as
+                                 many columns as the number of possible labels.
+            distances: distances to original data point.
+            label: label for which we want an explanation
+            num_features: maximum number of features in explanation
+            feature_selection: how to select num_features. options are:
+                'forward_selection': iteratively add features to the model.
+                    This is costly when num_features is high
+                'highest_weights': selects the features that have the highest
+                    product of absolute weight * original data point when
+                    learning with all the features
+                'lasso_path': chooses features based on the lasso
+                    regularization path
+                'none': uses all features, ignores num_features
+                'auto': uses forward_selection if num_features <= 6, and
+                    'highest_weights' otherwise.
+            model_regressor: sklearn regressor to use in explanation.
+                Defaults to Ridge regression if None. Must have
+                model_regressor.coef_ and 'sample_weight' as a parameter
+                to model_regressor.fit()
+
+        Returns:
+            (intercept, exp, score, local_pred):
+            intercept is a float.
+            exp is a sorted list of tuples, where each tuple (x,y) corresponds
+            to the feature id (x) and the local weight (y). The list is sorted
+            by decreasing absolute value of y.
+            score is the R^2 value of the returned explanation
+            local_pred is the prediction of the explanation model on the original instance
+        """
+
+        weights = self.kernel_fn(distances)
+        labels_column = neighborhood_labels[:, label]
+        used_features = self.feature_selection(neighborhood_data,
+                                               labels_column,
+                                               weights,
+                                               num_features,
+                                               feature_selection)
+        if model_regressor is None:
+            model_regressor = Ridge(alpha=1, fit_intercept=True,
+                                    random_state=self.random_state)
+        easy_model = model_regressor
+        easy_model.fit(neighborhood_data[:, used_features],
+                       labels_column, sample_weight=weights)
+        prediction_score = easy_model.score(
+            neighborhood_data[:, used_features],
+            labels_column, sample_weight=weights)
+
+        local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
+
+        if self.verbose:
+            print('Intercept', easy_model.intercept_)
+            print('Prediction_local', local_pred,)
+            print('Right:', neighborhood_labels[0, label])
+        return (easy_model.intercept_,
+                sorted(zip(used_features, easy_model.coef_),
+                       key=lambda x: np.abs(x[1]), reverse=True),
+                prediction_score, local_pred)
+
+
+class ImageExplanation(object):
+    def __init__(self, image, segments):
+        """Init function.
+
+        Args:
+            image: 3d numpy array
+            segments: 2d numpy array, with the output from skimage.segmentation
+        """
+        self.image = image
+        self.segments = segments
+        self.intercept = {}
+        self.local_exp = {}
+        self.local_pred = None
+
+    def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
+                           num_features=5, min_weight=0.):
+        """Init function.
+
+        Args:
+            label: label to explain
+            positive_only: if True, only take superpixels that positively contribute to
+                the prediction of the label.
+            negative_only: if True, only take superpixels that negatively contribute to
+                the prediction of the label. If false, and so is positive_only, then both
+                negativey and positively contributions will be taken.
+                Both can't be True at the same time
+            hide_rest: if True, make the non-explanation part of the return
+                image gray
+            num_features: number of superpixels to include in explanation
+            min_weight: minimum weight of the superpixels to include in explanation
+
+        Returns:
+            (image, mask), where image is a 3d numpy array and mask is a 2d
+            numpy array that can be used with
+            skimage.segmentation.mark_boundaries
+        """
+        if label not in self.local_exp:
+            raise KeyError('Label not in explanation')
+        if positive_only & negative_only:
+            raise ValueError("Positive_only and negative_only cannot be true at the same time.")
+        segments = self.segments
+        image = self.image
+        exp = self.local_exp[label]
+        mask = np.zeros(segments.shape, segments.dtype)
+        if hide_rest:
+            temp = np.zeros(self.image.shape)
+        else:
+            temp = self.image.copy()
+        if positive_only:
+            fs = [x[0] for x in exp
+                  if x[1] > 0 and x[1] > min_weight][:num_features]
+        if negative_only:
+            fs = [x[0] for x in exp
+                  if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
+        if positive_only or negative_only:
+            for f in fs:
+                temp[segments == f] = image[segments == f].copy()
+                mask[segments == f] = 1
+            return temp, mask
+        else:
+            for f, w in exp[:num_features]:
+                if np.abs(w) < min_weight:
+                    continue
+                c = 0 if w < 0 else 1
+                mask[segments == f] = -1 if w < 0 else 1
+                temp[segments == f] = image[segments == f].copy()
+                temp[segments == f, c] = np.max(image)
+            return temp, mask
+
+    def get_rendered_image(self, label, min_weight=0.005):
+        """
+
+        Args:
+            label: label to explain
+            min_weight:
+
+        Returns:
+            image, is a 3d numpy array
+        """
+        if label not in self.local_exp:
+            raise KeyError('Label not in explanation')
+
+        from matplotlib import cm
+
+        segments = self.segments
+        image = self.image
+        exp = self.local_exp[label]
+        temp = np.zeros_like(image)
+
+        weight_max = abs(exp[0][1])
+        exp = [(f, w/weight_max) for f, w in exp]
+        exp = sorted(exp, key=lambda x: x[1], reverse=True)  # negatives are at last.
+
+        cmaps = cm.get_cmap('Spectral')
+        # sigmoid_space = 1 / (1 + np.exp(-np.linspace(-20, 20, len(exp))))
+        colors = cmaps(np.linspace(0, 1, len(exp)))
+        colors = colors[:, :3]
+
+        for i, (f, w) in enumerate(exp):
+            if np.abs(w) < min_weight:
+                continue
+            temp[segments == f] = image[segments == f].copy()
+            temp[segments == f] = colors[i] * 255
+        return temp
+
+
+class LimeImageExplainer(object):
+    """Explains predictions on Image (i.e. matrix) data.
+    For numerical features, perturb them by sampling from a Normal(0,1) and
+    doing the inverse operation of mean-centering and scaling, according to the
+    means and stds in the training data. For categorical features, perturb by
+    sampling according to the training distribution, and making a binary
+    feature that is 1 when the value is the same as the instance being
+    explained."""
+
+    def __init__(self, kernel_width=.25, kernel=None, verbose=False,
+                 feature_selection='auto', random_state=None):
+        """Init function.
+
+        Args:
+            kernel_width: kernel width for the exponential kernel.
+            If None, defaults to sqrt(number of columns) * 0.75.
+            kernel: similarity kernel that takes euclidean distances and kernel
+                width as input and outputs weights in (0,1). If None, defaults to
+                an exponential kernel.
+            verbose: if true, print local prediction values from linear model
+            feature_selection: feature selection method. can be
+                'forward_selection', 'lasso_path', 'none' or 'auto'.
+                See function 'explain_instance_with_data' in lime_base.py for
+                details on what each of the options does.
+            random_state: an integer or numpy.RandomState that will be used to
+                generate random numbers. If None, the random state will be
+                initialized using the internal numpy seed.
+        """
+        kernel_width = float(kernel_width)
+
+        if kernel is None:
+            def kernel(d, kernel_width):
+                return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
+
+        kernel_fn = partial(kernel, kernel_width=kernel_width)
+
+        self.random_state = check_random_state(random_state)
+        self.feature_selection = feature_selection
+        self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
+
+    def explain_instance(self, image, classifier_fn, labels=(1,),
+                         hide_color=None,
+                         num_features=100000, num_samples=1000,
+                         batch_size=10,
+                         distance_metric='cosine',
+                         model_regressor=None
+                         ):
+        """Generates explanations for a prediction.
+
+        First, we generate neighborhood data by randomly perturbing features
+        from the instance (see __data_inverse). We then learn locally weighted
+        linear models on this neighborhood data to explain each of the classes
+        in an interpretable way (see lime_base.py).
+
+        Args:
+            image: 3 dimension RGB image. If this is only two dimensional,
+                we will assume it's a grayscale image and call gray2rgb.
+            classifier_fn: classifier prediction probability function, which
+                takes a numpy array and outputs prediction probabilities.  For
+                ScikitClassifiers , this is classifier.predict_proba.
+            labels: iterable with labels to be explained.
+            hide_color: TODO
+            num_features: maximum number of features present in explanation
+            num_samples: size of the neighborhood to learn the linear model
+            batch_size: TODO
+            distance_metric: the distance metric to use for weights.
+            model_regressor: sklearn regressor to use in explanation. Defaults
+            to Ridge regression in LimeBase. Must have model_regressor.coef_
+            and 'sample_weight' as a parameter to model_regressor.fit()
+
+        Returns:
+            An ImageExplanation object (see lime_image.py) with the corresponding
+            explanations.
+        """
+        if len(image.shape) == 2:
+            image = gray2rgb(image)
+
+        try:
+            segments = quickshift(image, sigma=1)
+        except ValueError as e:
+            raise e
+
+        self.segments = segments
+
+        fudged_image = image.copy()
+        if hide_color is None:
+            # if no hide_color, use the mean
+            for x in np.unique(segments):
+                mx = np.mean(image[segments == x], axis=0)
+                fudged_image[segments == x] = mx
+        elif hide_color == 'avg_from_neighbor':
+            from scipy.spatial.distance import cdist
+
+            n_features = np.unique(segments).shape[0]
+            regions = regionprops(segments + 1)
+            centroids = np.zeros((n_features, 2))
+            for i, x in enumerate(regions):
+                centroids[i] = np.array(x.centroid)
+
+            d = cdist(centroids, centroids, 'sqeuclidean')
+
+            for x in np.unique(segments):
+                # print(np.argmin(d[x]))
+                a = [image[segments == i] for i in np.argsort(d[x])[1:6]]
+                mx = np.mean(np.concatenate(a), axis=0)
+                fudged_image[segments == x] = mx
+
+        else:
+            fudged_image[:] = 0
+
+        top = labels
+
+        data, labels = self.data_labels(image, fudged_image, segments,
+                                        classifier_fn, num_samples,
+                                        batch_size=batch_size)
+
+        distances = sklearn.metrics.pairwise_distances(
+            data,
+            data[0].reshape(1, -1),
+            metric=distance_metric
+        ).ravel()
+
+        ret_exp = ImageExplanation(image, segments)
+        for label in top:
+            (ret_exp.intercept[label],
+             ret_exp.local_exp[label],
+             ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data(
+                data, labels, distances, label, num_features,
+                model_regressor=model_regressor,
+                feature_selection=self.feature_selection)
+        return ret_exp
+
+    def data_labels(self,
+                    image,
+                    fudged_image,
+                    segments,
+                    classifier_fn,
+                    num_samples,
+                    batch_size=10):
+        """Generates images and predictions in the neighborhood of this image.
+
+        Args:
+            image: 3d numpy array, the image
+            fudged_image: 3d numpy array, image to replace original image when
+                superpixel is turned off
+            segments: segmentation of the image
+            classifier_fn: function that takes a list of images and returns a
+                matrix of prediction probabilities
+            num_samples: size of the neighborhood to learn the linear model
+            batch_size: classifier_fn will be called on batches of this size.
+
+        Returns:
+            A tuple (data, labels), where:
+                data: dense num_samples * num_superpixels
+                labels: prediction probabilities matrix
+        """
+        n_features = np.unique(segments).shape[0]
+        data = self.random_state.randint(0, 2, num_samples * n_features) \
+            .reshape((num_samples, n_features))
+        labels = []
+        data[0, :] = 1
+        imgs = []
+        for row in data:
+            temp = copy.deepcopy(image)
+            zeros = np.where(row == 0)[0]
+            mask = np.zeros(segments.shape).astype(bool)
+            for z in zeros:
+                mask[segments == z] = True
+            temp[mask] = fudged_image[mask]
+            imgs.append(temp)
+            if len(imgs) == batch_size:
+                preds = classifier_fn(np.array(imgs))
+                labels.extend(preds)
+                imgs = []
+        if len(imgs) > 0:
+            preds = classifier_fn(np.array(imgs))
+            labels.extend(preds)
+        return data, np.array(labels)

+ 46 - 0
paddlex/cv/models/explanation/visualize.py

@@ -0,0 +1,46 @@
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import copy
+import os.path as osp
+import numpy as np
+from .core.explanation import Explanation
+
+
+def visualize(img_file, 
+              model, 
+              explanation_type='lime',
+              num_samples=3000, 
+              batch_size=50,
+              save_dir='./'):
+    model.arrange_transforms(
+                transforms=model.test_transforms, mode='test')
+    tmp_transforms = copy.deepcopy(model.test_transforms)
+    tmp_transforms.transforms = tmp_transforms.transforms[:-2]
+    img = tmp_transforms(img_file)[0]
+    img = np.around(img).astype('uint8')
+    img = np.expand_dims(img, axis=0)
+    explaier = None
+    if explanation_type == 'lime':
+        explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size)
+    else:
+        raise Exception('The {} explanantion method is not supported yet!'.format(explanation_type))
+    img_name = osp.splitext(osp.split(img_file)[-1])[0]
+    explaier.explain(img, save_dir=save_dir)
+    
+    
+def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
+    def predict_func(image):
+        image = image.astype('float32')
+        model.test_transforms.transforms = model.test_transforms.transforms[-2:]
+        out = model.explanation_predict(image)
+        return out[0]
+    explaier = Explanation('lime', 
+                            predict_func,
+                            num_samples=num_samples, 
+                            batch_size=batch_size)
+    return explaier
+    

+ 4 - 1
paddlex/cv/nets/resnet.py

@@ -120,6 +120,7 @@ class ResNet(object):
         self.num_classes = num_classes
         self.lr_mult_list = lr_mult_list
         self.curr_stage = 0
+        self.features = []
 
     def _conv_offset(self,
                      input,
@@ -474,7 +475,9 @@ class ResNet(object):
                 size=self.num_classes,
                 param_attr=fluid.param_attr.ParamAttr(
                     initializer=fluid.initializer.Uniform(-stdv, stdv)))
-            return out
+            self.features.append(out)
+#             out.persistable=True
+            return out, self.features
 
         return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
                             for idx, feat in enumerate(res_endpoints)])