sunyanfang01 5 rokov pred
rodič
commit
b821586070

+ 3 - 3
docs/apis/visualize.md

@@ -136,10 +136,10 @@ LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME
 
 
 ### 使用示例
-> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/interpret.py)。
+> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/lime.py)。
 
 
-## LIME可解释性结果可视化
+## NormLIME可解释性结果可视化
 ```
 paddlex.interpret.normlime(img_file, 
                            model, 
@@ -163,5 +163,5 @@ NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会
 
 **注意:** dataset`读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
 ### 使用示例
-> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/interpret.py)。
+> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/normlime.py)。
 

+ 23 - 0
tutorials/interpret/lime.py

@@ -0,0 +1,23 @@
+import os
+# 选择使用0号卡
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import os.path as osp
+import paddlex as pdx
+
+# 下载和解压Imagenet果蔬分类数据集
+veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
+pdx.utils.download_and_decompress(veg_dataset, path='./')
+
+# 下载和解压已训练好的MobileNetV2模型
+model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
+pdx.utils.download_and_decompress(model_file, path='./')
+
+# 加载模型
+model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
+
+# 可解释性可视化
+pdx.interpret.lime(
+         'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+          model,
+          save_dir='./')

+ 31 - 0
tutorials/interpret/normlime.py

@@ -0,0 +1,31 @@
+import os
+# 选择使用0号卡
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import os.path as osp
+import paddlex as pdx
+
+# 下载和解压Imagenet果蔬分类数据集
+veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
+pdx.utils.download_and_decompress(veg_dataset, path='./')
+
+# 下载和解压已训练好的MobileNetV2模型
+model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
+pdx.utils.download_and_decompress(model_file, path='./')
+
+# 加载模型
+model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
+
+# 定义测试所用的数据集
+test_dataset = pdx.datasets.ImageNet(
+    data_dir='mini_imagenet_veg',
+    file_list=osp.join('mini_imagenet_veg', 'test_list.txt'),
+    label_list=osp.join('mini_imagenet_veg', 'labels.txt'),
+    transforms=model.test_transforms)
+
+# 可解释性可视化
+pdx.interpret.normlime(
+         'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+          model, 
+          test_dataset, 
+          save_dir='./')