Browse Source

add interpret

sunyanfang01 5 years ago
parent
commit
1ee92504fb
1 changed files with 12 additions and 16 deletions
  1. 12 16
      tutorials/interpret/interpret.py

+ 12 - 16
tutorials/interpret/interpret.py

@@ -10,34 +10,30 @@ from paddlex.cls import transforms
 veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
 pdx.utils.download_and_decompress(veg_dataset, path='./')
 
-# 定义测试集的transform
-test_transforms = transforms.Compose([
-    transforms.ResizeByShort(short_size=256),
-    transforms.CenterCrop(crop_size=224),
-    transforms.Normalize()
-])
+# 下载和解压已训练好的MobileNetV2模型
+model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
+pdx.utils.download_and_decompress(model_file, path='./')
+
+# 加载模型
+model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
+
 
 # 定义测试所用的数据集
 test_dataset = pdx.datasets.ImageNet(
     data_dir='mini_imagenet_veg',
     file_list=osp.join('mini_imagenet_veg', 'test_list.txt'),
     label_list=osp.join('mini_imagenet_veg', 'labels.txt'),
-    transforms=test_transforms)
-
-# 下载和解压已训练好的MobileNetV2模型
-model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
-pdx.utils.download_and_decompress(model_file, path='./')
-
-# 导入模型
-model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
+    transforms=model.test_transforms)
 
 # 可解释性可视化
-pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+pdx.interpret.visualize(
+         'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
           model,
           test_dataset, 
           algo='lime',
           save_dir='./')
-pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
+pdx.interpret.visualize(
+         'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', 
           model, 
           test_dataset, 
           algo='normlime',