Sfoglia il codice sorgente

fix the interpret for python35

sunyanfang01 5 anni fa
parent
commit
4fc5f47070

+ 4 - 3
paddlex/interpret/as_data_reader/readers.py

@@ -20,6 +20,7 @@ import six
 import glob
 from .data_path_utils import _find_classes
 from PIL import Image
+import paddlex.utils.logging as logging
 
 
 def resize_short(img, target_size, interpolation=None):
@@ -117,7 +118,7 @@ def read_image(img_path, target_size=256, crop_size=224):
         assert len(img_path.shape) == 4
         return img_path
     else:
-        ValueError(f"Not recognized data type {type(img_path)}.")
+        ValueError("Not recognized data type {}.".format(type(img_path)))
 
 
 class ReaderConfig(object):
@@ -156,7 +157,7 @@ class ReaderConfig(object):
 
                 img = cv2.imread(img_path)
                 if img is None:
-                    print(img_path)
+                    logging.info(img_path)
                     continue
                 img = resize_short(img, target_size, interpolation=None)
                 img = crop_image(img, crop_size, center=self.is_test)
@@ -208,7 +209,7 @@ def create_reader(list_image_path, list_label=None, is_test=False):
 
             img = cv2.imread(img_path)
             if img is None:
-                print(img_path)
+                logging.info(img_path)
                 continue
 
             img = resize_short(img, target_size, interpolation=None)

+ 28 - 22
paddlex/interpret/core/interpretation_algorithms.py

@@ -21,6 +21,7 @@ from . import lime_base
 from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, gen_user_home
 from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model
 from paddlex.interpret.as_data_reader.readers import read_image
+import paddlex.utils.logging as logging
 
 
 import cv2
@@ -71,7 +72,8 @@ class CAM(object):
         if self.label_names is not None:
             ln = self.label_names[l]
 
-        print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
+        prob_str = "%.3f" % (probability[pred_label[0]])
+        logging.info("predicted result: {} with probability {}.".format(ln, prob_str))
         return feature_maps, fc_weights
 
     def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
@@ -96,7 +98,8 @@ class CAM(object):
                 ax.axis("off")
             axes = axes.ravel()
             axes[0].imshow(self.image)
-            axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
+            prob_str = "{%.3f}" % (self.predicted_probability)
+            axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
 
             axes[1].imshow(cam)
             axes[1].set_title("CAM")
@@ -157,14 +160,15 @@ class LIME(object):
         if self.label_names is not None:
             ln = self.label_names[l]
             
-        print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
+        prob_str = "%.3f" % (probability[pred_label[0]])
+        logging.info("predicted result: {} with probability {}.".format(ln, prob_str))
 
         end = time.time()
         algo = lime_base.LimeImageInterpreter()
         interpreter = algo.interpret_instance(self.image, self.predict_fn, self.labels, 0,
                                               num_samples=self.num_samples, batch_size=self.batch_size)
         self.lime_interpreter = interpreter
-        print('lime time: ', time.time() - end, 's.')
+        logging.info('lime time: ' + str(time.time() - end) + 's.')
 
     def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
         if self.lime_interpreter is None:
@@ -189,7 +193,8 @@ class LIME(object):
                 ax.axis("off")
             axes = axes.ravel()
             axes[0].imshow(self.image)
-            axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
+            prob_str = "{%.3f}" % (self.predicted_probability)
+            axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
 
             axes[1].imshow(mark_boundaries(self.image, self.lime_interpreter.segments))
             axes[1].set_title("superpixel segmentation")
@@ -201,7 +206,7 @@ class LIME(object):
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols + i].set_title(f"label {ln}, first {num_to_show} superpixels")
+                axes[ncols + i].set_title("label {}, first {} superpixels".format(ln, num_to_show))
 
         if save_to_disk and save_outdir is not None:
             os.makedirs(save_outdir, exist_ok=True)
@@ -232,8 +237,9 @@ class NormLIME(object):
                 raise ValueError("NormLIME needs the KMeans model, where we provided a default one in "
                                  "pre_models/kmeans_model.pkl.")
         else:
-            print("Warning: It is *strongly* suggested to use the default KMeans model in pre_models/kmeans_model.pkl. "
-                  "Use another one will change the final result.")
+            logging.debug("Warning: It is *strongly* suggested to use the \
+            default KMeans model in pre_models/kmeans_model.pkl. \
+            Use another one will change the final result.")
             self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
 
         self.num_samples = num_samples
@@ -243,7 +249,7 @@ class NormLIME(object):
             self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
         except:
             self.normlime_weights = None
-            print("Warning: not find the correct precomputed Normlime result.")
+            logging.debug("Warning: not find the correct precomputed Normlime result.")
 
         self.predict_fn = predict_fn
 
@@ -289,8 +295,7 @@ class NormLIME(object):
         self.predicted_probability = self._lime.predicted_probability
         self.image = image_show[0]
         self.labels = self._lime.labels
-        # print(f'predicted result: {self.predicted_label} with probability {self.predicted_probability: .3f}')
-        print('performing NormLIME operations ...')
+        logging.info('performing NormLIME operations ...')
 
         cluster_labels = self.predict_cluster_labels(
             compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_interpreter.segments
@@ -329,7 +334,8 @@ class NormLIME(object):
 
             axes = axes.ravel()
             axes[0].imshow(self.image)
-            axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
+            prob_str = "{%.3f}" % (self.predicted_probability)
+            axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
 
             axes[1].imshow(mark_boundaries(self.image, self._lime.lime_interpreter.segments))
             axes[1].set_title("superpixel segmentation")
@@ -342,7 +348,7 @@ class NormLIME(object):
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels")
+                axes[ncols + i].set_title("LIME: first {} superpixels".format(num_to_show))
 
             # NormLIME visualization
             self._lime.lime_interpreter.local_weights = g_weights
@@ -351,7 +357,7 @@ class NormLIME(object):
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols * 2 + i].set_title(f"NormLIME: first {num_to_show} superpixels")
+                axes[ncols * 2 + i].set_title("NormLIME: first {} superpixels".format(num_to_show))
 
             # NormLIME*LIME visualization
             combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
@@ -361,7 +367,7 @@ class NormLIME(object):
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels")
+                axes[ncols * 3 + i].set_title("Combined: first {} superpixels".format(num_to_show))
 
             self._lime.lime_interpreter.local_weights = lime_weights
 
@@ -433,32 +439,32 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
     import matplotlib.pyplot as plt
     if isinstance(data_, str):
         if algorithm_name == 'cam':
-            f_out = f"{algorithm_name}_{data_.split('/')[-1]}.png"
+            f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
         else:
-            f_out = f"{algorithm_name}_{data_.split('/')[-1]}_s{num_samples}.png"
+            f_out = "{}_{}_s{}.png".format(algorithm_name, data_.split('/')[-1], num_samples)
         plt.savefig(
             os.path.join(save_outdir, f_out)
         )
     else:
         n = 0
         if algorithm_name == 'cam':
-            f_out = f'cam-{n}.png'
+            f_out = 'cam-{}.png'.format(n)
         else:
-            f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
+            f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
         while os.path.exists(
                 os.path.join(save_outdir, f_out)
         ):
             n += 1
             if algorithm_name == 'cam':
-                f_out = f'cam-{n}.png'
+                f_out = 'cam-{}.png'.format(n)
             else:
-                f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
+                f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
             continue
         plt.savefig(
             os.path.join(
                 save_outdir, f_out
             )
         )
-    print('The image of intrepretation result save in {}'.format(os.path.join(
+    logging.info('The image of intrepretation result save in {}'.format(os.path.join(
                 save_outdir, f_out
             )))

+ 4 - 4
paddlex/interpret/core/lime_base.py

@@ -34,6 +34,7 @@ import scipy as sp
 import tqdm
 import copy
 from functools import partial
+import paddlex.utils.logging as logging
 
 
 class LimeBase(object):
@@ -230,9 +231,9 @@ class LimeBase(object):
         local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
 
         if self.verbose:
-            print('Intercept', easy_model.intercept_)
-            print('Prediction_local', local_pred,)
-            print('Right:', neighborhood_labels[0, label])
+            logging.info('Intercept' + str(easy_model.intercept_))
+            logging.info('Prediction_local' + str(local_pred))
+            logging.info('Right:' + str(neighborhood_labels[0, label]))
         return (easy_model.intercept_,
                 sorted(zip(used_features, easy_model.coef_),
                        key=lambda x: np.abs(x[1]), reverse=True),
@@ -451,7 +452,6 @@ class LimeImageInterpreter(object):
             d = cdist(centroids, centroids, 'sqeuclidean')
 
             for x in np.unique(segments):
-                # print(np.argmin(d[x]))
                 a = [image[segments == i] for i in np.argsort(d[x])[1:6]]
                 mx = np.mean(np.concatenate(a), axis=0)
                 fudged_image[segments == x] = mx

+ 14 - 15
paddlex/interpret/core/normlime_base.py

@@ -67,7 +67,6 @@ def centroid_using_superpixels(features, segments):
     one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
     for i, r in enumerate(regions):
         one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
-    # print(one_list.shape)
     return one_list
 
 
@@ -85,7 +84,7 @@ def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_
     precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir)
 
     # load precomputed results, compute normlime weights and save.
-    fname_list = glob.glob(os.path.join(save_dir, f'lime_weights_s{num_samples}*.npy'))
+    fname_list = glob.glob(os.path.join(save_dir, 'lime_weights_s{}.npy'.format(num_samples)))
     return compute_normlime_weights(fname_list, save_dir, num_samples)
 
 
@@ -117,10 +116,10 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
 
     for data_index, each_data_ in enumerate(list_data_):
         if isinstance(each_data_, str):
-            save_path = f"lime_weights_s{num_samples}_{each_data_.split('/')[-1].split('.')[0]}.npy"
+            save_path = "lime_weights_s{}_{}.npy".format(num_samples, each_data_.split('/')[-1].split('.')[0])
             save_path = os.path.join(save_dir, save_path)
         else:
-            save_path = f"lime_weights_s{num_samples}_{data_index}.npy"
+            save_path = "lime_weights_s{}_{}.npy".format(num_samples, data_index)
             save_path = os.path.join(save_dir, save_path)
 
         if os.path.exists(save_path):
@@ -180,9 +179,9 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
             lime_weights = lime_weights_and_cluster['lime_weights']
             cluster = lime_weights_and_cluster['cluster']
         except:
-            print('When loading precomputed LIME result, skipping', f)
+            logging.info('When loading precomputed LIME result, skipping' + str(f))
             continue
-        print('Loading precomputed LIME result,', f)
+        logging.info('Loading precomputed LIME result,' + str(f))
 
         pred_labels = lime_weights.keys()
         for y in pred_labels:
@@ -207,23 +206,23 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
 
     # check normlime
     if len(normlime_weights_all_labels.keys()) < max(normlime_weights_all_labels.keys()) + 1:
-        print(
-            "\n"
-            "Warning: !!! \n"
-            f"There are at least {max(normlime_weights_all_labels.keys()) + 1} classes, "
-            f"but the NormLIME has results of only {len(normlime_weights_all_labels.keys())} classes. \n"
-            "It may have cause unstable results in the later computation"
-            " but can be improved by computing more test samples."
+        logging.info(
+            "\n" + \
+            "Warning: !!! \n" + \
+            "There are at least {} classes, ".format(max(normlime_weights_all_labels.keys()) + 1) + \
+            "but the NormLIME has results of only {} classes. \n".format(len(normlime_weights_all_labels.keys())) + \
+            "It may have cause unstable results in the later computation" + \
+            " but can be improved by computing more test samples." + \
             "\n"
         )
 
     n = 0
-    f_out = f'normlime_weights_s{lime_num_samples}_samples_{len(a_list_lime_fnames)}-{n}.npy'
+    f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n)
     while os.path.exists(
             os.path.join(save_dir, f_out)
     ):
         n += 1
-        f_out = f'normlime_weights_s{lime_num_samples}_samples_{len(a_list_lime_fnames)}-{n}.npy'
+        f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n)
         continue
 
     np.save(