Sfoglia il codice sorgente

Update visualize.py

SunAhong1993 5 anni fa
parent
commit
5a2ad68451
1 ha cambiato i file con 7 aggiunte e 3 eliminazioni
  1. 7 3
      paddlex/cv/models/explanation/visualize.py

+ 7 - 3
paddlex/cv/models/explanation/visualize.py

@@ -57,8 +57,10 @@ def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
         image = image.astype('float32')
         for i in range(image.shape[0]):
             image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
+        tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         out = model.explanation_predict(image)
+        model.test_transforms.transforms = tmp_transforms
         return out[0]
     labels_name = None
     if dataset is not None:
@@ -74,15 +76,19 @@ def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
 def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
     def precompute_predict_func(image):
         image = image.astype('float32')
+        tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         out = model.explanation_predict(image)
+        model.test_transforms.transforms = tmp_transforms
         return out[0]
     def predict_func(image):
         image = image.astype('float32')
         for i in range(image.shape[0]):
             image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
+        tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         out = model.explanation_predict(image)
+        model.test_transforms.transforms = tmp_transforms
         return out[0]
     labels_name = None
     if dataset is not None:
@@ -118,6 +124,4 @@ def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=
             num_samples=num_samples, 
             batch_size=batch_size,
             save_dir=save_dir)
-
-
-    
+