Przeglądaj źródła

Merge pull request #90 from SunAhong1993/syf0520

add interpret vis docs
Jason 5 lat temu
rodzic
commit
c7ae49ff4d

+ 39 - 12
docs/apis/visualize.md

@@ -114,27 +114,54 @@ pdx.slim.visualize(model, 'mobilenetv2.sensitivities', save_dir='./')
 # 可视化结果保存在./sensitivities.png
 ```
 
-## 可解释性结果可视化
+## LIME可解释性结果可视化
 ```
-paddlex.interpret.visualize(img_file, 
-                            model, 
-                            dataset=None, 
-                            algo='lime',
-                            num_samples=3000, 
-                            batch_size=50,
-                            save_dir='./')
+paddlex.interpret.lime(img_file, 
+                       model, 
+                       num_samples=3000, 
+                       batch_size=50,
+                       save_dir='./')
 ```
-将模型预测结果的可解释性可视化,目前只支持分类模型。
+使用LIME算法将模型预测结果的可解释性可视化。  
+LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,得到每个输入维度的权重,以此来解释模型。    
+
+**注意:** 可解释性结果可视化目前只支持分类模型。
 
 ### 参数
 >* **img_file** (str): 预测图像路径。
 >* **model** (paddlex.cv.models): paddlex中的模型。
->* **dataset** (paddlex.datasets): 数据集读取器,默认为None。
->* **algo** (str): 可解释性方式,当前可选'lime'和'normlime'。
 >* **num_samples** (int): LIME用于学习线性模型的采样数,默认为3000。
 >* **batch_size** (int): 预测数据batch大小,默认为50。
 >* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 
 
 
 ### 使用示例
-> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/interpret.py)。
+> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/lime.py)。
+
+
+## NormLIME可解释性结果可视化
+```
+paddlex.interpret.normlime(img_file, 
+                           model, 
+                           dataset=None,
+                           num_samples=3000, 
+                           batch_size=50,
+                           save_dir='./')
+```
+使用NormLIME算法将模型预测结果的可解释性可视化。
+NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
+
+**注意:** 可解释性结果可视化目前只支持分类模型。
+
+### 参数
+>* **img_file** (str): 预测图像路径。
+>* **model** (paddlex.cv.models): paddlex中的模型。
+>* **dataset** (paddlex.datasets): 数据集读取器,默认为None。
+>* **num_samples** (int): LIME用于学习线性模型的采样数,默认为3000。
+>* **batch_size** (int): 预测数据batch大小,默认为50。
+>* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 
+
+**注意:** dataset`读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
+### 使用示例
+> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/normlime.py)。
+

+ 2 - 1
paddlex/interpret/__init__.py

@@ -15,4 +15,5 @@
 from __future__ import absolute_import
 from . import visualize
 
-visualize = visualize.visualize
+lime = visualize.lime
+normlime = visualize.normlime

+ 2 - 3
paddlex/interpret/core/normlime_base.py

@@ -116,9 +116,8 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
         if os.path.exists(save_path):
             logging.info(save_path + ' exists, not computing this one.', use_color=True)
             continue
-
-        logging.info('processing'+each_data_ if isinstance(each_data_, str) else data_index + \
-              f'+{data_index}/{len(list_data_)}', use_color=True)
+        img_file_name = each_data_ if isinstance(each_data_, str) else data_index
+        logging.info('processing '+ img_file_name + ' [{}/{}]'.format(data_index, len(list_data_)), use_color=True)
 
         image_show = read_image(each_data_)
         result = predict_fn(image_show)

+ 57 - 17
paddlex/interpret/visualize.py

@@ -22,20 +22,65 @@ from .interpretation_predict import interpretation_predict
 from .core.interpretation import Interpretation
 from .core.normlime_base import precompute_normlime_weights
 from .core._session_preparation import gen_user_home
-
-def visualize(img_file, 
+   
+def lime(img_file, 
+         model, 
+         num_samples=3000, 
+         batch_size=50,
+         save_dir='./'):
+    """使用LIME算法将模型预测结果的可解释性可视化。 
+    
+    LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
+    在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
+    和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
+    得到每个输入维度的权重,以此来解释模型。  
+    
+    注意:LIME可解释性结果可视化目前只支持分类模型。
+         
+    Args:
+        img_file (str): 预测图像路径。
+        model (paddlex.cv.models): paddlex中的模型。
+        num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
+        batch_size (int): 预测数据batch大小,默认为50。
+        save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。        
+    """
+    assert model.model_type == 'classifier', \
+        'Now the interpretation visualize only be supported in classifier!'
+    if model.status != 'Normal':
+        raise Exception('The interpretation only can deal with the Normal model')
+    if not osp.exists(save_dir):
+        os.makedirs(save_dir)
+    model.arrange_transforms(
+                transforms=model.test_transforms, mode='test')
+    tmp_transforms = copy.deepcopy(model.test_transforms)
+    tmp_transforms.transforms = tmp_transforms.transforms[:-2]
+    img = tmp_transforms(img_file)[0]
+    img = np.around(img).astype('uint8')
+    img = np.expand_dims(img, axis=0)
+    interpreter = None
+    interpreter = get_lime_interpreter(img, model, num_samples=num_samples, batch_size=batch_size)
+    img_name = osp.splitext(osp.split(img_file)[-1])[0]
+    interpreter.interpret(img, save_dir=save_dir)
+    
+    
+def normlime(img_file, 
               model, 
               dataset=None,
-              algo='lime',
               num_samples=3000, 
               batch_size=50,
               save_dir='./'):
-    """可解释性可视化。
+    """使用NormLIME算法将模型预测结果的可解释性可视化。
+    
+    NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测
+    试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
+    
+    注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
+    注意2:NormLIME可解释性结果可视化目前只支持分类模型。
+         
     Args:
         img_file (str): 预测图像路径。
         model (paddlex.cv.models): paddlex中的模型。
         dataset (paddlex.datasets): 数据集读取器,默认为None。
-        algo (str): 可解释性方式,当前可选'lime'和'normlime'。
         num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
         batch_size (int): 预测数据batch大小,默认为50。
         save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。        
@@ -54,21 +99,16 @@ def visualize(img_file,
     img = np.around(img).astype('uint8')
     img = np.expand_dims(img, axis=0)
     interpreter = None
-    if algo == 'lime':
-        interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
-    elif algo == 'normlime':
-        if dataset is None:
-            raise Exception('The dataset is None. Cannot implement this kind of interpretation')
-        interpreter = get_normlime_interpreter(img, model, dataset, 
-                                     num_samples=num_samples, batch_size=batch_size,
+    if dataset is None:
+        raise Exception('The dataset is None. Cannot implement this kind of interpretation')
+    interpreter = get_normlime_interpreter(img, model, dataset, 
+                                 num_samples=num_samples, batch_size=batch_size,
                                      save_dir=save_dir)
-    else:
-        raise Exception('The {} interpretation method is not supported yet!'.format(algo))
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
     interpreter.interpret(img, save_dir=save_dir)
     
     
-def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
+def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
     def predict_func(image):
         image = image.astype('float32')
         for i in range(image.shape[0]):
@@ -79,8 +119,8 @@ def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
         model.test_transforms.transforms = tmp_transforms
         return out[0]
     labels_name = None
-    if dataset is not None:
-        labels_name = dataset.labels
+    if hasattr(model, 'labels'):
+        labels_name = model.labels
     interpreter = Interpretation('lime', 
                             predict_func,
                             labels_name,

+ 23 - 0
tutorials/interpret/lime.py

@@ -0,0 +1,23 @@
+import os
+# 选择使用0号卡
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import os.path as osp
+import paddlex as pdx
+
+# 下载和解压Imagenet果蔬分类数据集
+veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
+pdx.utils.download_and_decompress(veg_dataset, path='./')
+
+# 下载和解压已训练好的MobileNetV2模型
+model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
+pdx.utils.download_and_decompress(model_file, path='./')
+
+# 加载模型
+model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
+
+# 可解释性可视化
+pdx.interpret.lime(
+         'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+          model,
+          save_dir='./')

+ 1 - 8
tutorials/interpret/interpret.py → tutorials/interpret/normlime.py

@@ -24,15 +24,8 @@ test_dataset = pdx.datasets.ImageNet(
     transforms=model.test_transforms)
 
 # 可解释性可视化
-pdx.interpret.visualize(
-         'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
-          model,
-          test_dataset, 
-          algo='lime',
-          save_dir='./')
-pdx.interpret.visualize(
+pdx.interpret.normlime(
          'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
           model, 
           test_dataset, 
-          algo='normlime',
           save_dir='./')