Ver código fonte

Update visualize.py

SunAhong1993 5 anos atrás
pai
commit
5dc2716de0
1 arquivos alterados com 18 adições e 10 exclusões
  1. 18 10
      paddlex/cv/models/explanation/visualize.py

+ 18 - 10
paddlex/cv/models/explanation/visualize.py

@@ -23,7 +23,7 @@ from .core.normlime_base import precompute_normlime_weights
 
 def visualize(img_file, 
               model, 
-              normlime_dataset=None,
+              dataset=None,
               explanation_type='lime',
               num_samples=3000, 
               batch_size=50,
@@ -39,11 +39,11 @@ def visualize(img_file,
     img = np.expand_dims(img, axis=0)
     explaier = None
     if explanation_type == 'lime':
-        explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size)
+        explaier = get_lime_explaier(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
     elif explanation_type == 'normlime':
-        if normlime_dataset is None:
-            raise Exception('The normlime_dataset is None. Cannot implement this kind of explanation')
-        explaier = get_normlime_explaier(img, model, normlime_dataset, 
+        if dataset is None:
+            raise Exception('The dataset is None. Cannot implement this kind of explanation')
+        explaier = get_normlime_explaier(img, model, dataset, 
                                      num_samples=num_samples, batch_size=batch_size,
                                      save_dir=save_dir)
     else:
@@ -52,7 +52,7 @@ def visualize(img_file,
     explaier.explain(img, save_dir=save_dir)
     
     
-def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
+def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
     def predict_func(image):
         image = image.astype('float32')
         for i in range(image.shape[0]):
@@ -60,14 +60,18 @@ def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         out = model.explanation_predict(image)
         return out[0]
+    labels_name = None
+    if dataset is not None:
+        labels_name = dataset.labels
     explaier = Explanation('lime', 
                             predict_func,
+                            labels_name,
                             num_samples=num_samples, 
                             batch_size=batch_size)
     return explaier
 
 
-def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'):
+def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
     def precompute_predict_func(image):
         image = image.astype('float32')
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
@@ -80,6 +84,9 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         out = model.explanation_predict(image)
         return out[0]
+    labels_name = None
+    if dataset is not None:
+        labels_name = dataset.labels
     root_path = os.environ['HOME']
     root_path = osp.join(root_path, '.paddlex')
     pre_models_path = osp.join(root_path, "pre_models")
@@ -88,21 +95,22 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_
         # TODO
         # paddlex.utils.download_and_decompress(url, path=pre_models_path)
     npy_dir = precompute_for_normlime(precompute_predict_func, 
-                                      normlime_dataset, 
+                                      dataset, 
                                       num_samples=num_samples, 
                                       batch_size=batch_size,
                                       save_dir=save_dir)
     explaier = Explanation('normlime', 
                             predict_func,
+                            labels_name,
                             num_samples=num_samples, 
                             batch_size=batch_size,
                             normlime_weights=npy_dir)
     return explaier
 
 
-def precompute_for_normlime(predict_func, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'):
+def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
     image_list = []
-    for item in normlime_dataset.file_list:
+    for item in dataset.file_list:
         image_list.append(item[0])
     return precompute_normlime_weights(
             image_list,