Преглед изворни кода

Update explanation_algorithms.py

SunAhong1993 пре 5 година
родитељ
комит
d51704bd3c
1 измењених фајлова са 34 додато и 11 уклоњено
  1. 34 11
      paddlex/cv/models/explanation/core/explanation_algorithms.py

+ 34 - 11
paddlex/cv/models/explanation/core/explanation_algorithms.py

@@ -25,7 +25,7 @@ import cv2
 
 
 class CAM(object):
-    def __init__(self, predict_fn):
+    def __init__(self, predict_fn, label_names):
         """
 
         Args:
@@ -37,6 +37,7 @@ class CAM(object):
 
         """
         self.predict_fn = predict_fn
+        self.label_names = label_names
 
     def preparation_cam(self, data_):
         image_show = read_image(data_)
@@ -61,8 +62,13 @@ class CAM(object):
 
         fc_weights = paddle_get_fc_weights()
         feature_maps = result[1]
+        
+        l = pred_label[0]
+        ln = l
+        if self.label_names is not None:
+            ln = self.label_names[l]
 
-        print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}')
+        print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
         return feature_maps, fc_weights
 
     def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
@@ -73,6 +79,9 @@ class CAM(object):
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
+            ln = l 
+            if self.label_names is not None:
+                ln = self.label_names[l]
 
             psize = 5
             nrows = 1
@@ -84,7 +93,7 @@ class CAM(object):
                 ax.axis("off")
             axes = axes.ravel()
             axes[0].imshow(self.image)
-            axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
+            axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
 
             axes[1].imshow(cam)
             axes[1].set_title("CAM")
@@ -100,7 +109,7 @@ class CAM(object):
 
 
 class LIME(object):
-    def __init__(self, predict_fn, num_samples=3000, batch_size=50):
+    def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50):
         """
         LIME wrapper. See lime_base.py for the detailed LIME implementation.
         Args:
@@ -115,6 +124,7 @@ class LIME(object):
         self.labels = None
         self.image = None
         self.lime_explainer = None
+        self.label_names = label_names
 
     def preparation_lime(self, data_):
         image_show = read_image(data_)
@@ -137,8 +147,13 @@ class LIME(object):
         self.predicted_probability = probability[pred_label[0]]
         self.image = image_show[0]
         self.labels = pred_label
-
-        print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}')
+        
+        l = pred_label[0]
+        ln = l
+        if self.label_names is not None:
+            ln = self.label_names[l]
+            
+        print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
 
         end = time.time()
         algo = lime_base.LimeImageExplainer()
@@ -155,6 +170,9 @@ class LIME(object):
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
+            ln = l 
+            if self.label_names is not None:
+                ln = self.label_names[l]
 
             psize = 5
             nrows = 2
@@ -167,7 +185,7 @@ class LIME(object):
                 ax.axis("off")
             axes = axes.ravel()
             axes[0].imshow(self.image)
-            axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
+            axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
 
             axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments))
             axes[1].set_title("superpixel segmentation")
@@ -179,7 +197,7 @@ class LIME(object):
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
+                axes[ncols + i].set_title(f"label {ln}, first {num_to_show} superpixels")
 
         if save_to_disk and save_outdir is not None:
             os.makedirs(save_outdir, exist_ok=True)
@@ -192,7 +210,7 @@ class LIME(object):
 
 
 class NormLIME(object):
-    def __init__(self, predict_fn, num_samples=3000, batch_size=50,
+    def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50,
                  kmeans_model_for_normlime=None, normlime_weights=None):
         if kmeans_model_for_normlime is None:
             try:
@@ -218,6 +236,7 @@ class NormLIME(object):
 
         self.labels = None
         self.image = None
+        self.label_names = label_names
 
     def predict_cluster_labels(self, feature_map, segments):
         return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments))
@@ -239,6 +258,7 @@ class NormLIME(object):
     def preparation_normlime(self, data_):
         self._lime = LIME(
             self.predict_fn,
+            self.label_names,
             self.num_samples,
             self.batch_size
         )
@@ -273,6 +293,9 @@ class NormLIME(object):
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
+            ln = l
+            if self.label_names is not None:
+                ln = self.label_names[l]
 
             psize = 5
             nrows = 4
@@ -287,7 +310,7 @@ class NormLIME(object):
 
             axes = axes.ravel()
             axes[0].imshow(self.image)
-            axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
+            axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
 
             axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments))
             axes[1].set_title("superpixel segmentation")
@@ -416,4 +439,4 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
             os.path.join(
                 save_outdir, f_out
             )
-        )
+        )