Browse Source

Merge pull request #71 from SunAhong1993/syf0519

fix the interpret bug
Jason 5 năm trước cách đây
mục cha
commit
645ccef9c9

+ 3 - 0
paddlex/interpret/core/interpretation_algorithms.py

@@ -442,3 +442,6 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
                 save_outdir, f_out
             )
         )
+    print('The image of intrepretation result save in {}'.format(os.path.join(
+                save_outdir, f_out
+            )))

+ 2 - 1
paddlex/interpret/core/lime_base.py

@@ -36,6 +36,7 @@ from skimage.color import gray2rgb
 from sklearn.linear_model import Ridge, lars_path
 from sklearn.utils import check_random_state
 
+import tqdm
 import copy
 from functools import partial
 from skimage.segmentation import quickshift
@@ -509,7 +510,7 @@ class LimeImageInterpreter(object):
         labels = []
         data[0, :] = 1
         imgs = []
-        for row in data:
+        for row in tqdm.tqdm(data):
             temp = copy.deepcopy(image)
             zeros = np.where(row == 0)[0]
             mask = np.zeros(segments.shape).astype(bool)

+ 4 - 2
paddlex/interpret/visualize.py

@@ -44,6 +44,8 @@ def visualize(img_file,
         'Now the interpretation visualize only be supported in classifier!'
     if model.status != 'Normal':
         raise Exception('The interpretation only can deal with the Normal model')
+    if not osp.exists(save_dir):
+        os.makedirs(save_dir)
     model.arrange_transforms(
                 transforms=model.test_transforms, mode='test')
     tmp_transforms = copy.deepcopy(model.test_transforms)
@@ -111,9 +113,9 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
     root_path = osp.join(root_path, '.paddlex')
     pre_models_path = osp.join(root_path, "pre_models")
     if not osp.exists(pre_models_path):
-        os.makedirs(pre_models_path)
+        os.makedirs(root_path)
         url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
-        pdx.utils.download_and_decompress(url, path=pre_models_path)
+        pdx.utils.download_and_decompress(url, path=root_path)
     npy_dir = precompute_for_normlime(precompute_predict_func, 
                                       dataset, 
                                       num_samples=num_samples, 

+ 12 - 13
tutorials/interpret/interpret.py

@@ -4,6 +4,7 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
 import os.path as osp
 import paddlex as pdx
+from paddlex.cls import transforms
 
 # 下载和解压Imagenet果蔬分类数据集
 veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
@@ -16,6 +17,7 @@ pdx.utils.download_and_decompress(model_file, path='./')
 # 加载模型
 model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
 
+
 # 定义测试所用的数据集
 test_dataset = pdx.datasets.ImageNet(
     data_dir='mini_imagenet_veg',
@@ -24,18 +26,15 @@ test_dataset = pdx.datasets.ImageNet(
     transforms=model.test_transforms)
 
 # 可解释性可视化
-# LIME算法
 pdx.interpret.visualize(
-    'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
-    model,
-    test_dataset,
-    algo='lime',
-    save_dir='./')
-
-# NormLIME算法
+         'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+          model,
+          test_dataset, 
+          algo='lime',
+          save_dir='./')
 pdx.interpret.visualize(
-    'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
-    model,
-    test_dataset,
-    algo='normlime',
-    save_dir='./')
+         'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+          model, 
+          test_dataset, 
+          algo='normlime',
+          save_dir='./')