seven 5 роки тому
батько
коміт
a7aa87a6ae

+ 2 - 8
paddlex/interpret/core/interpretation.py

@@ -33,21 +33,15 @@ class Interpretation(object):
         self.algorithm = supported_algorithms[self.algorithm_name](
             self.predict_fn, label_names, **kwargs)
 
-    def interpret(self,
-                  data_,
-                  visualization=True,
-                  save_to_disk=True,
-                  save_dir='./tmp'):
+    def interpret(self, data_, visualization=True, save_dir='./'):
         """
 
         Args:
             data_: data_ can be a path or numpy.ndarray.
             visualization: whether to show using matplotlib.
-            save_to_disk: whether to save the figure in local disk.
             save_dir: dir to save figure if save_to_disk is True.
 
         Returns:
 
         """
-        return self.algorithm.interpret(data_, visualization, save_to_disk,
-                                        save_dir)
+        return self.algorithm.interpret(data_, visualization, save_dir)

+ 12 - 28
paddlex/interpret/core/interpretation_algorithms.py

@@ -76,16 +76,12 @@ class CAM(object):
             ln, prob_str))
         return feature_maps, fc_weights
 
-    def interpret(self,
-                  data_,
-                  visualization=True,
-                  save_to_disk=True,
-                  save_outdir=None):
+    def interpret(self, data_, visualization=True, save_outdir=None):
         feature_maps, fc_weights = self.preparation_cam(data_)
         cam = get_cam(self.image, feature_maps, fc_weights,
                       self.predicted_label)
 
-        if visualization or save_to_disk:
+        if visualization or save_outdir is not None:
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
@@ -110,7 +106,7 @@ class CAM(object):
             axes[1].imshow(cam)
             axes[1].set_title("CAM")
 
-        if save_to_disk and save_outdir is not None:
+        if save_outdir is not None:
             os.makedirs(save_outdir, exist_ok=True)
             save_fig(data_, save_outdir, 'cam')
 
@@ -186,15 +182,11 @@ class LIME(object):
         self.lime_interpreter = interpreter
         logging.info('lime time: ' + str(time.time() - end) + 's.')
 
-    def interpret(self,
-                  data_,
-                  visualization=True,
-                  save_to_disk=True,
-                  save_outdir=None):
+    def interpret(self, data_, visualization=True, save_outdir=None):
         if self.lime_interpreter is None:
             self.preparation_lime(data_)
 
-        if visualization or save_to_disk:
+        if visualization or save_outdir is not None:
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
@@ -234,7 +226,7 @@ class LIME(object):
                 axes[ncols + i].set_title(
                     "label {}, first {} superpixels".format(ln, num_to_show))
 
-        if save_to_disk and save_outdir is not None:
+        if save_outdir is not None:
             os.makedirs(save_outdir, exist_ok=True)
             save_fig(data_, save_outdir, 'lime', self.num_samples)
 
@@ -337,11 +329,7 @@ class NormLIMEStandard(object):
 
         return g_weights
 
-    def interpret(self,
-                  data_,
-                  visualization=True,
-                  save_to_disk=True,
-                  save_outdir=None):
+    def interpret(self, data_, visualization=True, save_outdir=None):
         if self.normlime_weights is None:
             raise ValueError(
                 "Not find the correct precomputed NormLIME result. \n"
@@ -351,7 +339,7 @@ class NormLIMEStandard(object):
         g_weights = self.preparation_normlime(data_)
         lime_weights = self._lime.lime_interpreter.local_weights
 
-        if visualization or save_to_disk:
+        if visualization or save_outdir is not None:
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
@@ -423,7 +411,7 @@ class NormLIMEStandard(object):
 
             self._lime.lime_interpreter.local_weights = lime_weights
 
-        if save_to_disk and save_outdir is not None:
+        if save_outdir is not None:
             os.makedirs(save_outdir, exist_ok=True)
             save_fig(data_, save_outdir, 'normlime', self.num_samples)
 
@@ -524,11 +512,7 @@ class NormLIME(object):
 
         return g_weights
 
-    def interpret(self,
-                  data_,
-                  visualization=True,
-                  save_to_disk=True,
-                  save_outdir=None):
+    def interpret(self, data_, visualization=True, save_outdir=None):
         if self.normlime_weights is None:
             raise ValueError(
                 "Not find the correct precomputed NormLIME result. \n"
@@ -538,7 +522,7 @@ class NormLIME(object):
         g_weights = self.preparation_normlime(data_)
         lime_weights = self._lime.lime_interpreter.local_weights
 
-        if visualization or save_to_disk:
+        if visualization or save_outdir is not None:
             import matplotlib.pyplot as plt
             from skimage.segmentation import mark_boundaries
             l = self.labels[0]
@@ -611,7 +595,7 @@ class NormLIME(object):
 
             self._lime.lime_interpreter.local_weights = lime_weights
 
-        if save_to_disk and save_outdir is not None:
+        if save_outdir is not None:
             os.makedirs(save_outdir, exist_ok=True)
             save_fig(data_, save_outdir, 'normlime', self.num_samples)
 

+ 1 - 1
paddlex/interpret/visualize.py

@@ -168,7 +168,7 @@ def get_normlime_interpreter(img,
         normlime_weights_file = precompute_global_classifier(
             dataset,
             predict_func,
-            save_path=normlime_weights_file,
+            save_path=osp.join(save_dir, normlime_weights_file),
             batch_size=batch_size)
 
     interpreter = Interpretation(