sunyanfang01 5 vuotta sitten
vanhempi
commit
d7a1c2ae38
1 muutettua tiedostoa jossa 2 lisäystä ja 2 poistoa
  1. 2 2
      paddlex/interpret/core/normlime_base.py

+ 2 - 2
paddlex/interpret/core/normlime_base.py

@@ -85,7 +85,7 @@ def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_
     precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir)
 
     # load precomputed results, compute normlime weights and save.
-    fname_list = glob.glob(os.path.join(save_dir, 'lime_weights_s{}.npy'.format(num_samples)))
+    fname_list = glob.glob(os.path.join(save_dir, 'lime_weights_s{}*.npy'.format(num_samples)))
     return compute_normlime_weights(fname_list, save_dir, num_samples)
 
 
@@ -174,6 +174,7 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
 
 def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
     normlime_weights_all_labels = {}
+    
     for f in a_list_lime_fnames:
         try:
             lime_weights_and_cluster = np.load(f, allow_pickle=True).item()
@@ -183,7 +184,6 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
             logging.info('When loading precomputed LIME result, skipping' + str(f))
             continue
         logging.info('Loading precomputed LIME result,' + str(f))
-
         pred_labels = lime_weights.keys()
         for y in pred_labels:
             normlime_weights = normlime_weights_all_labels.get(y, {})