소스 검색

modify save file name

seven 5 년 전
부모
커밋
43046bfa6d
2개의 변경된 파일10개의 추가작업 그리고 29개의 파일을 삭제
  1. 8 27
      paddlex/interpret/core/interpretation_algorithms.py
  2. 2 2
      paddlex/interpret/visualize.py

+ 8 - 27
paddlex/interpret/core/interpretation_algorithms.py

@@ -107,7 +107,6 @@ class CAM(object):
             axes[1].set_title("CAM")
             axes[1].set_title("CAM")
 
 
         if save_outdir is not None:
         if save_outdir is not None:
-            os.makedirs(save_outdir, exist_ok=True)
             save_fig(data_, save_outdir, 'cam')
             save_fig(data_, save_outdir, 'cam')
 
 
         if visualization:
         if visualization:
@@ -219,7 +218,7 @@ class LIME(object):
                     self.lime_interpreter, l, w)
                     self.lime_interpreter, l, w)
                 temp, mask = self.lime_interpreter.get_image_and_mask(
                 temp, mask = self.lime_interpreter.get_image_and_mask(
                     l,
                     l,
-                    positive_only=False,
+                    positive_only=True,
                     hide_rest=False,
                     hide_rest=False,
                     num_features=num_to_show)
                     num_features=num_to_show)
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
@@ -227,7 +226,6 @@ class LIME(object):
                     "label {}, first {} superpixels".format(ln, num_to_show))
                     "label {}, first {} superpixels".format(ln, num_to_show))
 
 
         if save_outdir is not None:
         if save_outdir is not None:
-            os.makedirs(save_outdir, exist_ok=True)
             save_fig(data_, save_outdir, 'lime', self.num_samples)
             save_fig(data_, save_outdir, 'lime', self.num_samples)
 
 
         if visualization:
         if visualization:
@@ -412,7 +410,6 @@ class NormLIMEStandard(object):
             self._lime.lime_interpreter.local_weights = lime_weights
             self._lime.lime_interpreter.local_weights = lime_weights
 
 
         if save_outdir is not None:
         if save_outdir is not None:
-            os.makedirs(save_outdir, exist_ok=True)
             save_fig(data_, save_outdir, 'normlime', self.num_samples)
             save_fig(data_, save_outdir, 'normlime', self.num_samples)
 
 
         if visualization:
         if visualization:
@@ -596,7 +593,6 @@ class NormLIME(object):
             self._lime.lime_interpreter.local_weights = lime_weights
             self._lime.lime_interpreter.local_weights = lime_weights
 
 
         if save_outdir is not None:
         if save_outdir is not None:
-            os.makedirs(save_outdir, exist_ok=True)
             save_fig(data_, save_outdir, 'normlime', self.num_samples)
             save_fig(data_, save_outdir, 'normlime', self.num_samples)
 
 
         if visualization:
         if visualization:
@@ -674,26 +670,11 @@ def get_cam(image_show,
 
 
 def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
 def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
     import matplotlib.pyplot as plt
     import matplotlib.pyplot as plt
-    if isinstance(data_, str):
-        if algorithm_name == 'cam':
-            f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
-        else:
-            f_out = "{}_{}_s{}.png".format(algorithm_name,
-                                           data_.split('/')[-1], num_samples)
-        plt.savefig(os.path.join(save_outdir, f_out))
+    if algorithm_name == 'cam':
+        f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
     else:
     else:
-        n = 0
-        if algorithm_name == 'cam':
-            f_out = 'cam-{}.png'.format(n)
-        else:
-            f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
-        while os.path.exists(os.path.join(save_outdir, f_out)):
-            n += 1
-            if algorithm_name == 'cam':
-                f_out = 'cam-{}.png'.format(n)
-            else:
-                f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
-            continue
-        plt.savefig(os.path.join(save_outdir, f_out))
-    logging.info('The image of intrepretation result save in {}'.format(
-        os.path.join(save_outdir, f_out)))
+        f_out = "{}_{}_s{}.png".format(save_outdir, algorithm_name,
+                                       num_samples)
+
+    plt.savefig(f_out)
+    logging.info('The image of intrepretation result save in {}'.format(f_out))

+ 2 - 2
paddlex/interpret/visualize.py

@@ -58,7 +58,7 @@ def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
     interpreter = get_lime_interpreter(
     interpreter = get_lime_interpreter(
         img, model, num_samples=num_samples, batch_size=batch_size)
         img, model, num_samples=num_samples, batch_size=batch_size)
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
-    interpreter.interpret(img, save_dir=save_dir)
+    interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
 
 
 
 
 def normlime(img_file,
 def normlime(img_file,
@@ -111,7 +111,7 @@ def normlime(img_file,
         save_dir=save_dir,
         save_dir=save_dir,
         normlime_weights_file=normlime_weights_file)
         normlime_weights_file=normlime_weights_file)
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
-    interpreter.interpret(img, save_dir=save_dir)
+    interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
 
 
 
 
 def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
 def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):