Przeglądaj źródła

Merge pull request #11 from PaddlePaddle/develop

alex
SunAhong1993 5 lat temu
rodzic
commit
f13d936746

+ 1 - 1
deploy/cpp/src/paddlex.cpp

@@ -98,7 +98,7 @@ bool Model::load_config(const std::string& model_dir) {
 
 bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) {
   cv::Mat im = input_im.clone();
-  if (!transforms_.Run(&im, &inputs_)) {
+  if (!transforms_.Run(&im, blob)) {
     return false;
   }
   return true;

+ 6 - 0
docs/FAQ.md

@@ -60,3 +60,9 @@
 ## 11. 每次训练新的模型,都需要重新下载预训练模型,怎样可以下载一次就搞定
 > 1.可以按照9的方式来解决这个问题  
 > 2.每次训练前都设定`paddlex.pretrain_dir`路径,如设定`paddlex.pretrain_dir='/usrname/paddlex`,如此下载完的预训练模型会存放至`/usrname/paddlex`目录下,而已经下载在该目录的模型也不会再次重复下载
+
+## 12. 程序启动时提示"Failed to execute script PaddleX",如何解决?
+> 1. 请检查目标机器上PaddleX程序所在路径是否包含中文。目前暂不支持中文路径,请尝试将程序移动到英文目录。
+> 2. 如果您的系统是Windows 7或者Windows Server 2012时,原因是缺少MFPlat.DLL/MF.dll/MFReadWrite.dll等OpenCV依赖的DLL,请按如下方式安装桌面体验:通过“我的电脑”-->“属性”-->"管理"打开服务器管理器,点击右上角“管理”选择“添加角色和功能”。点击“服务器选择”-->“功能”,拖动滚动条到最下端,点开“用户界面和基础结构”,勾选“桌面体验”后点击“安装”,等安装完成尝试再次运行PaddleX。
+> 3. 请检查目标机器上是否有其他的PaddleX程序或者进程在运行中,如有请退出或者重启机器看是否解决
+> 4. 请确认运行程序的用户是否有管理员权限,如非管理员权限用户请尝试使用管理员运行看是否成功

+ 1 - 1
paddlex/__init__.py

@@ -53,4 +53,4 @@ log_level = 2
 
 from . import interpret
 
-__version__ = '1.0.5'
+__version__ = '1.0.6'

+ 1 - 0
paddlex/cls.py

@@ -37,5 +37,6 @@ DenseNet161 = cv.models.DenseNet161
 DenseNet201 = cv.models.DenseNet201
 ShuffleNetV2 = cv.models.ShuffleNetV2
 HRNet_W18 = cv.models.HRNet_W18
+AlexNet = cv.models.AlexNet
 
 transforms = cv.transforms.cls_transforms

+ 1 - 0
paddlex/cv/models/__init__.py

@@ -35,6 +35,7 @@ from .classifier import DenseNet161
 from .classifier import DenseNet201
 from .classifier import ShuffleNetV2
 from .classifier import HRNet_W18
+from .classifier import AlexNet
 from .base import BaseAPI
 from .yolo_v3 import YOLOv3
 from .faster_rcnn import FasterRCNN

+ 9 - 13
paddlex/cv/models/base.py

@@ -221,8 +221,8 @@ class BaseAPI:
             logging.info(
                 "Load pretrain weights from {}.".format(pretrain_weights),
                 use_color=True)
-            paddlex.utils.utils.load_pretrain_weights(
-                self.exe, self.train_prog, pretrain_weights, fuse_bn)
+            paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog,
+                                                      pretrain_weights, fuse_bn)
         # 进行裁剪
         if sensitivities_file is not None:
             import paddleslim
@@ -262,6 +262,7 @@ class BaseAPI:
 
         info['_Attributes']['num_classes'] = self.num_classes
         info['_Attributes']['labels'] = self.labels
+        info['_Attributes']['fixed_input_shape'] = self.fixed_input_shape
         try:
             primary_metric_key = list(self.eval_metrics.keys())[0]
             primary_metric_value = float(self.eval_metrics[primary_metric_key])
@@ -325,9 +326,7 @@ class BaseAPI:
         logging.info("Model saved in {}.".format(save_dir))
 
     def export_inference_model(self, save_dir):
-        test_input_names = [
-            var.name for var in list(self.test_inputs.values())
-        ]
+        test_input_names = [var.name for var in list(self.test_inputs.values())]
         test_outputs = list(self.test_outputs.values())
         if self.__class__.__name__ == 'MaskRCNN':
             from paddlex.utils.save import save_mask_inference_model
@@ -364,8 +363,7 @@ class BaseAPI:
 
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
-        logging.info("Model for inference deploy saved in {}.".format(
-            save_dir))
+        logging.info("Model for inference deploy saved in {}.".format(save_dir))
 
     def train_loop(self,
                    num_epochs,
@@ -489,13 +487,11 @@ class BaseAPI:
                         eta = ((num_epochs - i) * total_num_steps - step - 1
                                ) * avg_step_time
                     if time_eval_one_epoch is not None:
-                        eval_eta = (
-                            total_eval_times - i // save_interval_epochs
-                        ) * time_eval_one_epoch
+                        eval_eta = (total_eval_times - i // save_interval_epochs
+                                    ) * time_eval_one_epoch
                     else:
-                        eval_eta = (
-                            total_eval_times - i // save_interval_epochs
-                        ) * total_num_steps_eval * avg_step_time
+                        eval_eta = (total_eval_times - i // save_interval_epochs
+                                    ) * total_num_steps_eval * avg_step_time
                     eta_str = seconds_to_hms(eta + eval_eta)
 
                     logging.info(

+ 9 - 0
paddlex/cv/models/classifier.py

@@ -48,6 +48,8 @@ class BaseClassifier(BaseAPI):
         self.fixed_input_shape = None
 
     def build_net(self, mode='train'):
+        if self.__class__.__name__ == "AlexNet":
+            assert self.fixed_input_shape is not None, "In AlexNet, input_shape should be defined, e.g. model = paddlex.cls.AlexNet(num_classes=1000, input_shape=[224, 224])"
         if self.fixed_input_shape is not None:
             input_shape = [
                 None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
@@ -427,3 +429,10 @@ class HRNet_W18(BaseClassifier):
     def __init__(self, num_classes=1000):
         super(HRNet_W18, self).__init__(
             model_name='HRNet_W18', num_classes=num_classes)
+
+
+class AlexNet(BaseClassifier):
+    def __init__(self, num_classes=1000, input_shape=None):
+        super(AlexNet, self).__init__(
+            model_name='AlexNet', num_classes=num_classes)
+        self.fixed_input_shape = input_shape

+ 11 - 2
paddlex/cv/models/load_model.py

@@ -41,7 +41,16 @@ def load_model(model_dir, fixed_input_shape=None):
     if 'model_name' in info['_init_params']:
         del info['_init_params']['model_name']
     model = getattr(paddlex.cv.models, info['Model'])(**info['_init_params'])
+
     model.fixed_input_shape = fixed_input_shape
+    if '_Attributes' in info:
+        if 'fixed_input_shape' in info['_Attributes']:
+            fixed_input_shape = info['_Attributes']['fixed_input_shape']
+            if fixed_input_shape is not None:
+                logging.info("Model already has fixed_input_shape with {}".
+                             format(fixed_input_shape))
+                model.fixed_input_shape = fixed_input_shape
+
     if status == "Normal" or \
             status == "Prune" or status == "fluid.save":
         startup_prog = fluid.Program()
@@ -88,8 +97,8 @@ def load_model(model_dir, fixed_input_shape=None):
                 model.model_type, info['Transforms'], info['BatchTransforms'])
             model.eval_transforms = copy.deepcopy(model.test_transforms)
         else:
-            model.test_transforms = build_transforms(
-                model.model_type, info['Transforms'], to_rgb)
+            model.test_transforms = build_transforms(model.model_type,
+                                                     info['Transforms'], to_rgb)
             model.eval_transforms = copy.deepcopy(model.test_transforms)
 
     if '_Attributes' in info:

+ 8 - 4
paddlex/cv/models/utils/pretrain_weights.py

@@ -70,6 +70,8 @@ image_pretrain = {
     'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
     'HRNet_W64':
     'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
+    'AlexNet':
+    'http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.tar'
 }
 
 coco_pretrain = {
@@ -99,10 +101,12 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
                 backbone = 'DetResNet50'
         assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
             backbone)
-        #        url = image_pretrain[backbone]
-        #        fname = osp.split(url)[-1].split('.')[0]
-        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
-        #        return osp.join(new_save_dir, fname)
+
+        #        if backbone == 'AlexNet':
+        #            url = image_pretrain[backbone]
+        #            fname = osp.split(url)[-1].split('.')[0]
+        #            paddlex.utils.download_and_decompress(url, path=new_save_dir)
+        #            return osp.join(new_save_dir, fname)
         try:
             hub.download(backbone, save_path=new_save_dir)
         except Exception as e:

+ 6 - 0
paddlex/cv/nets/__init__.py

@@ -24,6 +24,7 @@ from .xception import Xception
 from .densenet import DenseNet
 from .shufflenet_v2 import ShuffleNetV2
 from .hrnet import HRNet
+from .alexnet import AlexNet
 
 
 def resnet18(input, num_classes=1000):
@@ -153,3 +154,8 @@ def shufflenetv2(input, num_classes=1000):
 def hrnet_w18(input, num_classes=1000):
     model = HRNet(width=18, num_classes=num_classes)
     return model(input)
+
+
+def alexnet(input, num_classes=1000):
+    model = AlexNet(num_classes=num_classes)
+    return model(input)

+ 170 - 0
paddlex/cv/nets/alexnet.py

@@ -0,0 +1,170 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import paddle
+import paddle.fluid as fluid
+
+
+class AlexNet():
+    def __init__(self, num_classes=1000):
+        assert num_classes is not None, "In AlextNet, num_classes cannot be None"
+        self.num_classes = num_classes
+
+    def __call__(self, input):
+        stdv = 1.0 / math.sqrt(input.shape[1] * 11 * 11)
+        layer_name = [
+            "conv1", "conv2", "conv3", "conv4", "conv5", "fc6", "fc7", "fc8"
+        ]
+        conv1 = fluid.layers.conv2d(
+            input=input,
+            num_filters=64,
+            filter_size=11,
+            stride=4,
+            padding=2,
+            groups=1,
+            act='relu',
+            bias_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[0] + "_offset"),
+            param_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[0] + "_weights"))
+        pool1 = fluid.layers.pool2d(
+            input=conv1,
+            pool_size=3,
+            pool_stride=2,
+            pool_padding=0,
+            pool_type='max')
+
+        stdv = 1.0 / math.sqrt(pool1.shape[1] * 5 * 5)
+        conv2 = fluid.layers.conv2d(
+            input=pool1,
+            num_filters=192,
+            filter_size=5,
+            stride=1,
+            padding=2,
+            groups=1,
+            act='relu',
+            bias_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[1] + "_offset"),
+            param_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[1] + "_weights"))
+        pool2 = fluid.layers.pool2d(
+            input=conv2,
+            pool_size=3,
+            pool_stride=2,
+            pool_padding=0,
+            pool_type='max')
+
+        stdv = 1.0 / math.sqrt(pool2.shape[1] * 3 * 3)
+        conv3 = fluid.layers.conv2d(
+            input=pool2,
+            num_filters=384,
+            filter_size=3,
+            stride=1,
+            padding=1,
+            groups=1,
+            act='relu',
+            bias_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[2] + "_offset"),
+            param_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[2] + "_weights"))
+
+        stdv = 1.0 / math.sqrt(conv3.shape[1] * 3 * 3)
+        conv4 = fluid.layers.conv2d(
+            input=conv3,
+            num_filters=256,
+            filter_size=3,
+            stride=1,
+            padding=1,
+            groups=1,
+            act='relu',
+            bias_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[3] + "_offset"),
+            param_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[3] + "_weights"))
+
+        stdv = 1.0 / math.sqrt(conv4.shape[1] * 3 * 3)
+        conv5 = fluid.layers.conv2d(
+            input=conv4,
+            num_filters=256,
+            filter_size=3,
+            stride=1,
+            padding=1,
+            groups=1,
+            act='relu',
+            bias_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[4] + "_offset"),
+            param_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[4] + "_weights"))
+        pool5 = fluid.layers.pool2d(
+            input=conv5,
+            pool_size=3,
+            pool_stride=2,
+            pool_padding=0,
+            pool_type='max')
+
+        drop6 = fluid.layers.dropout(x=pool5, dropout_prob=0.5)
+        stdv = 1.0 / math.sqrt(drop6.shape[1] * drop6.shape[2] *
+                               drop6.shape[3] * 1.0)
+
+        fc6 = fluid.layers.fc(
+            input=drop6,
+            size=4096,
+            act='relu',
+            bias_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[5] + "_offset"),
+            param_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[5] + "_weights"))
+        drop7 = fluid.layers.dropout(x=fc6, dropout_prob=0.5)
+        stdv = 1.0 / math.sqrt(drop7.shape[1] * 1.0)
+
+        fc7 = fluid.layers.fc(
+            input=drop7,
+            size=4096,
+            act='relu',
+            bias_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[6] + "_offset"),
+            param_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[6] + "_weights"))
+
+        stdv = 1.0 / math.sqrt(fc7.shape[1] * 1.0)
+        out = fluid.layers.fc(
+            input=fc7,
+            size=self.num_classes,
+            bias_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[7] + "_offset"),
+            param_attr=fluid.param_attr.ParamAttr(
+                initializer=fluid.initializer.Uniform(-stdv, stdv),
+                name=layer_name[7] + "_weights"))
+        return out

+ 2 - 2
paddlex/cv/nets/hrnet.py

@@ -71,7 +71,7 @@ class HRNet(object):
         self.end_points = []
         return
 
-    def net(self, input, class_dim=1000):
+    def net(self, input):
         width = self.width
         channels_2, channels_3, channels_4 = self.channels[width]
         num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
@@ -125,7 +125,7 @@ class HRNet(object):
             stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
             out = fluid.layers.fc(
                 input=pool,
-                size=class_dim,
+                size=self.num_classes,
                 param_attr=ParamAttr(
                     name='fc_weights',
                     initializer=fluid.initializer.Uniform(-stdv, stdv)),

+ 24 - 8
paddlex/interpret/core/_session_preparation.py

@@ -20,6 +20,7 @@ import numpy as np
 from paddle.fluid.param_attr import ParamAttr
 from paddlex.interpret.as_data_reader.readers import preprocess_image
 
+
 def gen_user_home():
     if "HOME" in os.environ:
         home_path = os.environ["HOME"]
@@ -34,10 +35,20 @@ def paddle_get_fc_weights(var_name="fc_0.w_0"):
 
 
 def paddle_resize(extracted_features, outsize):
-    resized_features = fluid.layers.resize_bilinear(extracted_features, outsize)
+    resized_features = fluid.layers.resize_bilinear(extracted_features,
+                                                    outsize)
     return resized_features
 
 
+def get_precomputed_normlime_weights():
+    root_path = gen_user_home()
+    root_path = osp.join(root_path, '.paddlex')
+    h_pre_models = osp.join(root_path, "pre_models")
+    normlime_weights_file = osp.join(
+        h_pre_models, "normlime_weights_imagenet_resnet50vc.npy")
+    return np.load(normlime_weights_file, allow_pickle=True).item()
+
+
 def compute_features_for_kmeans(data_content):
     root_path = gen_user_home()
     root_path = osp.join(root_path, '.paddlex')
@@ -47,6 +58,7 @@ def compute_features_for_kmeans(data_content):
             os.makedirs(root_path)
         url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
         pdx.utils.download_and_decompress(url, path=root_path)
+
     def conv_bn_layer(input,
                       num_filters,
                       filter_size,
@@ -55,7 +67,7 @@ def compute_features_for_kmeans(data_content):
                       act=None,
                       name=None,
                       is_test=True,
-                      global_name=''):
+                      global_name='for_kmeans_'):
         conv = fluid.layers.conv2d(
             input=input,
             num_filters=num_filters,
@@ -79,14 +91,14 @@ def compute_features_for_kmeans(data_content):
             bias_attr=ParamAttr(global_name + bn_name + '_offset'),
             moving_mean_name=global_name + bn_name + '_mean',
             moving_variance_name=global_name + bn_name + '_variance',
-            use_global_stats=is_test
-        )
+            use_global_stats=is_test)
 
     startup_prog = fluid.default_startup_program().clone(for_test=True)
     prog = fluid.Program()
     with fluid.program_guard(prog, startup_prog):
         with fluid.unique_name.guard():
-            image_op = fluid.data(name='image', shape=[None, 3, 224, 224], dtype='float32')
+            image_op = fluid.data(
+                name='image', shape=[None, 3, 224, 224], dtype='float32')
 
             conv = conv_bn_layer(
                 input=image_op,
@@ -110,7 +122,8 @@ def compute_features_for_kmeans(data_content):
                 act='relu',
                 name='conv1_3')
             extracted_features = conv
-            resized_features = fluid.layers.resize_bilinear(extracted_features, image_op.shape[2:])
+            resized_features = fluid.layers.resize_bilinear(extracted_features,
+                                                            image_op.shape[2:])
 
     gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
     place = fluid.CUDAPlace(gpu_id)
@@ -119,7 +132,10 @@ def compute_features_for_kmeans(data_content):
     exe.run(startup_prog)
     fluid.io.load_persistables(exe, h_pre_models, prog)
 
-    images = preprocess_image(data_content)  # transpose to [N, 3, H, W], scaled to [0.0, 1.0]
-    result = exe.run(prog, fetch_list=[resized_features], feed={'image': images})
+    images = preprocess_image(
+        data_content)  # transpose to [N, 3, H, W], scaled to [0.0, 1.0]
+    result = exe.run(prog,
+                     fetch_list=[resized_features],
+                     feed={'image': images})
 
     return result[0][0]

+ 7 - 11
paddlex/interpret/core/interpretation.py

@@ -20,12 +20,10 @@ class Interpretation(object):
     """
     Base class for all interpretation algorithms.
     """
-    def __init__(self, interpretation_algorithm_name, predict_fn, label_names, **kwargs):
-        supported_algorithms = {
-            'cam': CAM,
-            'lime': LIME,
-            'normlime': NormLIME
-        }
+
+    def __init__(self, interpretation_algorithm_name, predict_fn, label_names,
+                 **kwargs):
+        supported_algorithms = {'cam': CAM, 'lime': LIME, 'normlime': NormLIME}
 
         self.algorithm_name = interpretation_algorithm_name.lower()
         assert self.algorithm_name in supported_algorithms.keys()
@@ -33,19 +31,17 @@ class Interpretation(object):
 
         # initialization for the interpretation algorithm.
         self.algorithm = supported_algorithms[self.algorithm_name](
-            self.predict_fn, label_names, **kwargs
-        )
+            self.predict_fn, label_names, **kwargs)
 
-    def interpret(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
+    def interpret(self, data_, visualization=True, save_dir='./'):
         """
 
         Args:
             data_: data_ can be a path or numpy.ndarray.
             visualization: whether to show using matplotlib.
-            save_to_disk: whether to save the figure in local disk.
             save_dir: dir to save figure if save_to_disk is True.
 
         Returns:
 
         """
-        return self.algorithm.interpret(data_, visualization, save_to_disk, save_dir)
+        return self.algorithm.interpret(data_, visualization, save_dir)

+ 316 - 106
paddlex/interpret/core/interpretation_algorithms.py

@@ -23,7 +23,6 @@ from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, lo
 from paddlex.interpret.as_data_reader.readers import read_image
 import paddlex.utils.logging as logging
 
-
 import cv2
 
 
@@ -66,25 +65,27 @@ class CAM(object):
 
         fc_weights = paddle_get_fc_weights()
         feature_maps = result[1]
-        
+
         l = pred_label[0]
         ln = l
         if self.label_names is not None:
             ln = self.label_names[l]
 
         prob_str = "%.3f" % (probability[pred_label[0]])
-        logging.info("predicted result: {} with probability {}.".format(ln, prob_str))
+        logging.info("predicted result: {} with probability {}.".format(
+            ln, prob_str))
         return feature_maps, fc_weights
 
-    def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+    def interpret(self, data_, visualization=True, save_outdir=None):
         feature_maps, fc_weights = self.preparation_cam(data_)
-        cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
+        cam = get_cam(self.image, feature_maps, fc_weights,
+                      self.predicted_label)
 
-        if visualization or save_to_disk:
+        if visualization or save_outdir is not None:
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
-            ln = l 
+            ln = l
             if self.label_names is not None:
                 ln = self.label_names[l]
 
@@ -93,7 +94,8 @@ class CAM(object):
             ncols = 2
 
             plt.close()
-            f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            f, axes = plt.subplots(
+                nrows, ncols, figsize=(psize * ncols, psize * nrows))
             for ax in axes.ravel():
                 ax.axis("off")
             axes = axes.ravel()
@@ -104,8 +106,7 @@ class CAM(object):
             axes[1].imshow(cam)
             axes[1].set_title("CAM")
 
-        if save_to_disk and save_outdir is not None:
-            os.makedirs(save_outdir, exist_ok=True)
+        if save_outdir is not None:
             save_fig(data_, save_outdir, 'cam')
 
         if visualization:
@@ -115,7 +116,11 @@ class CAM(object):
 
 
 class LIME(object):
-    def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50):
+    def __init__(self,
+                 predict_fn,
+                 label_names,
+                 num_samples=3000,
+                 batch_size=50):
         """
         LIME wrapper. See lime_base.py for the detailed LIME implementation.
         Args:
@@ -154,31 +159,37 @@ class LIME(object):
         self.predicted_probability = probability[pred_label[0]]
         self.image = image_show[0]
         self.labels = pred_label
-        
+
         l = pred_label[0]
         ln = l
         if self.label_names is not None:
             ln = self.label_names[l]
-            
+
         prob_str = "%.3f" % (probability[pred_label[0]])
-        logging.info("predicted result: {} with probability {}.".format(ln, prob_str))
+        logging.info("predicted result: {} with probability {}.".format(
+            ln, prob_str))
 
         end = time.time()
         algo = lime_base.LimeImageInterpreter()
-        interpreter = algo.interpret_instance(self.image, self.predict_fn, self.labels, 0,
-                                              num_samples=self.num_samples, batch_size=self.batch_size)
+        interpreter = algo.interpret_instance(
+            self.image,
+            self.predict_fn,
+            self.labels,
+            0,
+            num_samples=self.num_samples,
+            batch_size=self.batch_size)
         self.lime_interpreter = interpreter
         logging.info('lime time: ' + str(time.time() - end) + 's.')
 
-    def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+    def interpret(self, data_, visualization=True, save_outdir=None):
         if self.lime_interpreter is None:
             self.preparation_lime(data_)
 
-        if visualization or save_to_disk:
+        if visualization or save_outdir is not None:
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
-            ln = l 
+            ln = l
             if self.label_names is not None:
                 ln = self.label_names[l]
 
@@ -188,7 +199,8 @@ class LIME(object):
             ncols = len(weights_choices)
 
             plt.close()
-            f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            f, axes = plt.subplots(
+                nrows, ncols, figsize=(psize * ncols, psize * nrows))
             for ax in axes.ravel():
                 ax.axis("off")
             axes = axes.ravel()
@@ -196,20 +208,24 @@ class LIME(object):
             prob_str = "{%.3f}" % (self.predicted_probability)
             axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
 
-            axes[1].imshow(mark_boundaries(self.image, self.lime_interpreter.segments))
+            axes[1].imshow(
+                mark_boundaries(self.image, self.lime_interpreter.segments))
             axes[1].set_title("superpixel segmentation")
 
             # LIME visualization
             for i, w in enumerate(weights_choices):
-                num_to_show = auto_choose_num_features_to_show(self.lime_interpreter, l, w)
+                num_to_show = auto_choose_num_features_to_show(
+                    self.lime_interpreter, l, w)
                 temp, mask = self.lime_interpreter.get_image_and_mask(
-                    l, positive_only=False, hide_rest=False, num_features=num_to_show
-                )
+                    l,
+                    positive_only=True,
+                    hide_rest=False,
+                    num_features=num_to_show)
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols + i].set_title("label {}, first {} superpixels".format(ln, num_to_show))
+                axes[ncols + i].set_title(
+                    "label {}, first {} superpixels".format(ln, num_to_show))
 
-        if save_to_disk and save_outdir is not None:
-            os.makedirs(save_outdir, exist_ok=True)
+        if save_outdir is not None:
             save_fig(data_, save_outdir, 'lime', self.num_samples)
 
         if visualization:
@@ -218,9 +234,196 @@ class LIME(object):
         return
 
 
+class NormLIMEStandard(object):
+    def __init__(self,
+                 predict_fn,
+                 label_names,
+                 num_samples=3000,
+                 batch_size=50,
+                 kmeans_model_for_normlime=None,
+                 normlime_weights=None):
+        root_path = gen_user_home()
+        root_path = osp.join(root_path, '.paddlex')
+        h_pre_models = osp.join(root_path, "pre_models")
+        if not osp.exists(h_pre_models):
+            if not osp.exists(root_path):
+                os.makedirs(root_path)
+            url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
+            pdx.utils.download_and_decompress(url, path=root_path)
+        h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
+        if kmeans_model_for_normlime is None:
+            try:
+                self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
+            except:
+                raise ValueError(
+                    "NormLIME needs the KMeans model, where we provided a default one in "
+                    "pre_models/kmeans_model.pkl.")
+        else:
+            logging.debug("Warning: It is *strongly* suggested to use the \
+            default KMeans model in pre_models/kmeans_model.pkl. \
+            Use another one will change the final result.")
+            self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
+
+        self.num_samples = num_samples
+        self.batch_size = batch_size
+
+        try:
+            self.normlime_weights = np.load(
+                normlime_weights, allow_pickle=True).item()
+        except:
+            self.normlime_weights = None
+            logging.debug(
+                "Warning: not find the correct precomputed Normlime result.")
+
+        self.predict_fn = predict_fn
+
+        self.labels = None
+        self.image = None
+        self.label_names = label_names
+
+    def predict_cluster_labels(self, feature_map, segments):
+        X = get_feature_for_kmeans(feature_map, segments)
+        try:
+            cluster_labels = self.kmeans_model.predict(X)
+        except AttributeError:
+            from sklearn.metrics import pairwise_distances_argmin_min
+            cluster_labels, _ = pairwise_distances_argmin_min(
+                X, self.kmeans_model.cluster_centers_)
+        return cluster_labels
+
+    def predict_using_normlime_weights(self, pred_labels,
+                                       predicted_cluster_labels):
+        # global weights
+        g_weights = {y: [] for y in pred_labels}
+        for y in pred_labels:
+            cluster_weights_y = self.normlime_weights.get(y, {})
+            g_weights[y] = [(i, cluster_weights_y.get(k, 0.0))
+                            for i, k in enumerate(predicted_cluster_labels)]
+
+            g_weights[y] = sorted(
+                g_weights[y], key=lambda x: np.abs(x[1]), reverse=True)
+
+        return g_weights
+
+    def preparation_normlime(self, data_):
+        self._lime = LIME(self.predict_fn, self.label_names, self.num_samples,
+                          self.batch_size)
+        self._lime.preparation_lime(data_)
+
+        image_show = read_image(data_)
+
+        self.predicted_label = self._lime.predicted_label
+        self.predicted_probability = self._lime.predicted_probability
+        self.image = image_show[0]
+        self.labels = self._lime.labels
+        logging.info('performing NormLIME operations ...')
+
+        cluster_labels = self.predict_cluster_labels(
+            compute_features_for_kmeans(image_show).transpose((1, 2, 0)),
+            self._lime.lime_interpreter.segments)
+
+        g_weights = self.predict_using_normlime_weights(self.labels,
+                                                        cluster_labels)
+
+        return g_weights
+
+    def interpret(self, data_, visualization=True, save_outdir=None):
+        if self.normlime_weights is None:
+            raise ValueError(
+                "Not find the correct precomputed NormLIME result. \n"
+                "\t Try to call compute_normlime_weights() first or load the correct path."
+            )
+
+        g_weights = self.preparation_normlime(data_)
+        lime_weights = self._lime.lime_interpreter.local_weights
+
+        if visualization or save_outdir is not None:
+            import matplotlib.pyplot as plt
+            from skimage.segmentation import mark_boundaries
+            l = self.labels[0]
+            ln = l
+            if self.label_names is not None:
+                ln = self.label_names[l]
+
+            psize = 5
+            nrows = 4
+            weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
+            nums_to_show = []
+            ncols = len(weights_choices)
+
+            plt.close()
+            f, axes = plt.subplots(
+                nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            for ax in axes.ravel():
+                ax.axis("off")
+
+            axes = axes.ravel()
+            axes[0].imshow(self.image)
+            prob_str = "{%.3f}" % (self.predicted_probability)
+            axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
+
+            axes[1].imshow(
+                mark_boundaries(self.image,
+                                self._lime.lime_interpreter.segments))
+            axes[1].set_title("superpixel segmentation")
+
+            # LIME visualization
+            for i, w in enumerate(weights_choices):
+                num_to_show = auto_choose_num_features_to_show(
+                    self._lime.lime_interpreter, l, w)
+                nums_to_show.append(num_to_show)
+                temp, mask = self._lime.lime_interpreter.get_image_and_mask(
+                    l,
+                    positive_only=False,
+                    hide_rest=False,
+                    num_features=num_to_show)
+                axes[ncols + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols + i].set_title("LIME: first {} superpixels".format(
+                    num_to_show))
+
+            # NormLIME visualization
+            self._lime.lime_interpreter.local_weights = g_weights
+            for i, num_to_show in enumerate(nums_to_show):
+                temp, mask = self._lime.lime_interpreter.get_image_and_mask(
+                    l,
+                    positive_only=False,
+                    hide_rest=False,
+                    num_features=num_to_show)
+                axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols * 2 + i].set_title(
+                    "NormLIME: first {} superpixels".format(num_to_show))
+
+            # NormLIME*LIME visualization
+            combined_weights = combine_normlime_and_lime(lime_weights,
+                                                         g_weights)
+            self._lime.lime_interpreter.local_weights = combined_weights
+            for i, num_to_show in enumerate(nums_to_show):
+                temp, mask = self._lime.lime_interpreter.get_image_and_mask(
+                    l,
+                    positive_only=False,
+                    hide_rest=False,
+                    num_features=num_to_show)
+                axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
+                axes[ncols * 3 + i].set_title(
+                    "Combined: first {} superpixels".format(num_to_show))
+
+            self._lime.lime_interpreter.local_weights = lime_weights
+
+        if save_outdir is not None:
+            save_fig(data_, save_outdir, 'normlime', self.num_samples)
+
+        if visualization:
+            plt.show()
+
+
 class NormLIME(object):
-    def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50,
-                 kmeans_model_for_normlime=None, normlime_weights=None):
+    def __init__(self,
+                 predict_fn,
+                 label_names,
+                 num_samples=3000,
+                 batch_size=50,
+                 kmeans_model_for_normlime=None,
+                 normlime_weights=None):
         root_path = gen_user_home()
         root_path = osp.join(root_path, '.paddlex')
         h_pre_models = osp.join(root_path, "pre_models")
@@ -234,8 +437,9 @@ class NormLIME(object):
             try:
                 self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
             except:
-                raise ValueError("NormLIME needs the KMeans model, where we provided a default one in "
-                                 "pre_models/kmeans_model.pkl.")
+                raise ValueError(
+                    "NormLIME needs the KMeans model, where we provided a default one in "
+                    "pre_models/kmeans_model.pkl.")
         else:
             logging.debug("Warning: It is *strongly* suggested to use the \
             default KMeans model in pre_models/kmeans_model.pkl. \
@@ -246,10 +450,12 @@ class NormLIME(object):
         self.batch_size = batch_size
 
         try:
-            self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
+            self.normlime_weights = np.load(
+                normlime_weights, allow_pickle=True).item()
         except:
             self.normlime_weights = None
-            logging.debug("Warning: not find the correct precomputed Normlime result.")
+            logging.debug(
+                "Warning: not find the correct precomputed Normlime result.")
 
         self.predict_fn = predict_fn
 
@@ -263,30 +469,27 @@ class NormLIME(object):
             cluster_labels = self.kmeans_model.predict(X)
         except AttributeError:
             from sklearn.metrics import pairwise_distances_argmin_min
-            cluster_labels, _ = pairwise_distances_argmin_min(X, self.kmeans_model.cluster_centers_)
+            cluster_labels, _ = pairwise_distances_argmin_min(
+                X, self.kmeans_model.cluster_centers_)
         return cluster_labels
 
-    def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
+    def predict_using_normlime_weights(self, pred_labels,
+                                       predicted_cluster_labels):
         # global weights
         g_weights = {y: [] for y in pred_labels}
         for y in pred_labels:
             cluster_weights_y = self.normlime_weights.get(y, {})
-            g_weights[y] = [
-                (i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
-            ]
+            g_weights[y] = [(i, cluster_weights_y.get(k, 0.0))
+                            for i, k in enumerate(predicted_cluster_labels)]
 
-            g_weights[y] = sorted(g_weights[y],
-                                  key=lambda x: np.abs(x[1]), reverse=True)
+            g_weights[y] = sorted(
+                g_weights[y], key=lambda x: np.abs(x[1]), reverse=True)
 
         return g_weights
 
     def preparation_normlime(self, data_):
-        self._lime = LIME(
-            self.predict_fn,
-            self.label_names,
-            self.num_samples,
-            self.batch_size
-        )
+        self._lime = LIME(self.predict_fn, self.label_names, self.num_samples,
+                          self.batch_size)
         self._lime.preparation_lime(data_)
 
         image_show = read_image(data_)
@@ -298,22 +501,25 @@ class NormLIME(object):
         logging.info('performing NormLIME operations ...')
 
         cluster_labels = self.predict_cluster_labels(
-            compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_interpreter.segments
-        )
+            compute_features_for_kmeans(image_show).transpose((1, 2, 0)),
+            self._lime.lime_interpreter.segments)
 
-        g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
+        g_weights = self.predict_using_normlime_weights(self.labels,
+                                                        cluster_labels)
 
         return g_weights
 
-    def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+    def interpret(self, data_, visualization=True, save_outdir=None):
         if self.normlime_weights is None:
-            raise ValueError("Not find the correct precomputed NormLIME result. \n"
-                             "\t Try to call compute_normlime_weights() first or load the correct path.")
+            raise ValueError(
+                "Not find the correct precomputed NormLIME result. \n"
+                "\t Try to call compute_normlime_weights() first or load the correct path."
+            )
 
         g_weights = self.preparation_normlime(data_)
         lime_weights = self._lime.lime_interpreter.local_weights
 
-        if visualization or save_to_disk:
+        if visualization or save_outdir is not None:
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
@@ -328,7 +534,8 @@ class NormLIME(object):
             ncols = len(weights_choices)
 
             plt.close()
-            f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
+            f, axes = plt.subplots(
+                nrows, ncols, figsize=(psize * ncols, psize * nrows))
             for ax in axes.ravel():
                 ax.axis("off")
 
@@ -337,64 +544,83 @@ class NormLIME(object):
             prob_str = "{%.3f}" % (self.predicted_probability)
             axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
 
-            axes[1].imshow(mark_boundaries(self.image, self._lime.lime_interpreter.segments))
+            axes[1].imshow(
+                mark_boundaries(self.image,
+                                self._lime.lime_interpreter.segments))
             axes[1].set_title("superpixel segmentation")
 
             # LIME visualization
             for i, w in enumerate(weights_choices):
-                num_to_show = auto_choose_num_features_to_show(self._lime.lime_interpreter, l, w)
+                num_to_show = auto_choose_num_features_to_show(
+                    self._lime.lime_interpreter, l, w)
                 nums_to_show.append(num_to_show)
                 temp, mask = self._lime.lime_interpreter.get_image_and_mask(
-                    l, positive_only=False, hide_rest=False, num_features=num_to_show
-                )
+                    l,
+                    positive_only=True,
+                    hide_rest=False,
+                    num_features=num_to_show)
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols + i].set_title("LIME: first {} superpixels".format(num_to_show))
+                axes[ncols + i].set_title("LIME: first {} superpixels".format(
+                    num_to_show))
 
             # NormLIME visualization
             self._lime.lime_interpreter.local_weights = g_weights
             for i, num_to_show in enumerate(nums_to_show):
                 temp, mask = self._lime.lime_interpreter.get_image_and_mask(
-                    l, positive_only=False, hide_rest=False, num_features=num_to_show
-                )
+                    l,
+                    positive_only=True,
+                    hide_rest=False,
+                    num_features=num_to_show)
                 axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols * 2 + i].set_title("NormLIME: first {} superpixels".format(num_to_show))
+                axes[ncols * 2 + i].set_title(
+                    "NormLIME: first {} superpixels".format(num_to_show))
 
             # NormLIME*LIME visualization
-            combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
+            combined_weights = combine_normlime_and_lime(lime_weights,
+                                                         g_weights)
+
             self._lime.lime_interpreter.local_weights = combined_weights
             for i, num_to_show in enumerate(nums_to_show):
                 temp, mask = self._lime.lime_interpreter.get_image_and_mask(
-                    l, positive_only=False, hide_rest=False, num_features=num_to_show
-                )
+                    l,
+                    positive_only=True,
+                    hide_rest=False,
+                    num_features=num_to_show)
                 axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols * 3 + i].set_title("Combined: first {} superpixels".format(num_to_show))
+                axes[ncols * 3 + i].set_title(
+                    "Combined: first {} superpixels".format(num_to_show))
 
             self._lime.lime_interpreter.local_weights = lime_weights
 
-        if save_to_disk and save_outdir is not None:
-            os.makedirs(save_outdir, exist_ok=True)
+        if save_outdir is not None:
             save_fig(data_, save_outdir, 'normlime', self.num_samples)
 
         if visualization:
             plt.show()
 
 
-def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show):
+def auto_choose_num_features_to_show(lime_interpreter, label,
+                                     percentage_to_show):
     segments = lime_interpreter.segments
     lime_weights = lime_interpreter.local_weights[label]
-    num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
+    num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[
+        1] // len(np.unique(segments)) // 8
 
     # l1 norm with filtered weights.
-    used_weights = [(tuple_w[0], tuple_w[1]) for i, tuple_w in enumerate(lime_weights) if tuple_w[1] > 0]
+    used_weights = [(tuple_w[0], tuple_w[1])
+                    for i, tuple_w in enumerate(lime_weights)
+                    if tuple_w[1] > 0]
     norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)])
-    normalized_weights = [(tuple_w[0], tuple_w[1] / norm) for i, tuple_w in enumerate(lime_weights)]
+    normalized_weights = [(tuple_w[0], tuple_w[1] / norm)
+                          for i, tuple_w in enumerate(lime_weights)]
 
     a = 0.0
     n = 0
     for i, tuple_w in enumerate(normalized_weights):
         if tuple_w[1] < 0:
             continue
-        if len(np.where(segments == tuple_w[0])[0]) < num_pixels_threshold_in_a_sp:
+        if len(np.where(segments == tuple_w[0])[
+                0]) < num_pixels_threshold_in_a_sp:
             continue
 
         a += tuple_w[1]
@@ -406,12 +632,18 @@ def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show
         return 5
 
     if n == 0:
-        return auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show-0.1)
+        return auto_choose_num_features_to_show(lime_interpreter, label,
+                                                percentage_to_show - 0.1)
 
     return n
 
 
-def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam_max=None):
+def get_cam(image_show,
+            feature_maps,
+            fc_weights,
+            label_index,
+            cam_min=None,
+            cam_max=None):
     _, nc, h, w = feature_maps.shape
 
     cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1)
@@ -425,7 +657,8 @@ def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam
     cam = cam - cam_min
     cam = cam / cam_max
     cam = np.uint8(255 * cam)
-    cam_img = cv2.resize(cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
+    cam_img = cv2.resize(
+        cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
 
     heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET)
     heatmap = np.float32(heatmap)
@@ -437,34 +670,11 @@ def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam
 
 def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
     import matplotlib.pyplot as plt
-    if isinstance(data_, str):
-        if algorithm_name == 'cam':
-            f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
-        else:
-            f_out = "{}_{}_s{}.png".format(algorithm_name, data_.split('/')[-1], num_samples)
-        plt.savefig(
-            os.path.join(save_outdir, f_out)
-        )
+    if algorithm_name == 'cam':
+        f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
     else:
-        n = 0
-        if algorithm_name == 'cam':
-            f_out = 'cam-{}.png'.format(n)
-        else:
-            f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
-        while os.path.exists(
-                os.path.join(save_outdir, f_out)
-        ):
-            n += 1
-            if algorithm_name == 'cam':
-                f_out = 'cam-{}.png'.format(n)
-            else:
-                f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
-            continue
-        plt.savefig(
-            os.path.join(
-                save_outdir, f_out
-            )
-        )
-    logging.info('The image of intrepretation result save in {}'.format(os.path.join(
-                save_outdir, f_out
-            )))
+        f_out = "{}_{}_s{}.png".format(save_outdir, algorithm_name,
+                                       num_samples)
+
+    plt.savefig(f_out)
+    logging.info('The image of intrepretation result save in {}'.format(f_out))

+ 104 - 68
paddlex/interpret/core/lime_base.py

@@ -27,7 +27,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 The code in this file (lime_base.py) is modified from https://github.com/marcotcr/lime.
 """
 
-
 import numpy as np
 import scipy as sp
 
@@ -39,10 +38,8 @@ import paddlex.utils.logging as logging
 
 class LimeBase(object):
     """Class for learning a locally linear sparse model from perturbed data"""
-    def __init__(self,
-                 kernel_fn,
-                 verbose=False,
-                 random_state=None):
+
+    def __init__(self, kernel_fn, verbose=False, random_state=None):
         """Init function
 
         Args:
@@ -72,15 +69,14 @@ class LimeBase(object):
         """
         from sklearn.linear_model import lars_path
         x_vector = weighted_data
-        alphas, _, coefs = lars_path(x_vector,
-                                     weighted_labels,
-                                     method='lasso',
-                                     verbose=False)
+        alphas, _, coefs = lars_path(
+            x_vector, weighted_labels, method='lasso', verbose=False)
         return alphas, coefs
 
     def forward_selection(self, data, labels, weights, num_features):
         """Iteratively adds features to the model"""
-        clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
+        clf = Ridge(
+            alpha=0, fit_intercept=True, random_state=self.random_state)
         used_features = []
         for _ in range(min(num_features, data.shape[1])):
             max_ = -100000000
@@ -88,11 +84,13 @@ class LimeBase(object):
             for feature in range(data.shape[1]):
                 if feature in used_features:
                     continue
-                clf.fit(data[:, used_features + [feature]], labels,
+                clf.fit(data[:, used_features + [feature]],
+                        labels,
                         sample_weight=weights)
-                score = clf.score(data[:, used_features + [feature]],
-                                  labels,
-                                  sample_weight=weights)
+                score = clf.score(
+                    data[:, used_features + [feature]],
+                    labels,
+                    sample_weight=weights)
                 if score > max_:
                     best = feature
                     max_ = score
@@ -108,8 +106,8 @@ class LimeBase(object):
         elif method == 'forward_selection':
             return self.forward_selection(data, labels, weights, num_features)
         elif method == 'highest_weights':
-            clf = Ridge(alpha=0.01, fit_intercept=True,
-                        random_state=self.random_state)
+            clf = Ridge(
+                alpha=0.01, fit_intercept=True, random_state=self.random_state)
             clf.fit(data, labels, sample_weight=weights)
 
             coef = clf.coef_
@@ -125,7 +123,8 @@ class LimeBase(object):
                     nnz_indexes = argsort_data[::-1]
                     indices = weighted_data.indices[nnz_indexes]
                     num_to_pad = num_features - sdata
-                    indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype)))
+                    indices = np.concatenate((indices, np.zeros(
+                        num_to_pad, dtype=indices.dtype)))
                     indices_set = set(indices)
                     pad_counter = 0
                     for i in range(data.shape[1]):
@@ -135,7 +134,8 @@ class LimeBase(object):
                             if pad_counter >= num_to_pad:
                                 break
                 else:
-                    nnz_indexes = argsort_data[sdata - num_features:sdata][::-1]
+                    nnz_indexes = argsort_data[sdata - num_features:sdata][::
+                                                                           -1]
                     indices = weighted_data.indices[nnz_indexes]
                 return indices
             else:
@@ -146,13 +146,13 @@ class LimeBase(object):
                     reverse=True)
                 return np.array([x[0] for x in feature_weights[:num_features]])
         elif method == 'lasso_path':
-            weighted_data = ((data - np.average(data, axis=0, weights=weights))
-                             * np.sqrt(weights[:, np.newaxis]))
-            weighted_labels = ((labels - np.average(labels, weights=weights))
-                               * np.sqrt(weights))
+            weighted_data = ((data - np.average(
+                data, axis=0, weights=weights)) *
+                             np.sqrt(weights[:, np.newaxis]))
+            weighted_labels = ((labels - np.average(
+                labels, weights=weights)) * np.sqrt(weights))
             nonzero = range(weighted_data.shape[1])
-            _, coefs = self.generate_lars_path(weighted_data,
-                                               weighted_labels)
+            _, coefs = self.generate_lars_path(weighted_data, weighted_labels)
             for i in range(len(coefs.T) - 1, 0, -1):
                 nonzero = coefs.T[i].nonzero()[0]
                 if len(nonzero) <= num_features:
@@ -164,8 +164,8 @@ class LimeBase(object):
                 n_method = 'forward_selection'
             else:
                 n_method = 'highest_weights'
-            return self.feature_selection(data, labels, weights,
-                                          num_features, n_method)
+            return self.feature_selection(data, labels, weights, num_features,
+                                          n_method)
 
     def interpret_instance_with_data(self,
                                      neighborhood_data,
@@ -214,30 +214,31 @@ class LimeBase(object):
         weights = self.kernel_fn(distances)
         labels_column = neighborhood_labels[:, label]
         used_features = self.feature_selection(neighborhood_data,
-                                               labels_column,
-                                               weights,
-                                               num_features,
-                                               feature_selection)
+                                               labels_column, weights,
+                                               num_features, feature_selection)
         if model_regressor is None:
-            model_regressor = Ridge(alpha=1, fit_intercept=True,
-                                    random_state=self.random_state)
+            model_regressor = Ridge(
+                alpha=1, fit_intercept=True, random_state=self.random_state)
         easy_model = model_regressor
         easy_model.fit(neighborhood_data[:, used_features],
-                       labels_column, sample_weight=weights)
+                       labels_column,
+                       sample_weight=weights)
         prediction_score = easy_model.score(
             neighborhood_data[:, used_features],
-            labels_column, sample_weight=weights)
+            labels_column,
+            sample_weight=weights)
 
-        local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
+        local_pred = easy_model.predict(neighborhood_data[0, used_features]
+                                        .reshape(1, -1))
 
         if self.verbose:
             logging.info('Intercept' + str(easy_model.intercept_))
             logging.info('Prediction_local' + str(local_pred))
             logging.info('Right:' + str(neighborhood_labels[0, label]))
-        return (easy_model.intercept_,
-                sorted(zip(used_features, easy_model.coef_),
-                       key=lambda x: np.abs(x[1]), reverse=True),
-                prediction_score, local_pred)
+        return (easy_model.intercept_, sorted(
+            zip(used_features, easy_model.coef_),
+            key=lambda x: np.abs(x[1]),
+            reverse=True), prediction_score, local_pred)
 
 
 class ImageInterpretation(object):
@@ -254,8 +255,13 @@ class ImageInterpretation(object):
         self.local_weights = {}
         self.local_pred = None
 
-    def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
-                           num_features=5, min_weight=0.):
+    def get_image_and_mask(self,
+                           label,
+                           positive_only=True,
+                           negative_only=False,
+                           hide_rest=False,
+                           num_features=5,
+                           min_weight=0.):
         """Init function.
 
         Args:
@@ -279,7 +285,9 @@ class ImageInterpretation(object):
         if label not in self.local_weights:
             raise KeyError('Label not in interpretation')
         if positive_only & negative_only:
-            raise ValueError("Positive_only and negative_only cannot be true at the same time.")
+            raise ValueError(
+                "Positive_only and negative_only cannot be true at the same time."
+            )
         segments = self.segments
         image = self.image
         local_weights_label = self.local_weights[label]
@@ -289,14 +297,20 @@ class ImageInterpretation(object):
         else:
             temp = self.image.copy()
         if positive_only:
-            fs = [x[0] for x in local_weights_label
-                  if x[1] > 0 and x[1] > min_weight][:num_features]
+            fs = [
+                x[0] for x in local_weights_label
+                if x[1] > 0 and x[1] > min_weight
+            ][:num_features]
         if negative_only:
-            fs = [x[0] for x in local_weights_label
-                  if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
+            fs = [
+                x[0] for x in local_weights_label
+                if x[1] < 0 and abs(x[1]) > min_weight
+            ][:num_features]
         if positive_only or negative_only:
+            c = 1 if positive_only else 0
             for f in fs:
-                temp[segments == f] = image[segments == f].copy()
+                temp[segments == f] = [0, 255, 0]
+                # temp[segments == f, c] = np.max(image)
                 mask[segments == f] = 1
             return temp, mask
         else:
@@ -330,8 +344,11 @@ class ImageInterpretation(object):
         temp = np.zeros_like(image)
 
         weight_max = abs(local_weights_label[0][1])
-        local_weights_label = [(f, w/weight_max) for f, w in local_weights_label]
-        local_weights_label = sorted(local_weights_label, key=lambda x: x[1], reverse=True)  # negatives are at last.
+        local_weights_label = [(f, w / weight_max)
+                               for f, w in local_weights_label]
+        local_weights_label = sorted(
+            local_weights_label, key=lambda x: x[1],
+            reverse=True)  # negatives are at last.
 
         cmaps = cm.get_cmap('Spectral')
         colors = cmaps(np.linspace(0, 1, len(local_weights_label)))
@@ -354,8 +371,12 @@ class LimeImageInterpreter(object):
     feature that is 1 when the value is the same as the instance being
     interpreted."""
 
-    def __init__(self, kernel_width=.25, kernel=None, verbose=False,
-                 feature_selection='auto', random_state=None):
+    def __init__(self,
+                 kernel_width=.25,
+                 kernel=None,
+                 verbose=False,
+                 feature_selection='auto',
+                 random_state=None):
         """Init function.
 
         Args:
@@ -377,22 +398,27 @@ class LimeImageInterpreter(object):
         kernel_width = float(kernel_width)
 
         if kernel is None:
+
             def kernel(d, kernel_width):
-                return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
+                return np.sqrt(np.exp(-(d**2) / kernel_width**2))
 
         kernel_fn = partial(kernel, kernel_width=kernel_width)
 
         self.random_state = check_random_state(random_state)
         self.feature_selection = feature_selection
-        self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
+        self.base = LimeBase(
+            kernel_fn, verbose, random_state=self.random_state)
 
-    def interpret_instance(self, image, classifier_fn, labels=(1,),
+    def interpret_instance(self,
+                           image,
+                           classifier_fn,
+                           labels=(1, ),
                            hide_color=None,
-                           num_features=100000, num_samples=1000,
+                           num_features=100000,
+                           num_samples=1000,
                            batch_size=10,
                            distance_metric='cosine',
-                           model_regressor=None
-                           ):
+                           model_regressor=None):
         """Generates interpretations for a prediction.
 
         First, we generate neighborhood data by randomly perturbing features
@@ -435,6 +461,7 @@ class LimeImageInterpreter(object):
         self.segments = segments
 
         fudged_image = image.copy()
+        # global_mean = np.mean(image, (0, 1))
         if hide_color is None:
             # if no hide_color, use the mean
             for x in np.unique(segments):
@@ -461,24 +488,30 @@ class LimeImageInterpreter(object):
 
         top = labels
 
-        data, labels = self.data_labels(image, fudged_image, segments,
-                                        classifier_fn, num_samples,
-                                        batch_size=batch_size)
+        data, labels = self.data_labels(
+            image,
+            fudged_image,
+            segments,
+            classifier_fn,
+            num_samples,
+            batch_size=batch_size)
 
         distances = sklearn.metrics.pairwise_distances(
-            data,
-            data[0].reshape(1, -1),
-            metric=distance_metric
-        ).ravel()
+            data, data[0].reshape(1, -1), metric=distance_metric).ravel()
 
         interpretation_image = ImageInterpretation(image, segments)
         for label in top:
             (interpretation_image.intercept[label],
              interpretation_image.local_weights[label],
-             interpretation_image.score, interpretation_image.local_pred) = self.base.interpret_instance_with_data(
-                data, labels, distances, label, num_features,
-                model_regressor=model_regressor,
-                feature_selection=self.feature_selection)
+             interpretation_image.score, interpretation_image.local_pred
+             ) = self.base.interpret_instance_with_data(
+                 data,
+                 labels,
+                 distances,
+                 label,
+                 num_features,
+                 model_regressor=model_regressor,
+                 feature_selection=self.feature_selection)
         return interpretation_image
 
     def data_labels(self,
@@ -511,6 +544,9 @@ class LimeImageInterpreter(object):
         labels = []
         data[0, :] = 1
         imgs = []
+
+        logging.info("Computing LIME.", use_color=True)
+
         for row in tqdm.tqdm(data):
             temp = copy.deepcopy(image)
             zeros = np.where(row == 0)[0]

+ 219 - 38
paddlex/interpret/core/normlime_base.py

@@ -16,6 +16,7 @@ import os
 import os.path as osp
 import numpy as np
 import glob
+import tqdm
 
 from paddlex.interpret.as_data_reader.readers import read_image
 import paddlex.utils.logging as logging
@@ -38,18 +39,24 @@ def combine_normlime_and_lime(lime_weights, g_weights):
 
     for y in pred_labels:
         normlized_lime_weights_y = lime_weights[y]
-        lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y}
+        lime_weights_dict = {
+            tuple_w[0]: tuple_w[1]
+            for tuple_w in normlized_lime_weights_y
+        }
 
         normlized_g_weight_y = g_weights[y]
-        normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y}
+        normlime_weights_dict = {
+            tuple_w[0]: tuple_w[1]
+            for tuple_w in normlized_g_weight_y
+        }
 
         combined_weights[y] = [
             (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
             for seg_k in lime_weights_dict.keys()
         ]
 
-        combined_weights[y] = sorted(combined_weights[y],
-                                     key=lambda x: np.abs(x[1]), reverse=True)
+        combined_weights[y] = sorted(
+            combined_weights[y], key=lambda x: np.abs(x[1]), reverse=True)
 
     return combined_weights
 
@@ -67,7 +74,8 @@ def centroid_using_superpixels(features, segments):
     regions = regionprops(segments + 1)
     one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
     for i, r in enumerate(regions):
-        one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
+        one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] +
+                                                             0.5), :]
     return one_list
 
 
@@ -80,30 +88,39 @@ def get_feature_for_kmeans(feature_map, segments):
     return x
 
 
-def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_size=50, save_dir='./tmp'):
+def precompute_normlime_weights(list_data_,
+                                predict_fn,
+                                num_samples=3000,
+                                batch_size=50,
+                                save_dir='./tmp'):
     # save lime weights and kmeans cluster labels
-    precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir)
+    precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size,
+                            save_dir)
 
     # load precomputed results, compute normlime weights and save.
-    fname_list = glob.glob(os.path.join(save_dir, 'lime_weights_s{}*.npy'.format(num_samples)))
+    fname_list = glob.glob(
+        os.path.join(save_dir, 'lime_weights_s{}*.npy'.format(num_samples)))
     return compute_normlime_weights(fname_list, save_dir, num_samples)
 
 
-def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, cluster_labels, save_path):
+def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels,
+                                           cluster_labels, save_path):
 
     lime_weights = {}
     for label in image_pred_labels:
         lime_weights[label] = lime_all_weights[label]
 
     for_normlime_weights = {
-        'lime_weights': lime_weights,  # a dict: class_label: (seg_label, weight)
+        'lime_weights':
+        lime_weights,  # a dict: class_label: (seg_label, weight)
         'cluster': cluster_labels  # a list with segments as indices.
     }
 
     np.save(save_path, for_normlime_weights)
 
 
-def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir):
+def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size,
+                            save_dir):
     root_path = gen_user_home()
     root_path = osp.join(root_path, '.paddlex')
     h_pre_models = osp.join(root_path, "pre_models")
@@ -117,17 +134,24 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
 
     for data_index, each_data_ in enumerate(list_data_):
         if isinstance(each_data_, str):
-            save_path = "lime_weights_s{}_{}.npy".format(num_samples, each_data_.split('/')[-1].split('.')[0])
+            save_path = "lime_weights_s{}_{}.npy".format(
+                num_samples, each_data_.split('/')[-1].split('.')[0])
             save_path = os.path.join(save_dir, save_path)
         else:
-            save_path = "lime_weights_s{}_{}.npy".format(num_samples, data_index)
+            save_path = "lime_weights_s{}_{}.npy".format(num_samples,
+                                                         data_index)
             save_path = os.path.join(save_dir, save_path)
 
         if os.path.exists(save_path):
-            logging.info(save_path + ' exists, not computing this one.', use_color=True)
+            logging.info(
+                save_path + ' exists, not computing this one.', use_color=True)
             continue
-        img_file_name = each_data_ if isinstance(each_data_, str) else data_index
-        logging.info('processing '+ img_file_name + ' [{}/{}]'.format(data_index, len(list_data_)), use_color=True)
+        img_file_name = each_data_ if isinstance(each_data_,
+                                                 str) else data_index
+        logging.info(
+            'processing ' + img_file_name + ' [{}/{}]'.format(data_index,
+                                                              len(list_data_)),
+            use_color=True)
 
         image_show = read_image(each_data_)
         result = predict_fn(image_show)
@@ -156,32 +180,38 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
         pred_label = pred_label[:top_k]
 
         algo = lime_base.LimeImageInterpreter()
-        interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0,
-                                          num_samples=num_samples, batch_size=batch_size)
-
-        X = get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments)
+        interpreter = algo.interpret_instance(
+            image_show[0],
+            predict_fn,
+            pred_label,
+            0,
+            num_samples=num_samples,
+            batch_size=batch_size)
+
+        X = get_feature_for_kmeans(
+            compute_features_for_kmeans(image_show).transpose((1, 2, 0)),
+            interpreter.segments)
         try:
             cluster_labels = kmeans_model.predict(X)
         except AttributeError:
             from sklearn.metrics import pairwise_distances_argmin_min
-            cluster_labels, _ = pairwise_distances_argmin_min(X, kmeans_model.cluster_centers_)
+            cluster_labels, _ = pairwise_distances_argmin_min(
+                X, kmeans_model.cluster_centers_)
         save_one_lime_predict_and_kmean_labels(
-            interpreter.local_weights, pred_label,
-            cluster_labels,
-            save_path
-        )
+            interpreter.local_weights, pred_label, cluster_labels, save_path)
 
 
 def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
     normlime_weights_all_labels = {}
-    
+
     for f in a_list_lime_fnames:
         try:
             lime_weights_and_cluster = np.load(f, allow_pickle=True).item()
             lime_weights = lime_weights_and_cluster['lime_weights']
             cluster = lime_weights_and_cluster['cluster']
         except:
-            logging.info('When loading precomputed LIME result, skipping' + str(f))
+            logging.info('When loading precomputed LIME result, skipping' +
+                         str(f))
             continue
         logging.info('Loading precomputed LIME result,' + str(f))
         pred_labels = lime_weights.keys()
@@ -203,10 +233,12 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
     for y in normlime_weights_all_labels:
         normlime_weights = normlime_weights_all_labels.get(y, {})
         for k in normlime_weights:
-            normlime_weights[k] = sum(normlime_weights[k]) / len(normlime_weights[k])
+            normlime_weights[k] = sum(normlime_weights[k]) / len(
+                normlime_weights[k])
 
     # check normlime
-    if len(normlime_weights_all_labels.keys()) < max(normlime_weights_all_labels.keys()) + 1:
+    if len(normlime_weights_all_labels.keys()) < max(
+            normlime_weights_all_labels.keys()) + 1:
         logging.info(
             "\n" + \
             "Warning: !!! \n" + \
@@ -218,17 +250,166 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
         )
 
     n = 0
-    f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n)
-    while os.path.exists(
-            os.path.join(save_dir, f_out)
-    ):
+    f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(
+        lime_num_samples, len(a_list_lime_fnames), n)
+    while os.path.exists(os.path.join(save_dir, f_out)):
         n += 1
-        f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n)
+        f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(
+            lime_num_samples, len(a_list_lime_fnames), n)
         continue
 
-    np.save(
-        os.path.join(save_dir, f_out),
-        normlime_weights_all_labels
-    )
+    np.save(os.path.join(save_dir, f_out), normlime_weights_all_labels)
     return os.path.join(save_dir, f_out)
 
+
+def precompute_global_classifier(dataset,
+                                 predict_fn,
+                                 save_path,
+                                 batch_size=50,
+                                 max_num_samples=1000):
+    from sklearn.linear_model import LogisticRegression
+
+    root_path = gen_user_home()
+    root_path = osp.join(root_path, '.paddlex')
+    h_pre_models = osp.join(root_path, "pre_models")
+    if not osp.exists(h_pre_models):
+        if not osp.exists(root_path):
+            os.makedirs(root_path)
+        url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
+        pdx.utils.download_and_decompress(url, path=root_path)
+    h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
+    kmeans_model = load_kmeans_model(h_pre_models_kmeans)
+
+    image_list = []
+    for item in dataset.file_list:
+        image_list.append(item[0])
+
+    x_data = []
+    y_labels = []
+
+    num_features = len(kmeans_model.cluster_centers_)
+
+    logging.info(
+        "Initialization for NormLIME: Computing each sample in the test list.",
+        use_color=True)
+
+    for each_data_ in tqdm.tqdm(image_list):
+        x_data_i = np.zeros((num_features))
+        image_show = read_image(each_data_)
+        result = predict_fn(image_show)
+        result = result[0]  # only one image here.
+        c = compute_features_for_kmeans(image_show).transpose((1, 2, 0))
+
+        segments = np.zeros((image_show.shape[1], image_show.shape[2]),
+                            np.int32)
+        num_blocks = 10
+        height_per_i = segments.shape[0] // num_blocks + 1
+        width_per_i = segments.shape[1] // num_blocks + 1
+
+        for i in range(segments.shape[0]):
+            for j in range(segments.shape[1]):
+                segments[i,
+                         j] = i // height_per_i * num_blocks + j // width_per_i
+
+        # segments = quickshift(image_show[0], sigma=1)
+        X = get_feature_for_kmeans(c, segments)
+
+        try:
+            cluster_labels = kmeans_model.predict(X)
+        except AttributeError:
+            from sklearn.metrics import pairwise_distances_argmin_min
+            cluster_labels, _ = pairwise_distances_argmin_min(
+                X, kmeans_model.cluster_centers_)
+
+        for c in cluster_labels:
+            x_data_i[c] = 1
+
+        # x_data_i /= len(cluster_labels)
+
+        pred_y_i = np.argmax(result)
+        y_labels.append(pred_y_i)
+        x_data.append(x_data_i)
+
+    if len(np.unique(y_labels)) < 2:
+        logging.info("Warning: The test samples in the dataset is limited.\n \
+                     NormLIME may have no effect on the results.\n \
+                     Try to add more test samples, or see the results of LIME.")
+        num_classes = np.max(np.unique(y_labels)) + 1
+        normlime_weights_all_labels = {}
+        for class_index in range(num_classes):
+            w = np.ones((num_features)) / num_features
+            normlime_weights_all_labels[class_index] = {
+                i: wi
+                for i, wi in enumerate(w)
+            }
+        logging.info("Saving the computed normlime_weights in {}".format(
+            save_path))
+
+        np.save(save_path, normlime_weights_all_labels)
+        return save_path
+
+    clf = LogisticRegression(multi_class='multinomial', max_iter=1000)
+    clf.fit(x_data, y_labels)
+
+    num_classes = np.max(np.unique(y_labels)) + 1
+    normlime_weights_all_labels = {}
+
+    if len(y_labels) / len(np.unique(y_labels)) < 3:
+        logging.info("Warning: The test samples in the dataset is limited.\n \
+                     NormLIME may have no effect on the results.\n \
+                     Try to add more test samples, or see the results of LIME.")
+
+    if len(np.unique(y_labels)) == 2:
+        # binary: clf.coef_ has shape of [1, num_features]
+        for class_index in range(num_classes):
+            if class_index not in clf.classes_:
+                w = np.ones((num_features)) / num_features
+                normlime_weights_all_labels[class_index] = {
+                    i: wi
+                    for i, wi in enumerate(w)
+                }
+                continue
+
+            if clf.classes_[0] == class_index:
+                w = -clf.coef_[0]
+            else:
+                w = clf.coef_[0]
+
+            # softmax
+            w = w - np.max(w)
+            exp_w = np.exp(w * 10)
+            w = exp_w / np.sum(exp_w)
+
+            normlime_weights_all_labels[class_index] = {
+                i: wi
+                for i, wi in enumerate(w)
+            }
+    else:
+        # clf.coef_ has shape of [len(np.unique(y_labels)), num_features]
+        for class_index in range(num_classes):
+            if class_index not in clf.classes_:
+                w = np.ones((num_features)) / num_features
+                normlime_weights_all_labels[class_index] = {
+                    i: wi
+                    for i, wi in enumerate(w)
+                }
+                continue
+
+            coef_class_index = np.where(clf.classes_ == class_index)[0][0]
+            w = clf.coef_[coef_class_index]
+
+            # softmax
+            w = w - np.max(w)
+            exp_w = np.exp(w * 10)
+            w = exp_w / np.sum(exp_w)
+
+            normlime_weights_all_labels[class_index] = {
+                i: wi
+                for i, wi in enumerate(w)
+            }
+
+    logging.info("Saving the computed normlime_weights in {}".format(
+        save_path))
+    np.save(save_path, normlime_weights_all_labels)
+
+    return save_path

+ 18 - 9
paddlex/interpret/interpretation_predict.py

@@ -13,17 +13,26 @@
 # limitations under the License.
 
 import numpy as np
+import cv2
+import copy
+
 
 def interpretation_predict(model, images):
-    model.arrange_transforms(
-            transforms=model.test_transforms, mode='test')
+    images = images.astype('float32')
+    model.arrange_transforms(transforms=model.test_transforms, mode='test')
+    tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
+    model.test_transforms.transforms = model.test_transforms.transforms[-2:]
+
     new_imgs = []
     for i in range(images.shape[0]):
-        img = images[i]
-        new_imgs.append(model.test_transforms(img)[0])
+        images[i] = cv2.cvtColor(images[i], cv2.COLOR_RGB2BGR)
+        new_imgs.append(model.test_transforms(images[i])[0])
+
     new_imgs = np.array(new_imgs)
-    result = model.exe.run(
-        model.test_prog,
-        feed={'image': new_imgs},
-        fetch_list=list(model.interpretation_feats.values()))
-    return result
+    out = model.exe.run(model.test_prog,
+                        feed={'image': new_imgs},
+                        fetch_list=list(model.interpretation_feats.values()))
+
+    model.test_transforms.transforms = tmp_transforms
+
+    return out

+ 85 - 88
paddlex/interpret/visualize.py

@@ -20,79 +20,79 @@ import numpy as np
 import paddlex as pdx
 from .interpretation_predict import interpretation_predict
 from .core.interpretation import Interpretation
-from .core.normlime_base import precompute_normlime_weights
+from .core.normlime_base import precompute_global_classifier
 from .core._session_preparation import gen_user_home
-   
-def lime(img_file, 
-         model, 
-         num_samples=3000, 
-         batch_size=50,
-         save_dir='./'):
-    """使用LIME算法将模型预测结果的可解释性可视化。 
-    
+
+
+def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
+    """使用LIME算法将模型预测结果的可解释性可视化。
+
     LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
     在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
     和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
-    得到每个输入维度的权重,以此来解释模型。  
-    
+    得到每个输入维度的权重,以此来解释模型。
+
     注意:LIME可解释性结果可视化目前只支持分类模型。
-         
+
     Args:
         img_file (str): 预测图像路径。
         model (paddlex.cv.models): paddlex中的模型。
         num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
         batch_size (int): 预测数据batch大小,默认为50。
-        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。        
+        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
     """
     assert model.model_type == 'classifier', \
         'Now the interpretation visualize only be supported in classifier!'
     if model.status != 'Normal':
-        raise Exception('The interpretation only can deal with the Normal model')
+        raise Exception(
+            'The interpretation only can deal with the Normal model')
     if not osp.exists(save_dir):
         os.makedirs(save_dir)
-    model.arrange_transforms(
-                transforms=model.test_transforms, mode='test')
+    model.arrange_transforms(transforms=model.test_transforms, mode='test')
     tmp_transforms = copy.deepcopy(model.test_transforms)
     tmp_transforms.transforms = tmp_transforms.transforms[:-2]
     img = tmp_transforms(img_file)[0]
     img = np.around(img).astype('uint8')
     img = np.expand_dims(img, axis=0)
     interpreter = None
-    interpreter = get_lime_interpreter(img, model, num_samples=num_samples, batch_size=batch_size)
+    interpreter = get_lime_interpreter(
+        img, model, num_samples=num_samples, batch_size=batch_size)
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
-    interpreter.interpret(img, save_dir=save_dir)
-    
-    
-def normlime(img_file, 
-              model, 
-              dataset=None,
-              num_samples=3000, 
-              batch_size=50,
-              save_dir='./'):
+    interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
+
+
+def normlime(img_file,
+             model,
+             dataset=None,
+             num_samples=3000,
+             batch_size=50,
+             save_dir='./',
+             normlime_weights_file=None):
     """使用NormLIME算法将模型预测结果的可解释性可视化。
-    
+
     NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测
     试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
-    
+
     注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
     注意2:NormLIME可解释性结果可视化目前只支持分类模型。
-         
+
     Args:
         img_file (str): 预测图像路径。
         model (paddlex.cv.models): paddlex中的模型。
         dataset (paddlex.datasets): 数据集读取器,默认为None。
         num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
         batch_size (int): 预测数据batch大小,默认为50。
-        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。        
+        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
+        normlime_weights_file (str): NormLIME初始化文件名,若不存在,则计算一次,保存于该路径;若存在,则直接载入。
     """
     assert model.model_type == 'classifier', \
         'Now the interpretation visualize only be supported in classifier!'
     if model.status != 'Normal':
-        raise Exception('The interpretation only can deal with the Normal model')
+        raise Exception(
+            'The interpretation only can deal with the Normal model')
     if not osp.exists(save_dir):
         os.makedirs(save_dir)
-    model.arrange_transforms(
-                transforms=model.test_transforms, mode='test')
+    model.arrange_transforms(transforms=model.test_transforms, mode='test')
     tmp_transforms = copy.deepcopy(model.test_transforms)
     tmp_transforms.transforms = tmp_transforms.transforms[:-2]
     img = tmp_transforms(img_file)[0]
@@ -100,52 +100,48 @@ def normlime(img_file,
     img = np.expand_dims(img, axis=0)
     interpreter = None
     if dataset is None:
-        raise Exception('The dataset is None. Cannot implement this kind of interpretation')
-    interpreter = get_normlime_interpreter(img, model, dataset, 
-                                 num_samples=num_samples, batch_size=batch_size,
-                                     save_dir=save_dir)
+        raise Exception(
+            'The dataset is None. Cannot implement this kind of interpretation')
+    interpreter = get_normlime_interpreter(
+        img,
+        model,
+        dataset,
+        num_samples=num_samples,
+        batch_size=batch_size,
+        save_dir=save_dir,
+        normlime_weights_file=normlime_weights_file)
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
-    interpreter.interpret(img, save_dir=save_dir)
-    
-    
+    interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
+
+
 def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
     def predict_func(image):
-        image = image.astype('float32')
-        for i in range(image.shape[0]):
-            image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
-        tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
-        model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         out = interpretation_predict(model, image)
-        model.test_transforms.transforms = tmp_transforms
         return out[0]
+
     labels_name = None
     if hasattr(model, 'labels'):
         labels_name = model.labels
-    interpreter = Interpretation('lime', 
-                            predict_func,
-                            labels_name,
-                            num_samples=num_samples, 
-                            batch_size=batch_size)
+    interpreter = Interpretation(
+        'lime',
+        predict_func,
+        labels_name,
+        num_samples=num_samples,
+        batch_size=batch_size)
     return interpreter
 
 
-def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
-    def precompute_predict_func(image):
-        image = image.astype('float32')
-        tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
-        model.test_transforms.transforms = model.test_transforms.transforms[-2:]
-        out = interpretation_predict(model, image)
-        model.test_transforms.transforms = tmp_transforms
-        return out[0]
+def get_normlime_interpreter(img,
+                             model,
+                             dataset,
+                             num_samples=3000,
+                             batch_size=50,
+                             save_dir='./',
+                             normlime_weights_file=None):
     def predict_func(image):
-        image = image.astype('float32')
-        for i in range(image.shape[0]):
-            image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
-        tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
-        model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         out = interpretation_predict(model, image)
-        model.test_transforms.transforms = tmp_transforms
         return out[0]
+
     labels_name = None
     if dataset is not None:
         labels_name = dataset.labels
@@ -157,28 +153,29 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
             os.makedirs(root_path)
         url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
         pdx.utils.download_and_decompress(url, path=root_path)
-    npy_dir = precompute_for_normlime(precompute_predict_func, 
-                                      dataset, 
-                                      num_samples=num_samples, 
-                                      batch_size=batch_size,
-                                      save_dir=save_dir)
-    interpreter = Interpretation('normlime', 
-                            predict_func,
-                            labels_name,
-                            num_samples=num_samples, 
-                            batch_size=batch_size,
-                            normlime_weights=npy_dir)
-    return interpreter
-
 
-def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
-    image_list = []
-    for item in dataset.file_list:
-        image_list.append(item[0])
-    return precompute_normlime_weights(
-            image_list,  
+    if osp.exists(osp.join(save_dir, normlime_weights_file)):
+        normlime_weights_file = osp.join(save_dir, normlime_weights_file)
+        try:
+            np.load(normlime_weights_file, allow_pickle=True).item()
+        except:
+            normlime_weights_file = precompute_global_classifier(
+                dataset,
+                predict_func,
+                save_path=normlime_weights_file,
+                batch_size=batch_size)
+    else:
+        normlime_weights_file = precompute_global_classifier(
+            dataset,
             predict_func,
-            num_samples=num_samples, 
-            batch_size=batch_size,
-            save_dir=save_dir)
-  
+            save_path=osp.join(save_dir, normlime_weights_file),
+            batch_size=batch_size)
+
+    interpreter = Interpretation(
+        'normlime',
+        predict_func,
+        labels_name,
+        num_samples=num_samples,
+        batch_size=batch_size,
+        normlime_weights=normlime_weights_file)
+    return interpreter

+ 1 - 1
setup.py

@@ -19,7 +19,7 @@ long_description = "PaddleX. A end-to-end deeplearning model development toolkit
 
 setuptools.setup(
     name="paddlex",
-    version='1.0.5',
+    version='1.0.6',
     author="paddlex",
     author_email="paddlex@baidu.com",
     description=long_description,

+ 12 - 8
tutorials/interpret/normlime.py

@@ -14,18 +14,22 @@ model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilene
 pdx.utils.download_and_decompress(model_file, path='./')
 
 # 加载模型
-model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
+model_file = 'mini_imagenet_veg_mobilenetv2'
+model = pdx.load_model(model_file)
 
 # 定义测试所用的数据集
+dataset = 'mini_imagenet_veg'
 test_dataset = pdx.datasets.ImageNet(
-    data_dir='mini_imagenet_veg',
-    file_list=osp.join('mini_imagenet_veg', 'test_list.txt'),
-    label_list=osp.join('mini_imagenet_veg', 'labels.txt'),
+    data_dir=dataset,
+    file_list=osp.join(dataset, 'test_list.txt'),
+    label_list=osp.join(dataset, 'labels.txt'),
     transforms=model.test_transforms)
 
 # 可解释性可视化
 pdx.interpret.normlime(
-         'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
-          model, 
-          test_dataset, 
-          save_dir='./')
+    test_dataset.file_list[0][0],
+    model,
+    test_dataset,
+    save_dir='./',
+    normlime_weights_file='{}_{}.npy'.format(
+        dataset.split('/')[-1], model.model_name))