瀏覽代碼

fix the interpret

sunyanfang01 5 年之前
父節點
當前提交
2f92c61b7e

+ 3 - 0
paddlex/interpret/core/interpretation_algorithms.py

@@ -442,3 +442,6 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
                 save_outdir, f_out
             )
         )
+    print('The image of intrepretation result save in {}'.format(os.path.join(
+                save_outdir, f_out
+            )))

+ 2 - 1
paddlex/interpret/core/lime_base.py

@@ -36,6 +36,7 @@ from skimage.color import gray2rgb
 from sklearn.linear_model import Ridge, lars_path
 from sklearn.utils import check_random_state
 
+import tqdm
 import copy
 from functools import partial
 from skimage.segmentation import quickshift
@@ -509,7 +510,7 @@ class LimeImageInterpreter(object):
         labels = []
         data[0, :] = 1
         imgs = []
-        for row in data:
+        for row in tqdm.tqdm(data):
             temp = copy.deepcopy(image)
             zeros = np.where(row == 0)[0]
             mask = np.zeros(segments.shape).astype(bool)

+ 5 - 3
paddlex/interpret/visualize.py

@@ -44,6 +44,8 @@ def visualize(img_file,
         'Now the interpretation visualize only be supported in classifier!'
     if model.status != 'Normal':
         raise Exception('The interpretation only can deal with the Normal model')
+    if not osp.exists(save_dir):
+        os.makedirs(save_dir)
     model.arrange_transforms(
                 transforms=model.test_transforms, mode='test')
     tmp_transforms = copy.deepcopy(model.test_transforms)
@@ -108,12 +110,12 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
     if dataset is not None:
         labels_name = dataset.labels
     root_path = os.environ['HOME']
-    root_path = osp.join(root_path, '.paddlex')
+    root_path = osp.join(root_path, '.paddlex0')
     pre_models_path = osp.join(root_path, "pre_models")
     if not osp.exists(pre_models_path):
-        os.makedirs(pre_models_path)
+        os.makedirs(root_path)
         url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
-        pdx.utils.download_and_decompress(url, path=pre_models_path)
+        pdx.utils.download_and_decompress(url, path=root_path)
     npy_dir = precompute_for_normlime(precompute_predict_func, 
                                       dataset, 
                                       num_samples=num_samples, 

+ 25 - 22
tutorials/interpret/interpret.py

@@ -4,38 +4,41 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
 import os.path as osp
 import paddlex as pdx
+from paddlex.cls import transforms
 
 # 下载和解压Imagenet果蔬分类数据集
 veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
 pdx.utils.download_and_decompress(veg_dataset, path='./')
 
-# 下载和解压已训练好的MobileNetV2模型
-model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
-pdx.utils.download_and_decompress(model_file, path='./')
-
-# 加载模型
-model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
+# 定义测试集的transform
+test_transforms = transforms.Compose([
+    transforms.ResizeByShort(short_size=256),
+    transforms.CenterCrop(crop_size=224),
+    transforms.Normalize()
+])
 
 # 定义测试所用的数据集
 test_dataset = pdx.datasets.ImageNet(
     data_dir='mini_imagenet_veg',
     file_list=osp.join('mini_imagenet_veg', 'test_list.txt'),
     label_list=osp.join('mini_imagenet_veg', 'labels.txt'),
-    transforms=model.test_transforms)
+    transforms=test_transforms)
 
-# 可解释性可视化
-# LIME算法
-pdx.interpret.visualize(
-    'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
-    model,
-    test_dataset,
-    algo='lime',
-    save_dir='./')
+# 下载和解压已训练好的MobileNetV2模型
+model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
+pdx.utils.download_and_decompress(model_file, path='./')
 
-# NormLIME算法
-pdx.interpret.visualize(
-    'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
-    model,
-    test_dataset,
-    algo='normlime',
-    save_dir='./')
+# 导入模型
+model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
+
+# 可解释性可视化
+pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+          model,
+          test_dataset, 
+          algo='lime',
+          save_dir='./')
+pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+          model, 
+          test_dataset, 
+          algo='normlime',
+          save_dir='./')