Browse Source

fix the tutorial

sunyanfang01 5 years ago
parent
commit
8f095942e2
2 changed files with 8 additions and 6 deletions
  1. 4 3
      paddlex/interpret/visualize.py
  2. 4 3
      tutorials/interpret/interpret.py

+ 4 - 3
paddlex/interpret/visualize.py

@@ -17,6 +17,7 @@ import cv2
 import copy
 import os.path as osp
 import numpy as np
+import paddlex as pdx
 from .interpretation_predict import interpretation_predict
 from .core.interpretation import Interpretation
 from .core.normlime_base import precompute_normlime_weights
@@ -35,7 +36,7 @@ def visualize(img_file,
         model (paddlex.cv.models): paddlex中的模型。
         dataset (paddlex.datasets): 数据集读取器,默认为None。
         algo (str): 可解释性方式,当前可选'lime'和'normlime'。
-        num_samples (int): 随机采样数量,默认为3000。
+        num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
         batch_size (int): 预测数据batch大小,默认为50。
         save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。        
     """
@@ -111,8 +112,8 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
     pre_models_path = osp.join(root_path, "pre_models")
     if not osp.exists(pre_models_path):
         os.makedirs(pre_models_path)
-        # TODO
-        # paddlex.utils.download_and_decompress(url, path=pre_models_path)
+        url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
+        pdx.utils.download_and_decompress(url, path=pre_models_path)
     npy_dir = precompute_for_normlime(precompute_predict_func, 
                                       dataset, 
                                       num_samples=num_samples, 

+ 4 - 3
tutorials/interpret/interpret.py

@@ -2,11 +2,12 @@ import os
 # 选择使用0号卡
 os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
+import os.path as osp
 import paddlex as pdx
-from paddlex.cla import transforms
+from paddlex.cls import transforms
 
 # 下载和解压Imagenet果蔬分类数据集
-veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/mini_imagenet_veg.tar.gz'
+veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
 pdx.utils.download_and_decompress(veg_dataset, path='./')
 
 # 定义测试集的transform
@@ -24,7 +25,7 @@ test_dataset = pdx.datasets.ImageNet(
     transforms=test_transforms)
 
 # 下载和解压已训练好的MobileNetV2模型
-model_file = 'https://bj.bcebos.com/paddlex/models/mini_imagenet_veg_mobilenetv2.tar.gz'
+model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
 pdx.utils.download_and_decompress(model_file, path='./')
 
 # 导入模型