Эх сурвалжийг харах

Merge pull request #74 from SunAhong1993/syf0519

intrepret
Jason 5 жил өмнө
parent
commit
8be488a565

+ 2 - 1
paddlex/interpret/visualize.py

@@ -113,7 +113,8 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
     root_path = osp.join(root_path, '.paddlex')
     root_path = osp.join(root_path, '.paddlex')
     pre_models_path = osp.join(root_path, "pre_models")
     pre_models_path = osp.join(root_path, "pre_models")
     if not osp.exists(pre_models_path):
     if not osp.exists(pre_models_path):
-        os.makedirs(root_path)
+        if not osp.exists(root_path):
+            os.makedirs(root_path)
         url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
         url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
         pdx.utils.download_and_decompress(url, path=root_path)
         pdx.utils.download_and_decompress(url, path=root_path)
     npy_dir = precompute_for_normlime(precompute_predict_func, 
     npy_dir = precompute_for_normlime(precompute_predict_func, 

+ 1 - 3
tutorials/interpret/interpret.py

@@ -4,7 +4,6 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
 
 import os.path as osp
 import os.path as osp
 import paddlex as pdx
 import paddlex as pdx
-from paddlex.cls import transforms
 
 
 # 下载和解压Imagenet果蔬分类数据集
 # 下载和解压Imagenet果蔬分类数据集
 veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
 veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
@@ -17,7 +16,6 @@ pdx.utils.download_and_decompress(model_file, path='./')
 # 加载模型
 # 加载模型
 model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
 model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
 
 
-
 # 定义测试所用的数据集
 # 定义测试所用的数据集
 test_dataset = pdx.datasets.ImageNet(
 test_dataset = pdx.datasets.ImageNet(
     data_dir='mini_imagenet_veg',
     data_dir='mini_imagenet_veg',
@@ -37,4 +35,4 @@ pdx.interpret.visualize(
           model, 
           model, 
           test_dataset, 
           test_dataset, 
           algo='normlime',
           algo='normlime',
-          save_dir='./')
+          save_dir='./')