Эх сурвалжийг харах

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleX into develop_encrypt

FlyingQianMM 5 жил өмнө
parent
commit
6c806da598

+ 9 - 26
docs/tutorials/deploy/encryption.md

@@ -1,6 +1,7 @@
 # Paddle模型加密方案
 
 飞桨团队推出模型加密方案,使用业内主流的AES加密技术对最终模型进行加密。飞桨用户可以通过PaddleX导出模型后,使用该方案对模型进行加密,预测时使用解密SDK进行模型解密并完成推理,大大提升AI应用安全性和开发效率。
+** 注意:目前加密方案仅支持Linux系统**
 
 ## 1. 方案介绍
 
@@ -10,40 +11,22 @@
 
 下载并解压后,目录包含内容为:
 ```
-paddle_model_encrypt
+paddlex-encryption
 ├── include # 头文件:paddle_model_decrypt.h(解密)和paddle_model_encrypt.h(加密)
 |
 ├── lib # libpmodel-encrypt.so和libpmodel-decrypt.so动态库
 |
-└── tool # paddle_encrypt_tool
+└── tool # paddlex_encrypt_tool
 ```
 
-### 1.2 二进制工具
-
-#### 1.2.1 生成密钥
-
-产生随机密钥信息(用于AES加解密使用)(32字节key + 16字节iv, 注意这里产生的key是经过base64编码后的,这样可以扩充选取key的范围)
-
-```
-paddle_encrypt_tool    -g
-```
-#### 1.2.1 文件加密
-
-```
- paddle_encrypt_tool    -e    -key    keydata     -infile    infile    -outfile    outfile
-```
-
-#### 1.3 SDK
+### 1.2 加密PaddleX模型
 
+模型加密后,会产生随机密钥信息(用于AES加解密使用),该key值需要在模型加载时传入作为解密使用。
+> 32字节key + 16字节iv, 注意这里产生的key是经过base64编码后的,这样可以扩充选取key的范围
 ```
-// 加密API
-int paddle_encrypt_model(const char* keydata, const char* infile, const char* outfile);
-// 加载加密模型API:
-int paddle_security_load_model(
-        paddle::AnalysisConfig *config,
-        const char *key,
-        const char *model_file,
-        const char *param_file);
+./paddlex-encryption -model_dir paddlex_inference_model -save_dir paddlex_encrypted_model
 ```
+模型在加密后,会保存至指定的`-save_dir`下,同时生成密钥信息,命令输出如下图所示,密钥为`33NRtxvpDN+rkoiECm/e1Qc7sDlODdac7wp1m+3hFSU=`
+![](images/encryt.png)
 
 ## 2. PaddleX C++加密部署

BIN
docs/tutorials/deploy/images/encryt.png


+ 1 - 0
paddlex/__init__.py

@@ -29,6 +29,7 @@ from . import cls
 from . import slim
 from . import convertor
 from . import tools
+from . import interpret
 
 try:
     import pycocotools

+ 2 - 7
paddlex/cv/models/classifier.py

@@ -27,7 +27,6 @@ from .base import BaseAPI
 
 class BaseClassifier(BaseAPI):
     """构建分类器,并实现其训练、评估、预测和模型导出。
-
     Args:
         model_name (str): 分类器的模型名字,取值范围为['ResNet18',
                           'ResNet34', 'ResNet50', 'ResNet101',
@@ -65,6 +64,8 @@ class BaseClassifier(BaseAPI):
         softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
         inputs = OrderedDict([('image', image)])
         outputs = OrderedDict([('predict', softmax_out)])
+        if mode == 'test':
+            self.interpretation_feats = OrderedDict([('logits', net_out)])
         if mode != 'test':
             cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
             avg_cost = fluid.layers.mean(cost)
@@ -115,7 +116,6 @@ class BaseClassifier(BaseAPI):
               early_stop_patience=5,
               resume_checkpoint=None):
         """训练。
-
         Args:
             num_epochs (int): 训练迭代轮数。
             train_dataset (paddlex.datasets): 训练数据读取器。
@@ -139,7 +139,6 @@ class BaseClassifier(BaseAPI):
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
             resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
-
         Raises:
             ValueError: 模型从inference model进行加载。
         """
@@ -183,13 +182,11 @@ class BaseClassifier(BaseAPI):
                  epoch_id=None,
                  return_details=False):
         """评估。
-
         Args:
             eval_dataset (paddlex.datasets): 验证数据读取器。
             batch_size (int): 验证数据批大小。默认为1。
             epoch_id (int): 当前评估模型所在的训练轮数。
             return_details (bool): 是否返回详细信息。
-
         Returns:
           dict: 当return_details为False时,返回dict, 包含关键字:'acc1'、'acc5',
               分别表示最大值的accuracy、前5个最大值的accuracy。
@@ -248,12 +245,10 @@ class BaseClassifier(BaseAPI):
 
     def predict(self, img_file, transforms=None, topk=1):
         """预测。
-
         Args:
             img_file (str): 预测图像路径。
             transforms (paddlex.cls.transforms): 数据预处理操作。
             topk (int): 预测时前k个最大值。
-
         Returns:
             list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
             分别对应预测类别id、预测类别标签、预测得分。

+ 1 - 3
paddlex/cv/nets/darknet.py

@@ -136,10 +136,8 @@ class DarkNet(object):
     def __call__(self, input):
         """
         Get the backbone of DarkNet, that is output for the 5 stages.
-
         Args:
             input (Variable): input variable.
-
         Returns:
             The last variables of each stage.
         """
@@ -184,4 +182,4 @@ class DarkNet(object):
                 bias_attr=ParamAttr(name='fc_offset'))
             return out
 
-        return blocks
+        return blocks

+ 1 - 1
paddlex/cv/nets/densenet.py

@@ -173,4 +173,4 @@ class DenseNet(object):
             bn_ac_conv = fluid.layers.dropout(
                 x=bn_ac_conv, dropout_prob=dropout)
         bn_ac_conv = fluid.layers.concat([input, bn_ac_conv], axis=1)
-        return bn_ac_conv
+        return bn_ac_conv

+ 1 - 2
paddlex/cv/nets/mobilenet_v1.py

@@ -24,7 +24,6 @@ from paddle.fluid.regularizer import L2Decay
 class MobileNetV1(object):
     """
     MobileNet v1, see https://arxiv.org/abs/1704.04861
-
     Args:
         norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
         norm_decay (float): weight decay for normalization layer weights
@@ -214,4 +213,4 @@ class MobileNetV1(object):
         module17 = self._extra_block(module16, num_filters[3][0],
                                      num_filters[3][1], 1, 2,
                                      self.prefix_name + "conv7_4")
-        return module11, module13, module14, module15, module16, module17
+        return module11, module13, module14, module15, module16, module17

+ 1 - 1
paddlex/cv/nets/mobilenet_v2.py

@@ -239,4 +239,4 @@ class MobileNetV2:
                 padding=1,
                 expansion_factor=t,
                 name=name + '_' + str(i + 1))
-        return last_residual_block, depthwise_output
+        return last_residual_block, depthwise_output

+ 1 - 1
paddlex/cv/nets/shufflenet_v2.py

@@ -269,4 +269,4 @@ class ShuffleNetV2():
                 name='stage_' + name + '_conv3')
             out = fluid.layers.concat([conv_linear_1, conv_linear_2], axis=1)
 
-        return self.channel_shuffle(out, 2)
+        return self.channel_shuffle(out, 2)

+ 1 - 1
paddlex/cv/nets/xception.py

@@ -329,4 +329,4 @@ def xception_41(num_classes=None):
 
 def xception_71(num_classes=None):
     model = Xception(num_classes, 71)
-    return model
+    return model

+ 18 - 0
paddlex/interpret/__init__.py

@@ -0,0 +1,18 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from . import visualize
+
+visualize = visualize.visualize

+ 22 - 0
paddlex/interpret/as_data_reader/data_path_utils.py

@@ -0,0 +1,22 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import os
+
+def _find_classes(dir):
+    # Faster and available in Python 3.5 and above
+    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
+    classes.sort()
+    class_to_idx = {classes[i]: i for i in range(len(classes))}
+    return classes, class_to_idx

+ 225 - 0
paddlex/interpret/as_data_reader/readers.py

@@ -0,0 +1,225 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import os
+import sys
+import cv2
+import numpy as np
+import six
+import glob
+from .data_path_utils import _find_classes
+from PIL import Image
+
+
+def resize_short(img, target_size, interpolation=None):
+    """resize image
+
+    Args:
+        img: image data
+        target_size: resize short target size
+        interpolation: interpolation mode
+
+    Returns:
+        resized image data
+    """
+    percent = float(target_size) / min(img.shape[0], img.shape[1])
+    resized_width = int(round(img.shape[1] * percent))
+    resized_height = int(round(img.shape[0] * percent))
+    if interpolation:
+        resized = cv2.resize(
+            img, (resized_width, resized_height), interpolation=interpolation)
+    else:
+        resized = cv2.resize(img, (resized_width, resized_height))
+    return resized
+
+
+def crop_image(img, target_size, center=True):
+    """crop image
+
+    Args:
+        img: images data
+        target_size: crop target size
+        center: crop mode
+
+    Returns:
+        img: cropped image data
+    """
+    height, width = img.shape[:2]
+    size = target_size
+    if center:
+        w_start = (width - size) // 2
+        h_start = (height - size) // 2
+    else:
+        w_start = np.random.randint(0, width - size + 1)
+        h_start = np.random.randint(0, height - size + 1)
+    w_end = w_start + size
+    h_end = h_start + size
+    img = img[h_start:h_end, w_start:w_end, :]
+    return img
+
+
+def preprocess_image(img, random_mirror=False):
+    """
+    centered, scaled by 1/255.
+    :param img: np.array: shape: [ns, h, w, 3], color order: rgb.
+    :return: np.array: shape: [ns, h, w, 3]
+    """
+    mean = [0.485, 0.456, 0.406]
+    std = [0.229, 0.224, 0.225]
+
+    # transpose to [ns, 3, h, w]
+    img = img.astype('float32').transpose((0, 3, 1, 2)) / 255
+
+    img_mean = np.array(mean).reshape((3, 1, 1))
+    img_std = np.array(std).reshape((3, 1, 1))
+    img -= img_mean
+    img /= img_std
+
+    if random_mirror:
+        mirror = int(np.random.uniform(0, 2))
+        if mirror == 1:
+            img = img[:, :, ::-1, :]
+
+    return img
+
+
+def read_image(img_path, target_size=256, crop_size=224):
+    """
+    resize_short to 256, then center crop to 224.
+    :param img_path: one image path
+    :return: np.array: shape: [1, h, w, 3], color order: rgb.
+    """
+
+    if isinstance(img_path, str):
+        with open(img_path, 'rb') as f:
+            img = Image.open(f)
+            img = img.convert('RGB')
+            img = np.array(img)
+            # img = cv2.imread(img_path)
+
+            img = resize_short(img, target_size, interpolation=None)
+            img = crop_image(img, target_size=crop_size, center=True)
+            # img = img[:, :, ::-1]
+            img = np.expand_dims(img, axis=0)
+            return img
+    elif isinstance(img_path, np.ndarray):
+        assert len(img_path.shape) == 4
+        return img_path
+    else:
+        ValueError(f"Not recognized data type {type(img_path)}.")
+
+
+class ReaderConfig(object):
+    """
+    A generic data loader where the images are arranged in this way:
+
+        root/train/dog/xxy.jpg
+        root/train/dog/xxz.jpg
+        ...
+        root/train/cat/nsdf3.jpg
+        root/train/cat/asd932_.jpg
+        ...
+
+        root/test/dog/xxx.jpg
+        ...
+        root/test/cat/123.jpg
+        ...
+
+    """
+    def __init__(self, dataset_dir, is_test):
+        image_paths, labels, self.num_classes = self.get_dataset_info(dataset_dir, is_test)
+        random_per = np.random.permutation(range(len(image_paths)))
+        self.image_paths = image_paths[random_per]
+        self.labels = labels[random_per]
+        self.is_test = is_test
+
+    def get_reader(self):
+        def reader():
+            IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+            target_size = 256
+            crop_size = 224
+
+            for i, img_path in enumerate(self.image_paths):
+                if not img_path.lower().endswith(IMG_EXTENSIONS):
+                    continue
+
+                img = cv2.imread(img_path)
+                if img is None:
+                    print(img_path)
+                    continue
+                img = resize_short(img, target_size, interpolation=None)
+                img = crop_image(img, crop_size, center=self.is_test)
+                img = img[:, :, ::-1]
+                img = np.expand_dims(img, axis=0)
+
+                img = preprocess_image(img, not self.is_test)
+
+                yield img, self.labels[i]
+
+        return reader
+
+    def get_dataset_info(self, dataset_dir, is_test=False):
+        IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+
+        # read
+        if is_test:
+            datasubset_dir = os.path.join(dataset_dir, 'test')
+        else:
+            datasubset_dir = os.path.join(dataset_dir, 'train')
+
+        class_names, class_to_idx = _find_classes(datasubset_dir)
+        # num_classes = len(class_names)
+        image_paths = []
+        labels = []
+        for class_name in class_names:
+            classes_dir = os.path.join(datasubset_dir, class_name)
+            for img_path in glob.glob(os.path.join(classes_dir, '*')):
+                if not img_path.lower().endswith(IMG_EXTENSIONS):
+                    continue
+
+                image_paths.append(img_path)
+                labels.append(class_to_idx[class_name])
+
+        image_paths = np.array(image_paths)
+        labels = np.array(labels)
+        return image_paths, labels, len(class_names)
+
+
+def create_reader(list_image_path, list_label=None, is_test=False):
+    def reader():
+        IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+        target_size = 256
+        crop_size = 224
+
+        for i, img_path in enumerate(list_image_path):
+            if not img_path.lower().endswith(IMG_EXTENSIONS):
+                continue
+
+            img = cv2.imread(img_path)
+            if img is None:
+                print(img_path)
+                continue
+
+            img = resize_short(img, target_size, interpolation=None)
+            img = crop_image(img, crop_size, center=is_test)
+            img = img[:, :, ::-1]
+            img_show = np.expand_dims(img, axis=0)
+
+            img = preprocess_image(img_show, not is_test)
+
+            label = 0 if list_label is None else list_label[i]
+
+            yield img_show, img, label
+
+    return reader

+ 113 - 0
paddlex/interpret/core/_session_preparation.py

@@ -0,0 +1,113 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import os
+import paddle.fluid as fluid
+import numpy as np
+from paddle.fluid.param_attr import ParamAttr
+from ..as_data_reader.readers import preprocess_image
+
+root_path = os.environ['HOME']
+root_path = os.path.join(root_path, '.paddlex')
+h_pre_models = os.path.join(root_path, "pre_models")
+h_pre_models_kmeans = os.path.join(h_pre_models, "kmeans_model.pkl")
+
+
+def paddle_get_fc_weights(var_name="fc_0.w_0"):
+    fc_weights = fluid.global_scope().find_var(var_name).get_tensor()
+    return np.array(fc_weights)
+
+
+def paddle_resize(extracted_features, outsize):
+    resized_features = fluid.layers.resize_bilinear(extracted_features, outsize)
+    return resized_features
+
+
+def compute_features_for_kmeans(data_content):
+    def conv_bn_layer(input,
+                      num_filters,
+                      filter_size,
+                      stride=1,
+                      groups=1,
+                      act=None,
+                      name=None,
+                      is_test=True,
+                      global_name=''):
+        conv = fluid.layers.conv2d(
+            input=input,
+            num_filters=num_filters,
+            filter_size=filter_size,
+            stride=stride,
+            padding=(filter_size - 1) // 2,
+            groups=groups,
+            act=None,
+            param_attr=ParamAttr(name=global_name + name + "_weights"),
+            bias_attr=False,
+            name=global_name + name + '.conv2d.output.1')
+        if name == "conv1":
+            bn_name = "bn_" + name
+        else:
+            bn_name = "bn" + name[3:]
+        return fluid.layers.batch_norm(
+            input=conv,
+            act=act,
+            name=global_name + bn_name + '.output.1',
+            param_attr=ParamAttr(global_name + bn_name + '_scale'),
+            bias_attr=ParamAttr(global_name + bn_name + '_offset'),
+            moving_mean_name=global_name + bn_name + '_mean',
+            moving_variance_name=global_name + bn_name + '_variance',
+            use_global_stats=is_test
+        )
+
+    startup_prog = fluid.default_startup_program().clone(for_test=True)
+    prog = fluid.Program()
+    with fluid.program_guard(prog, startup_prog):
+        with fluid.unique_name.guard():
+            image_op = fluid.data(name='image', shape=[None, 3, 224, 224], dtype='float32')
+
+            conv = conv_bn_layer(
+                input=image_op,
+                num_filters=32,
+                filter_size=3,
+                stride=2,
+                act='relu',
+                name='conv1_1')
+            conv = conv_bn_layer(
+                input=conv,
+                num_filters=32,
+                filter_size=3,
+                stride=1,
+                act='relu',
+                name='conv1_2')
+            conv = conv_bn_layer(
+                input=conv,
+                num_filters=64,
+                filter_size=3,
+                stride=1,
+                act='relu',
+                name='conv1_3')
+            extracted_features = conv
+            resized_features = fluid.layers.resize_bilinear(extracted_features, image_op.shape[2:])
+
+    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
+    place = fluid.CUDAPlace(gpu_id)
+    # place = fluid.CPUPlace()
+    exe = fluid.Executor(place)
+    exe.run(startup_prog)
+    fluid.io.load_persistables(exe, h_pre_models, prog)
+
+    images = preprocess_image(data_content)  # transpose to [N, 3, H, W], scaled to [0.0, 1.0]
+    result = exe.run(prog, fetch_list=[resized_features], feed={'image': images})
+
+    return result[0][0]

+ 51 - 0
paddlex/interpret/core/interpretation.py

@@ -0,0 +1,51 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from .interpretation_algorithms import CAM, LIME, NormLIME
+from .normlime_base import precompute_normlime_weights
+
+
+class Interpretation(object):
+    """
+    Base class for all interpretation algorithms.
+    """
+    def __init__(self, interpretation_algorithm_name, predict_fn, label_names, **kwargs):
+        supported_algorithms = {
+            'cam': CAM,
+            'lime': LIME,
+            'normlime': NormLIME
+        }
+
+        self.algorithm_name = interpretation_algorithm_name.lower()
+        assert self.algorithm_name in supported_algorithms.keys()
+        self.predict_fn = predict_fn
+
+        # initialization for the interpretation algorithm.
+        self.algorithm = supported_algorithms[self.algorithm_name](
+            self.predict_fn, label_names, **kwargs
+        )
+
+    def interpret(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
+        """
+
+        Args:
+            data_: data_ can be a path or numpy.ndarray.
+            visualization: whether to show using matplotlib.
+            save_to_disk: whether to save the figure in local disk.
+            save_dir: dir to save figure if save_to_disk is True.
+
+        Returns:
+
+        """
+        return self.algorithm.interpret(data_, visualization, save_to_disk, save_dir)

+ 444 - 0
paddlex/interpret/core/interpretation_algorithms.py

@@ -0,0 +1,444 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import os
+import numpy as np
+import time
+
+from . import lime_base
+from ..as_data_reader.readers import read_image
+from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, h_pre_models_kmeans
+from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model
+
+import cv2
+
+
+class CAM(object):
+    def __init__(self, predict_fn, label_names):
+        """
+
+        Args:
+            predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255)
+                        output: [
+                        logits [N, num_classes],
+                        feature map before global average pooling [N, num_channels, h_, w_]
+                        ]
+
+        """
+        self.predict_fn = predict_fn
+        self.label_names = label_names
+
+    def preparation_cam(self, data_):
+        image_show = read_image(data_)
+        result = self.predict_fn(image_show)
+
+        logit = result[0][0]
+        if abs(np.sum(logit) - 1.0) > 1e-4:
+            # softmax
+            logit = logit - np.max(logit)
+            exp_result = np.exp(logit)
+            probability = exp_result / np.sum(exp_result)
+        else:
+            probability = logit
+
+        # only interpret top 1
+        pred_label = np.argsort(probability)
+        pred_label = pred_label[-1:]
+
+        self.predicted_label = pred_label[0]
+        self.predicted_probability = probability[pred_label[0]]
+        self.image = image_show[0]
+        self.labels = pred_label
+
+        fc_weights = paddle_get_fc_weights()
+        feature_maps = result[1]
+        
+        l = pred_label[0]
+        ln = l
+        if self.label_names is not None:
+            ln = self.label_names[l]
+
+        print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
+        return feature_maps, fc_weights
+
+    def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+        feature_maps, fc_weights = self.preparation_cam(data_)
+        cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
+
+        if visualization or save_to_disk:
+            import matplotlib.pyplot as plt
+            from skimage.segmentation import mark_boundaries
+            l = self.labels[0]
+            ln = l 
+            if self.label_names is not None:
+                ln = self.label_names[l]
+
+            psize = 5
+            nrows = 1
+            ncols = 2
+
+            plt.close()
+            f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            for ax in axes.ravel():
+                ax.axis("off")
+            axes = axes.ravel()
+            axes[0].imshow(self.image)
+            axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
+
+            axes[1].imshow(cam)
+            axes[1].set_title("CAM")
+
+        if save_to_disk and save_outdir is not None:
+            os.makedirs(save_outdir, exist_ok=True)
+            save_fig(data_, save_outdir, 'cam')
+
+        if visualization:
+            plt.show()
+
+        return
+
+
+class LIME(object):
+    def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50):
+        """
+        LIME wrapper. See lime_base.py for the detailed LIME implementation.
+        Args:
+            predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME.
+            num_samples: the number of samples that LIME takes for fitting.
+            batch_size: batch size for model inference each time.
+        """
+        self.num_samples = num_samples
+        self.batch_size = batch_size
+
+        self.predict_fn = predict_fn
+        self.labels = None
+        self.image = None
+        self.lime_interpreter = None
+        self.label_names = label_names
+
+    def preparation_lime(self, data_):
+        image_show = read_image(data_)
+        result = self.predict_fn(image_show)
+
+        result = result[0]  # only one image here.
+
+        if abs(np.sum(result) - 1.0) > 1e-4:
+            # softmax
+            result = result - np.max(result)
+            exp_result = np.exp(result)
+            probability = exp_result / np.sum(exp_result)
+        else:
+            probability = result
+
+        # only interpret top 1
+        pred_label = np.argsort(probability)
+        pred_label = pred_label[-1:]
+
+        self.predicted_label = pred_label[0]
+        self.predicted_probability = probability[pred_label[0]]
+        self.image = image_show[0]
+        self.labels = pred_label
+        
+        l = pred_label[0]
+        ln = l
+        if self.label_names is not None:
+            ln = self.label_names[l]
+            
+        print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
+
+        end = time.time()
+        algo = lime_base.LimeImageInterpreter()
+        interpreter = algo.interpret_instance(self.image, self.predict_fn, self.labels, 0,
+                                              num_samples=self.num_samples, batch_size=self.batch_size)
+        self.lime_interpreter = interpreter
+        print('lime time: ', time.time() - end, 's.')
+
+    def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+        if self.lime_interpreter is None:
+            self.preparation_lime(data_)
+
+        if visualization or save_to_disk:
+            import matplotlib.pyplot as plt
+            from skimage.segmentation import mark_boundaries
+            l = self.labels[0]
+            ln = l 
+            if self.label_names is not None:
+                ln = self.label_names[l]
+
+            psize = 5
+            nrows = 2
+            weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
+            ncols = len(weights_choices)
+
+            plt.close()
+            f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            for ax in axes.ravel():
+                ax.axis("off")
+            axes = axes.ravel()
+            axes[0].imshow(self.image)
+            axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
+
+            axes[1].imshow(mark_boundaries(self.image, self.lime_interpreter.segments))
+            axes[1].set_title("superpixel segmentation")
+
+            # LIME visualization
+            for i, w in enumerate(weights_choices):
+                num_to_show = auto_choose_num_features_to_show(self.lime_interpreter, l, w)
+                temp, mask = self.lime_interpreter.get_image_and_mask(
+                    l, positive_only=False, hide_rest=False, num_features=num_to_show
+                )
+                axes[ncols + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols + i].set_title(f"label {ln}, first {num_to_show} superpixels")
+
+        if save_to_disk and save_outdir is not None:
+            os.makedirs(save_outdir, exist_ok=True)
+            save_fig(data_, save_outdir, 'lime', self.num_samples)
+
+        if visualization:
+            plt.show()
+
+        return
+
+
+class NormLIME(object):
+    def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50,
+                 kmeans_model_for_normlime=None, normlime_weights=None):
+        if kmeans_model_for_normlime is None:
+            try:
+                self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
+            except:
+                raise ValueError("NormLIME needs the KMeans model, where we provided a default one in "
+                                 "pre_models/kmeans_model.pkl.")
+        else:
+            print("Warning: It is *strongly* suggested to use the default KMeans model in pre_models/kmeans_model.pkl. "
+                  "Use another one will change the final result.")
+            self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
+
+        self.num_samples = num_samples
+        self.batch_size = batch_size
+
+        try:
+            self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
+        except:
+            self.normlime_weights = None
+            print("Warning: not find the correct precomputed Normlime result.")
+
+        self.predict_fn = predict_fn
+
+        self.labels = None
+        self.image = None
+        self.label_names = label_names
+
+    def predict_cluster_labels(self, feature_map, segments):
+        return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments))
+
+    def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
+        # global weights
+        g_weights = {y: [] for y in pred_labels}
+        for y in pred_labels:
+            cluster_weights_y = self.normlime_weights.get(y, {})
+            g_weights[y] = [
+                (i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
+            ]
+
+            g_weights[y] = sorted(g_weights[y],
+                                  key=lambda x: np.abs(x[1]), reverse=True)
+
+        return g_weights
+
+    def preparation_normlime(self, data_):
+        self._lime = LIME(
+            self.predict_fn,
+            self.label_names,
+            self.num_samples,
+            self.batch_size
+        )
+        self._lime.preparation_lime(data_)
+
+        image_show = read_image(data_)
+
+        self.predicted_label = self._lime.predicted_label
+        self.predicted_probability = self._lime.predicted_probability
+        self.image = image_show[0]
+        self.labels = self._lime.labels
+        # print(f'predicted result: {self.predicted_label} with probability {self.predicted_probability: .3f}')
+        print('performing NormLIME operations ...')
+
+        cluster_labels = self.predict_cluster_labels(
+            compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_interpreter.segments
+        )
+
+        g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
+
+        return g_weights
+
+    def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+        if self.normlime_weights is None:
+            raise ValueError("Not find the correct precomputed NormLIME result. \n"
+                             "\t Try to call compute_normlime_weights() first or load the correct path.")
+
+        g_weights = self.preparation_normlime(data_)
+        lime_weights = self._lime.lime_interpreter.local_weights
+
+        if visualization or save_to_disk:
+            import matplotlib.pyplot as plt
+            from skimage.segmentation import mark_boundaries
+            l = self.labels[0]
+            ln = l
+            if self.label_names is not None:
+                ln = self.label_names[l]
+
+            psize = 5
+            nrows = 4
+            weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
+            nums_to_show = []
+            ncols = len(weights_choices)
+
+            plt.close()
+            f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            for ax in axes.ravel():
+                ax.axis("off")
+
+            axes = axes.ravel()
+            axes[0].imshow(self.image)
+            axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
+
+            axes[1].imshow(mark_boundaries(self.image, self._lime.lime_interpreter.segments))
+            axes[1].set_title("superpixel segmentation")
+
+            # LIME visualization
+            for i, w in enumerate(weights_choices):
+                num_to_show = auto_choose_num_features_to_show(self._lime.lime_interpreter, l, w)
+                nums_to_show.append(num_to_show)
+                temp, mask = self._lime.lime_interpreter.get_image_and_mask(
+                    l, positive_only=False, hide_rest=False, num_features=num_to_show
+                )
+                axes[ncols + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels")
+
+            # NormLIME visualization
+            self._lime.lime_interpreter.local_weights = g_weights
+            for i, num_to_show in enumerate(nums_to_show):
+                temp, mask = self._lime.lime_interpreter.get_image_and_mask(
+                    l, positive_only=False, hide_rest=False, num_features=num_to_show
+                )
+                axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols * 2 + i].set_title(f"NormLIME: first {num_to_show} superpixels")
+
+            # NormLIME*LIME visualization
+            combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
+            self._lime.lime_interpreter.local_weights = combined_weights
+            for i, num_to_show in enumerate(nums_to_show):
+                temp, mask = self._lime.lime_interpreter.get_image_and_mask(
+                    l, positive_only=False, hide_rest=False, num_features=num_to_show
+                )
+                axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels")
+
+            self._lime.lime_interpreter.local_weights = lime_weights
+
+        if save_to_disk and save_outdir is not None:
+            os.makedirs(save_outdir, exist_ok=True)
+            save_fig(data_, save_outdir, 'normlime', self.num_samples)
+
+        if visualization:
+            plt.show()
+
+
+def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show):
+    segments = lime_interpreter.segments
+    lime_weights = lime_interpreter.local_weights[label]
+    num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
+
+    # l1 norm with filtered weights.
+    used_weights = [(tuple_w[0], tuple_w[1]) for i, tuple_w in enumerate(lime_weights) if tuple_w[1] > 0]
+    norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)])
+    normalized_weights = [(tuple_w[0], tuple_w[1] / norm) for i, tuple_w in enumerate(lime_weights)]
+
+    a = 0.0
+    n = 0
+    for i, tuple_w in enumerate(normalized_weights):
+        if tuple_w[1] < 0:
+            continue
+        if len(np.where(segments == tuple_w[0])[0]) < num_pixels_threshold_in_a_sp:
+            continue
+
+        a += tuple_w[1]
+        if a > percentage_to_show:
+            n = i + 1
+            break
+
+    if percentage_to_show <= 0.0:
+        return 5
+
+    if n == 0:
+        return auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show-0.1)
+
+    return n
+
+
+def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam_max=None):
+    _, nc, h, w = feature_maps.shape
+
+    cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1)
+    cam = cam.sum((0, 1))
+
+    if cam_min is None:
+        cam_min = np.min(cam)
+    if cam_max is None:
+        cam_max = np.max(cam)
+
+    cam = cam - cam_min
+    cam = cam / cam_max
+    cam = np.uint8(255 * cam)
+    cam_img = cv2.resize(cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
+
+    heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET)
+    heatmap = np.float32(heatmap)
+    cam = heatmap + np.float32(image_show)
+    cam = cam / np.max(cam)
+
+    return cam
+
+
+def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
+    import matplotlib.pyplot as plt
+    if isinstance(data_, str):
+        if algorithm_name == 'cam':
+            f_out = f"{algorithm_name}_{data_.split('/')[-1]}.png"
+        else:
+            f_out = f"{algorithm_name}_{data_.split('/')[-1]}_s{num_samples}.png"
+        plt.savefig(
+            os.path.join(save_outdir, f_out)
+        )
+    else:
+        n = 0
+        if algorithm_name == 'cam':
+            f_out = f'cam-{n}.png'
+        else:
+            f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
+        while os.path.exists(
+                os.path.join(save_outdir, f_out)
+        ):
+            n += 1
+            if algorithm_name == 'cam':
+                f_out = f'cam-{n}.png'
+            else:
+                f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
+            continue
+        plt.savefig(
+            os.path.join(
+                save_outdir, f_out
+            )
+        )

+ 527 - 0
paddlex/interpret/core/lime_base.py

@@ -0,0 +1,527 @@
+"""
+Copyright (c) 2016, Marco Tulio Correia Ribeiro
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+  list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+  this list of conditions and the following disclaimer in the documentation
+  and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+"""
+The code in this file (lime_base.py) is modified from https://github.com/marcotcr/lime.
+"""
+
+
+import numpy as np
+import scipy as sp
+import sklearn
+import sklearn.preprocessing
+from skimage.color import gray2rgb
+from sklearn.linear_model import Ridge, lars_path
+from sklearn.utils import check_random_state
+
+import copy
+from functools import partial
+from skimage.segmentation import quickshift
+from skimage.measure import regionprops
+
+
+class LimeBase(object):
+    """Class for learning a locally linear sparse model from perturbed data"""
+    def __init__(self,
+                 kernel_fn,
+                 verbose=False,
+                 random_state=None):
+        """Init function
+
+        Args:
+            kernel_fn: function that transforms an array of distances into an
+                        array of proximity values (floats).
+            verbose: if true, print local prediction values from linear model.
+            random_state: an integer or numpy.RandomState that will be used to
+                generate random numbers. If None, the random state will be
+                initialized using the internal numpy seed.
+        """
+        self.kernel_fn = kernel_fn
+        self.verbose = verbose
+        self.random_state = check_random_state(random_state)
+
+    @staticmethod
+    def generate_lars_path(weighted_data, weighted_labels):
+        """Generates the lars path for weighted data.
+
+        Args:
+            weighted_data: data that has been weighted by kernel
+            weighted_label: labels, weighted by kernel
+
+        Returns:
+            (alphas, coefs), both are arrays corresponding to the
+            regularization parameter and coefficients, respectively
+        """
+        x_vector = weighted_data
+        alphas, _, coefs = lars_path(x_vector,
+                                     weighted_labels,
+                                     method='lasso',
+                                     verbose=False)
+        return alphas, coefs
+
+    def forward_selection(self, data, labels, weights, num_features):
+        """Iteratively adds features to the model"""
+        clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
+        used_features = []
+        for _ in range(min(num_features, data.shape[1])):
+            max_ = -100000000
+            best = 0
+            for feature in range(data.shape[1]):
+                if feature in used_features:
+                    continue
+                clf.fit(data[:, used_features + [feature]], labels,
+                        sample_weight=weights)
+                score = clf.score(data[:, used_features + [feature]],
+                                  labels,
+                                  sample_weight=weights)
+                if score > max_:
+                    best = feature
+                    max_ = score
+            used_features.append(best)
+        return np.array(used_features)
+
+    def feature_selection(self, data, labels, weights, num_features, method):
+        """Selects features for the model. see interpret_instance_with_data to
+           understand the parameters."""
+        if method == 'none':
+            return np.array(range(data.shape[1]))
+        elif method == 'forward_selection':
+            return self.forward_selection(data, labels, weights, num_features)
+        elif method == 'highest_weights':
+            clf = Ridge(alpha=0.01, fit_intercept=True,
+                        random_state=self.random_state)
+            clf.fit(data, labels, sample_weight=weights)
+
+            coef = clf.coef_
+            if sp.sparse.issparse(data):
+                coef = sp.sparse.csr_matrix(clf.coef_)
+                weighted_data = coef.multiply(data[0])
+                # Note: most efficient to slice the data before reversing
+                sdata = len(weighted_data.data)
+                argsort_data = np.abs(weighted_data.data).argsort()
+                # Edge case where data is more sparse than requested number of feature importances
+                # In that case, we just pad with zero-valued features
+                if sdata < num_features:
+                    nnz_indexes = argsort_data[::-1]
+                    indices = weighted_data.indices[nnz_indexes]
+                    num_to_pad = num_features - sdata
+                    indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype)))
+                    indices_set = set(indices)
+                    pad_counter = 0
+                    for i in range(data.shape[1]):
+                        if i not in indices_set:
+                            indices[pad_counter + sdata] = i
+                            pad_counter += 1
+                            if pad_counter >= num_to_pad:
+                                break
+                else:
+                    nnz_indexes = argsort_data[sdata - num_features:sdata][::-1]
+                    indices = weighted_data.indices[nnz_indexes]
+                return indices
+            else:
+                weighted_data = coef * data[0]
+                feature_weights = sorted(
+                    zip(range(data.shape[1]), weighted_data),
+                    key=lambda x: np.abs(x[1]),
+                    reverse=True)
+                return np.array([x[0] for x in feature_weights[:num_features]])
+        elif method == 'lasso_path':
+            weighted_data = ((data - np.average(data, axis=0, weights=weights))
+                             * np.sqrt(weights[:, np.newaxis]))
+            weighted_labels = ((labels - np.average(labels, weights=weights))
+                               * np.sqrt(weights))
+            nonzero = range(weighted_data.shape[1])
+            _, coefs = self.generate_lars_path(weighted_data,
+                                               weighted_labels)
+            for i in range(len(coefs.T) - 1, 0, -1):
+                nonzero = coefs.T[i].nonzero()[0]
+                if len(nonzero) <= num_features:
+                    break
+            used_features = nonzero
+            return used_features
+        elif method == 'auto':
+            if num_features <= 6:
+                n_method = 'forward_selection'
+            else:
+                n_method = 'highest_weights'
+            return self.feature_selection(data, labels, weights,
+                                          num_features, n_method)
+
+    def interpret_instance_with_data(self,
+                                     neighborhood_data,
+                                     neighborhood_labels,
+                                     distances,
+                                     label,
+                                     num_features,
+                                     feature_selection='auto',
+                                     model_regressor=None):
+        """Takes perturbed data, labels and distances, returns interpretation.
+
+        Args:
+            neighborhood_data: perturbed data, 2d array. first element is
+                               assumed to be the original data point.
+            neighborhood_labels: corresponding perturbed labels. should have as
+                                 many columns as the number of possible labels.
+            distances: distances to original data point.
+            label: label for which we want an interpretation
+            num_features: maximum number of features in interpretation
+            feature_selection: how to select num_features. options are:
+                'forward_selection': iteratively add features to the model.
+                    This is costly when num_features is high
+                'highest_weights': selects the features that have the highest
+                    product of absolute weight * original data point when
+                    learning with all the features
+                'lasso_path': chooses features based on the lasso
+                    regularization path
+                'none': uses all features, ignores num_features
+                'auto': uses forward_selection if num_features <= 6, and
+                    'highest_weights' otherwise.
+            model_regressor: sklearn regressor to use in interpretation.
+                Defaults to Ridge regression if None. Must have
+                model_regressor.coef_ and 'sample_weight' as a parameter
+                to model_regressor.fit()
+
+        Returns:
+            (intercept, exp, score, local_pred):
+            intercept is a float.
+            exp is a sorted list of tuples, where each tuple (x,y) corresponds
+            to the feature id (x) and the local weight (y). The list is sorted
+            by decreasing absolute value of y.
+            score is the R^2 value of the returned interpretation
+            local_pred is the prediction of the interpretation model on the original instance
+        """
+
+        weights = self.kernel_fn(distances)
+        labels_column = neighborhood_labels[:, label]
+        used_features = self.feature_selection(neighborhood_data,
+                                               labels_column,
+                                               weights,
+                                               num_features,
+                                               feature_selection)
+        if model_regressor is None:
+            model_regressor = Ridge(alpha=1, fit_intercept=True,
+                                    random_state=self.random_state)
+        easy_model = model_regressor
+        easy_model.fit(neighborhood_data[:, used_features],
+                       labels_column, sample_weight=weights)
+        prediction_score = easy_model.score(
+            neighborhood_data[:, used_features],
+            labels_column, sample_weight=weights)
+
+        local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
+
+        if self.verbose:
+            print('Intercept', easy_model.intercept_)
+            print('Prediction_local', local_pred,)
+            print('Right:', neighborhood_labels[0, label])
+        return (easy_model.intercept_,
+                sorted(zip(used_features, easy_model.coef_),
+                       key=lambda x: np.abs(x[1]), reverse=True),
+                prediction_score, local_pred)
+
+
+class ImageInterpretation(object):
+    def __init__(self, image, segments):
+        """Init function.
+
+        Args:
+            image: 3d numpy array
+            segments: 2d numpy array, with the output from skimage.segmentation
+        """
+        self.image = image
+        self.segments = segments
+        self.intercept = {}
+        self.local_weights = {}
+        self.local_pred = None
+
+    def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
+                           num_features=5, min_weight=0.):
+        """Init function.
+
+        Args:
+            label: label to interpret
+            positive_only: if True, only take superpixels that positively contribute to
+                the prediction of the label.
+            negative_only: if True, only take superpixels that negatively contribute to
+                the prediction of the label. If false, and so is positive_only, then both
+                negativey and positively contributions will be taken.
+                Both can't be True at the same time
+            hide_rest: if True, make the non-interpretation part of the return
+                image gray
+            num_features: number of superpixels to include in interpretation
+            min_weight: minimum weight of the superpixels to include in interpretation
+
+        Returns:
+            (image, mask), where image is a 3d numpy array and mask is a 2d
+            numpy array that can be used with
+            skimage.segmentation.mark_boundaries
+        """
+        if label not in self.local_weights:
+            raise KeyError('Label not in interpretation')
+        if positive_only & negative_only:
+            raise ValueError("Positive_only and negative_only cannot be true at the same time.")
+        segments = self.segments
+        image = self.image
+        local_weights_label = self.local_weights[label]
+        mask = np.zeros(segments.shape, segments.dtype)
+        if hide_rest:
+            temp = np.zeros(self.image.shape)
+        else:
+            temp = self.image.copy()
+        if positive_only:
+            fs = [x[0] for x in local_weights_label
+                  if x[1] > 0 and x[1] > min_weight][:num_features]
+        if negative_only:
+            fs = [x[0] for x in local_weights_label
+                  if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
+        if positive_only or negative_only:
+            for f in fs:
+                temp[segments == f] = image[segments == f].copy()
+                mask[segments == f] = 1
+            return temp, mask
+        else:
+            for f, w in local_weights_label[:num_features]:
+                if np.abs(w) < min_weight:
+                    continue
+                c = 0 if w < 0 else 1
+                mask[segments == f] = -1 if w < 0 else 1
+                temp[segments == f] = image[segments == f].copy()
+                temp[segments == f, c] = np.max(image)
+            return temp, mask
+
+    def get_rendered_image(self, label, min_weight=0.005):
+        """
+
+        Args:
+            label: label to interpret
+            min_weight:
+
+        Returns:
+            image, is a 3d numpy array
+        """
+        if label not in self.local_weights:
+            raise KeyError('Label not in interpretation')
+
+        from matplotlib import cm
+
+        segments = self.segments
+        image = self.image
+        local_weights_label = self.local_weights[label]
+        temp = np.zeros_like(image)
+
+        weight_max = abs(local_weights_label[0][1])
+        local_weights_label = [(f, w/weight_max) for f, w in local_weights_label]
+        local_weights_label = sorted(local_weights_label, key=lambda x: x[1], reverse=True)  # negatives are at last.
+
+        cmaps = cm.get_cmap('Spectral')
+        colors = cmaps(np.linspace(0, 1, len(local_weights_label)))
+        colors = colors[:, :3]
+
+        for i, (f, w) in enumerate(local_weights_label):
+            if np.abs(w) < min_weight:
+                continue
+            temp[segments == f] = image[segments == f].copy()
+            temp[segments == f] = colors[i] * 255
+        return temp
+
+
+class LimeImageInterpreter(object):
+    """Interpres predictions on Image (i.e. matrix) data.
+    For numerical features, perturb them by sampling from a Normal(0,1) and
+    doing the inverse operation of mean-centering and scaling, according to the
+    means and stds in the training data. For categorical features, perturb by
+    sampling according to the training distribution, and making a binary
+    feature that is 1 when the value is the same as the instance being
+    interpreted."""
+
+    def __init__(self, kernel_width=.25, kernel=None, verbose=False,
+                 feature_selection='auto', random_state=None):
+        """Init function.
+
+        Args:
+            kernel_width: kernel width for the exponential kernel.
+            If None, defaults to sqrt(number of columns) * 0.75.
+            kernel: similarity kernel that takes euclidean distances and kernel
+                width as input and outputs weights in (0,1). If None, defaults to
+                an exponential kernel.
+            verbose: if true, print local prediction values from linear model
+            feature_selection: feature selection method. can be
+                'forward_selection', 'lasso_path', 'none' or 'auto'.
+                See function 'einterpret_instance_with_data' in lime_base.py for
+                details on what each of the options does.
+            random_state: an integer or numpy.RandomState that will be used to
+                generate random numbers. If None, the random state will be
+                initialized using the internal numpy seed.
+        """
+        kernel_width = float(kernel_width)
+
+        if kernel is None:
+            def kernel(d, kernel_width):
+                return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
+
+        kernel_fn = partial(kernel, kernel_width=kernel_width)
+
+        self.random_state = check_random_state(random_state)
+        self.feature_selection = feature_selection
+        self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
+
+    def interpret_instance(self, image, classifier_fn, labels=(1,),
+                           hide_color=None,
+                           num_features=100000, num_samples=1000,
+                           batch_size=10,
+                           distance_metric='cosine',
+                           model_regressor=None
+                           ):
+        """Generates interpretations for a prediction.
+
+        First, we generate neighborhood data by randomly perturbing features
+        from the instance (see __data_inverse). We then learn locally weighted
+        linear models on this neighborhood data to interpret each of the classes
+        in an interpretable way (see lime_base.py).
+
+        Args:
+            image: 3 dimension RGB image. If this is only two dimensional,
+                we will assume it's a grayscale image and call gray2rgb.
+            classifier_fn: classifier prediction probability function, which
+                takes a numpy array and outputs prediction probabilities.  For
+                ScikitClassifiers , this is classifier.predict_proba.
+            labels: iterable with labels to be interpreted.
+            hide_color: TODO
+            num_features: maximum number of features present in interpretation
+            num_samples: size of the neighborhood to learn the linear model
+            batch_size: TODO
+            distance_metric: the distance metric to use for weights.
+            model_regressor: sklearn regressor to use in interpretation. Defaults
+            to Ridge regression in LimeBase. Must have model_regressor.coef_
+            and 'sample_weight' as a parameter to model_regressor.fit()
+
+        Returns:
+            An ImageIinterpretation object (see lime_image.py) with the corresponding
+            interpretations.
+        """
+        if len(image.shape) == 2:
+            image = gray2rgb(image)
+
+        try:
+            segments = quickshift(image, sigma=1)
+        except ValueError as e:
+            raise e
+
+        self.segments = segments
+
+        fudged_image = image.copy()
+        if hide_color is None:
+            # if no hide_color, use the mean
+            for x in np.unique(segments):
+                mx = np.mean(image[segments == x], axis=0)
+                fudged_image[segments == x] = mx
+        elif hide_color == 'avg_from_neighbor':
+            from scipy.spatial.distance import cdist
+
+            n_features = np.unique(segments).shape[0]
+            regions = regionprops(segments + 1)
+            centroids = np.zeros((n_features, 2))
+            for i, x in enumerate(regions):
+                centroids[i] = np.array(x.centroid)
+
+            d = cdist(centroids, centroids, 'sqeuclidean')
+
+            for x in np.unique(segments):
+                # print(np.argmin(d[x]))
+                a = [image[segments == i] for i in np.argsort(d[x])[1:6]]
+                mx = np.mean(np.concatenate(a), axis=0)
+                fudged_image[segments == x] = mx
+
+        else:
+            fudged_image[:] = 0
+
+        top = labels
+
+        data, labels = self.data_labels(image, fudged_image, segments,
+                                        classifier_fn, num_samples,
+                                        batch_size=batch_size)
+
+        distances = sklearn.metrics.pairwise_distances(
+            data,
+            data[0].reshape(1, -1),
+            metric=distance_metric
+        ).ravel()
+
+        interpretation_image = ImageInterpretation(image, segments)
+        for label in top:
+            (interpretation_image.intercept[label],
+             interpretation_image.local_weights[label],
+             interpretation_image.score, interpretation_image.local_pred) = self.base.interpret_instance_with_data(
+                data, labels, distances, label, num_features,
+                model_regressor=model_regressor,
+                feature_selection=self.feature_selection)
+        return interpretation_image
+
+    def data_labels(self,
+                    image,
+                    fudged_image,
+                    segments,
+                    classifier_fn,
+                    num_samples,
+                    batch_size=10):
+        """Generates images and predictions in the neighborhood of this image.
+
+        Args:
+            image: 3d numpy array, the image
+            fudged_image: 3d numpy array, image to replace original image when
+                superpixel is turned off
+            segments: segmentation of the image
+            classifier_fn: function that takes a list of images and returns a
+                matrix of prediction probabilities
+            num_samples: size of the neighborhood to learn the linear model
+            batch_size: classifier_fn will be called on batches of this size.
+
+        Returns:
+            A tuple (data, labels), where:
+                data: dense num_samples * num_superpixels
+                labels: prediction probabilities matrix
+        """
+        n_features = np.unique(segments).shape[0]
+        data = self.random_state.randint(0, 2, num_samples * n_features) \
+            .reshape((num_samples, n_features))
+        labels = []
+        data[0, :] = 1
+        imgs = []
+        for row in data:
+            temp = copy.deepcopy(image)
+            zeros = np.where(row == 0)[0]
+            mask = np.zeros(segments.shape).astype(bool)
+            for z in zeros:
+                mask[segments == z] = True
+            temp[mask] = fudged_image[mask]
+            imgs.append(temp)
+            if len(imgs) == batch_size:
+                preds = classifier_fn(np.array(imgs))
+                labels.extend(preds)
+                imgs = []
+        if len(imgs) > 0:
+            preds = classifier_fn(np.array(imgs))
+            labels.extend(preds)
+        return data, np.array(labels)

+ 221 - 0
paddlex/interpret/core/normlime_base.py

@@ -0,0 +1,221 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import os
+import numpy as np
+import glob
+
+from ..as_data_reader.readers import read_image
+from . import lime_base
+from ._session_preparation import compute_features_for_kmeans, h_pre_models_kmeans
+
+
+def load_kmeans_model(fname):
+    import pickle
+    with open(fname, 'rb') as f:
+        kmeans_model = pickle.load(f)
+
+    return kmeans_model
+
+
+def combine_normlime_and_lime(lime_weights, g_weights):
+    pred_labels = lime_weights.keys()
+    combined_weights = {y: [] for y in pred_labels}
+
+    for y in pred_labels:
+        normlized_lime_weights_y = lime_weights[y]
+        lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y}
+
+        normlized_g_weight_y = g_weights[y]
+        normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y}
+
+        combined_weights[y] = [
+            (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
+            for seg_k in lime_weights_dict.keys()
+        ]
+
+        combined_weights[y] = sorted(combined_weights[y],
+                                     key=lambda x: np.abs(x[1]), reverse=True)
+
+    return combined_weights
+
+
+def avg_using_superpixels(features, segments):
+    one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
+    for x in np.unique(segments):
+        one_list[x] = np.mean(features[segments == x], axis=0)
+
+    return one_list
+
+
+def centroid_using_superpixels(features, segments):
+    from skimage.measure import regionprops
+    regions = regionprops(segments + 1)
+    one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
+    for i, r in enumerate(regions):
+        one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
+    # print(one_list.shape)
+    return one_list
+
+
+def get_feature_for_kmeans(feature_map, segments):
+    from sklearn.preprocessing import normalize
+    centroid_feature = centroid_using_superpixels(feature_map, segments)
+    avg_feature = avg_using_superpixels(feature_map, segments)
+    x = np.concatenate((centroid_feature, avg_feature), axis=-1)
+    x = normalize(x)
+    return x
+
+
+def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_size=50, save_dir='./tmp'):
+    # save lime weights and kmeans cluster labels
+    precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir)
+
+    # load precomputed results, compute normlime weights and save.
+    fname_list = glob.glob(os.path.join(save_dir, f'lime_weights_s{num_samples}*.npy'))
+    return compute_normlime_weights(fname_list, save_dir, num_samples)
+
+
+def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, cluster_labels, save_path):
+
+    lime_weights = {}
+    for label in image_pred_labels:
+        lime_weights[label] = lime_all_weights[label]
+
+    for_normlime_weights = {
+        'lime_weights': lime_weights,  # a dict: class_label: (seg_label, weight)
+        'cluster': cluster_labels  # a list with segments as indices.
+    }
+
+    np.save(save_path, for_normlime_weights)
+
+
+def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir):
+    kmeans_model = load_kmeans_model(h_pre_models_kmeans)
+
+    for data_index, each_data_ in enumerate(list_data_):
+        if isinstance(each_data_, str):
+            save_path = f"lime_weights_s{num_samples}_{each_data_.split('/')[-1].split('.')[0]}.npy"
+            save_path = os.path.join(save_dir, save_path)
+        else:
+            save_path = f"lime_weights_s{num_samples}_{data_index}.npy"
+            save_path = os.path.join(save_dir, save_path)
+
+        if os.path.exists(save_path):
+            print(f'{save_path} exists, not computing this one.')
+            continue
+
+        print('processing', each_data_ if isinstance(each_data_, str) else data_index,
+              f', {data_index}/{len(list_data_)}')
+
+        image_show = read_image(each_data_)
+        result = predict_fn(image_show)
+        result = result[0]  # only one image here.
+
+        if abs(np.sum(result) - 1.0) > 1e-4:
+            # softmax
+            exp_result = np.exp(result)
+            probability = exp_result / np.sum(exp_result)
+        else:
+            probability = result
+
+        pred_label = np.argsort(probability)[::-1]
+
+        # top_k = argmin(top_n) > threshold
+        threshold = 0.05
+        top_k = 0
+        for l in pred_label:
+            if probability[l] < threshold or top_k == 5:
+                break
+            top_k += 1
+
+        if top_k == 0:
+            top_k = 1
+
+        pred_label = pred_label[:top_k]
+
+        algo = lime_base.LimeImageInterpreter()
+        interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0,
+                                          num_samples=num_samples, batch_size=batch_size)
+
+        cluster_labels = kmeans_model.predict(
+            get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments)
+        )
+        save_one_lime_predict_and_kmean_labels(
+            interpreter.local_weights, pred_label,
+            cluster_labels,
+            save_path
+        )
+
+
+def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
+    normlime_weights_all_labels = {}
+    for f in a_list_lime_fnames:
+        try:
+            lime_weights_and_cluster = np.load(f, allow_pickle=True).item()
+            lime_weights = lime_weights_and_cluster['lime_weights']
+            cluster = lime_weights_and_cluster['cluster']
+        except:
+            print('When loading precomputed LIME result, skipping', f)
+            continue
+        print('Loading precomputed LIME result,', f)
+
+        pred_labels = lime_weights.keys()
+        for y in pred_labels:
+            normlime_weights = normlime_weights_all_labels.get(y, {})
+            w_f_y = [abs(w[1]) for w in lime_weights[y]]
+            w_f_y_l1norm = sum(w_f_y)
+
+            for w in lime_weights[y]:
+                seg_label = w[0]
+                weight = w[1] * w[1] / w_f_y_l1norm
+                a = normlime_weights.get(cluster[seg_label], [])
+                a.append(weight)
+                normlime_weights[cluster[seg_label]] = a
+
+            normlime_weights_all_labels[y] = normlime_weights
+
+    # compute normlime
+    for y in normlime_weights_all_labels:
+        normlime_weights = normlime_weights_all_labels.get(y, {})
+        for k in normlime_weights:
+            normlime_weights[k] = sum(normlime_weights[k]) / len(normlime_weights[k])
+
+    # check normlime
+    if len(normlime_weights_all_labels.keys()) < max(normlime_weights_all_labels.keys()) + 1:
+        print(
+            "\n"
+            "Warning: !!! \n"
+            f"There are at least {max(normlime_weights_all_labels.keys()) + 1} classes, "
+            f"but the NormLIME has results of only {len(normlime_weights_all_labels.keys())} classes. \n"
+            "It may have cause unstable results in the later computation"
+            " but can be improved by computing more test samples."
+            "\n"
+        )
+
+    n = 0
+    f_out = f'normlime_weights_s{lime_num_samples}_samples_{len(a_list_lime_fnames)}-{n}.npy'
+    while os.path.exists(
+            os.path.join(save_dir, f_out)
+    ):
+        n += 1
+        f_out = f'normlime_weights_s{lime_num_samples}_samples_{len(a_list_lime_fnames)}-{n}.npy'
+        continue
+
+    np.save(
+        os.path.join(save_dir, f_out),
+        normlime_weights_all_labels
+    )
+    return os.path.join(save_dir, f_out)
+

+ 29 - 0
paddlex/interpret/interpretation_predict.py

@@ -0,0 +1,29 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+def interpretation_predict(model, images):
+    model.arrange_transforms(
+            transforms=model.test_transforms, mode='test')
+    new_imgs = []
+    for i in range(images.shape[0]):
+        img = images[i]
+        new_imgs.append(model.test_transforms(img)[0])
+    new_imgs = np.array(new_imgs)
+    result = model.exe.run(
+        model.test_prog,
+        feed={'image': new_imgs},
+        fetch_list=list(model.interpretation_feats.values()))
+    return result

+ 141 - 0
paddlex/interpret/visualize.py

@@ -0,0 +1,141 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import os
+import cv2
+import copy
+import os.path as osp
+import numpy as np
+import paddlex as pdx
+from .interpretation_predict import interpretation_predict
+from .core.interpretation import Interpretation
+from .core.normlime_base import precompute_normlime_weights
+
+
+def visualize(img_file, 
+              model, 
+              dataset=None,
+              algo='lime',
+              num_samples=3000, 
+              batch_size=50,
+              save_dir='./'):
+    """可解释性可视化。
+    Args:
+        img_file (str): 预测图像路径。
+        model (paddlex.cv.models): paddlex中的模型。
+        dataset (paddlex.datasets): 数据集读取器,默认为None。
+        algo (str): 可解释性方式,当前可选'lime'和'normlime'。
+        num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
+        batch_size (int): 预测数据batch大小,默认为50。
+        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。        
+    """
+    assert model.model_type == 'classifier', \
+        'Now the interpretation visualize only be supported in classifier!'
+    if model.status != 'Normal':
+        raise Exception('The interpretation only can deal with the Normal model')
+    model.arrange_transforms(
+                transforms=model.test_transforms, mode='test')
+    tmp_transforms = copy.deepcopy(model.test_transforms)
+    tmp_transforms.transforms = tmp_transforms.transforms[:-2]
+    img = tmp_transforms(img_file)[0]
+    img = np.around(img).astype('uint8')
+    img = np.expand_dims(img, axis=0)
+    interpreter = None
+    if algo == 'lime':
+        interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
+    elif algo == 'normlime':
+        if dataset is None:
+            raise Exception('The dataset is None. Cannot implement this kind of interpretation')
+        interpreter = get_normlime_interpreter(img, model, dataset, 
+                                     num_samples=num_samples, batch_size=batch_size,
+                                     save_dir=save_dir)
+    else:
+        raise Exception('The {} interpretation method is not supported yet!'.format(algo))
+    img_name = osp.splitext(osp.split(img_file)[-1])[0]
+    interpreter.interpret(img, save_dir=save_dir)
+    
+    
+def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
+    def predict_func(image):
+        image = image.astype('float32')
+        for i in range(image.shape[0]):
+            image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
+        tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
+        model.test_transforms.transforms = model.test_transforms.transforms[-2:]
+        out = interpretation_predict(model, image)
+        model.test_transforms.transforms = tmp_transforms
+        return out[0]
+    labels_name = None
+    if dataset is not None:
+        labels_name = dataset.labels
+    interpreter = Interpretation('lime', 
+                            predict_func,
+                            labels_name,
+                            num_samples=num_samples, 
+                            batch_size=batch_size)
+    return interpreter
+
+
+def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
+    def precompute_predict_func(image):
+        image = image.astype('float32')
+        tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
+        model.test_transforms.transforms = model.test_transforms.transforms[-2:]
+        out = interpretation_predict(model, image)
+        model.test_transforms.transforms = tmp_transforms
+        return out[0]
+    def predict_func(image):
+        image = image.astype('float32')
+        for i in range(image.shape[0]):
+            image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
+        tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
+        model.test_transforms.transforms = model.test_transforms.transforms[-2:]
+        out = interpretation_predict(model, image)
+        model.test_transforms.transforms = tmp_transforms
+        return out[0]
+    labels_name = None
+    if dataset is not None:
+        labels_name = dataset.labels
+    root_path = os.environ['HOME']
+    root_path = osp.join(root_path, '.paddlex')
+    pre_models_path = osp.join(root_path, "pre_models")
+    if not osp.exists(pre_models_path):
+        os.makedirs(pre_models_path)
+        url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
+        pdx.utils.download_and_decompress(url, path=pre_models_path)
+    npy_dir = precompute_for_normlime(precompute_predict_func, 
+                                      dataset, 
+                                      num_samples=num_samples, 
+                                      batch_size=batch_size,
+                                      save_dir=save_dir)
+    interpreter = Interpretation('normlime', 
+                            predict_func,
+                            labels_name,
+                            num_samples=num_samples, 
+                            batch_size=batch_size,
+                            normlime_weights=npy_dir)
+    return interpreter
+
+
+def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
+    image_list = []
+    for item in dataset.file_list:
+        image_list.append(item[0])
+    return precompute_normlime_weights(
+            image_list,  
+            predict_func,
+            num_samples=num_samples, 
+            batch_size=batch_size,
+            save_dir=save_dir)
+  

+ 1 - 8
paddlex/tools/__init__.py

@@ -14,11 +14,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .x2imagenet import EasyData2ImageNet
-from .x2coco import LabelMe2COCO
-from .x2coco import EasyData2COCO
-from .x2voc import LabelMe2VOC
-from .x2voc import EasyData2VOC
-from .x2seg import JingLing2Seg
-from .x2seg import LabelMe2Seg
-from .x2seg import EasyData2Seg
+from .convert import *

+ 33 - 0
paddlex/tools/convert.py

@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# coding: utf-8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .x2imagenet import EasyData2ImageNet
+from .x2coco import LabelMe2COCO
+from .x2coco import EasyData2COCO
+from .x2voc import LabelMe2VOC
+from .x2voc import EasyData2VOC
+from .x2seg import JingLing2Seg
+from .x2seg import LabelMe2Seg
+from .x2seg import EasyData2Seg
+
+easydata2imagenet = EasyData2ImageNet().convert
+labelme2coco = LabelMe2COCO().convert
+easydata2coco = EasyData2COCO().convert
+labelme2voc = LabelMe2VOC().convert
+easydata2voc = EasyData2VOC().convert
+jingling2seg = JingLing2Seg().convert
+labelme2seg = LabelMe2Seg().convert
+easydata2seg = EasyData2Seg().convert

+ 47 - 0
tutorials/interpret/interpret.py

@@ -0,0 +1,47 @@
+import os
+# 选择使用0号卡
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import os.path as osp
+import paddlex as pdx
+from paddlex.cls import transforms
+
+# 下载和解压Imagenet果蔬分类数据集
+veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
+pdx.utils.download_and_decompress(veg_dataset, path='./')
+
+# 定义测试集的transform
+test_transforms = transforms.Compose([
+    transforms.ResizeByShort(short_size=256),
+    transforms.CenterCrop(crop_size=224),
+    transforms.Normalize()
+])
+
+# 定义测试所用的数据集
+test_dataset = pdx.datasets.ImageNet(
+    data_dir='mini_imagenet_veg',
+    file_list=osp.join('mini_imagenet_veg', 'test_list.txt'),
+    label_list=osp.join('mini_imagenet_veg', 'labels.txt'),
+    transforms=test_transforms)
+
+# 下载和解压已训练好的MobileNetV2模型
+model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
+pdx.utils.download_and_decompress(model_file, path='./')
+
+# 导入模型
+model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
+
+# 可解释性可视化
+save_dir = 'interpret_results'
+if not osp.exists(save_dir):
+    os.makedirs(save_dir)
+pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+          model,
+          test_dataset, 
+          algo='lime',
+          save_dir=save_dir)
+pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+          model, 
+          test_dataset, 
+          algo='normlime',
+          save_dir=save_dir)