|
@@ -18,7 +18,8 @@ import time
|
|
|
|
|
|
|
|
from . import lime_base
|
|
from . import lime_base
|
|
|
from ..as_data_reader.readers import read_image
|
|
from ..as_data_reader.readers import read_image
|
|
|
-from ._session_preparation import paddle_get_fc_weights
|
|
|
|
|
|
|
+from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, h_pre_models_kmeans
|
|
|
|
|
+from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model
|
|
|
|
|
|
|
|
import cv2
|
|
import cv2
|
|
|
|
|
|
|
@@ -37,8 +38,8 @@ class CAM(object):
|
|
|
"""
|
|
"""
|
|
|
self.predict_fn = predict_fn
|
|
self.predict_fn = predict_fn
|
|
|
|
|
|
|
|
- def preparation_cam(self, data_path):
|
|
|
|
|
- image_show = read_image(data_path)
|
|
|
|
|
|
|
+ def preparation_cam(self, data_):
|
|
|
|
|
+ image_show = read_image(data_)
|
|
|
result = self.predict_fn(image_show)
|
|
result = self.predict_fn(image_show)
|
|
|
|
|
|
|
|
logit = result[0][0]
|
|
logit = result[0][0]
|
|
@@ -61,7 +62,7 @@ class CAM(object):
|
|
|
fc_weights = paddle_get_fc_weights()
|
|
fc_weights = paddle_get_fc_weights()
|
|
|
feature_maps = result[1]
|
|
feature_maps = result[1]
|
|
|
|
|
|
|
|
- print('predicted result: ', pred_label[0], probability[pred_label[0]])
|
|
|
|
|
|
|
+ print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}')
|
|
|
return feature_maps, fc_weights
|
|
return feature_maps, fc_weights
|
|
|
|
|
|
|
|
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
|
|
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
|
|
@@ -115,8 +116,8 @@ class LIME(object):
|
|
|
self.image = None
|
|
self.image = None
|
|
|
self.lime_explainer = None
|
|
self.lime_explainer = None
|
|
|
|
|
|
|
|
- def preparation_lime(self, data_path):
|
|
|
|
|
- image_show = read_image(data_path)
|
|
|
|
|
|
|
+ def preparation_lime(self, data_):
|
|
|
|
|
+ image_show = read_image(data_)
|
|
|
result = self.predict_fn(image_show)
|
|
result = self.predict_fn(image_show)
|
|
|
|
|
|
|
|
result = result[0] # only one image here.
|
|
result = result[0] # only one image here.
|
|
@@ -137,7 +138,7 @@ class LIME(object):
|
|
|
self.image = image_show[0]
|
|
self.image = image_show[0]
|
|
|
self.labels = pred_label
|
|
self.labels = pred_label
|
|
|
|
|
|
|
|
- print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]: .3f}')
|
|
|
|
|
|
|
+ print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}')
|
|
|
|
|
|
|
|
end = time.time()
|
|
end = time.time()
|
|
|
algo = lime_base.LimeImageExplainer()
|
|
algo = lime_base.LimeImageExplainer()
|
|
@@ -157,7 +158,7 @@ class LIME(object):
|
|
|
|
|
|
|
|
psize = 5
|
|
psize = 5
|
|
|
nrows = 2
|
|
nrows = 2
|
|
|
- weights_choices = [0.6, 0.75, 0.85]
|
|
|
|
|
|
|
+ weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
|
|
|
ncols = len(weights_choices)
|
|
ncols = len(weights_choices)
|
|
|
|
|
|
|
|
plt.close()
|
|
plt.close()
|
|
@@ -193,15 +194,25 @@ class LIME(object):
|
|
|
class NormLIME(object):
|
|
class NormLIME(object):
|
|
|
def __init__(self, predict_fn, num_samples=3000, batch_size=50,
|
|
def __init__(self, predict_fn, num_samples=3000, batch_size=50,
|
|
|
kmeans_model_for_normlime=None, normlime_weights=None):
|
|
kmeans_model_for_normlime=None, normlime_weights=None):
|
|
|
- assert kmeans_model_for_normlime is not None, "NormLIME needs the KMeans model."
|
|
|
|
|
- if normlime_weights is None:
|
|
|
|
|
- raise NotImplementedError("Computing NormLIME weights is not implemented yet.")
|
|
|
|
|
|
|
+ if kmeans_model_for_normlime is None:
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
|
|
|
|
|
+ except:
|
|
|
|
|
+ raise ValueError("NormLIME needs the KMeans model, where we provided a default one in "
|
|
|
|
|
+ "pre_models/kmeans_model.pkl.")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print("Warning: It is *strongly* suggested to use the default KMeans model in pre_models/kmeans_model.pkl. "
|
|
|
|
|
+ "Use another one will change the final result.")
|
|
|
|
|
+ self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
|
|
|
|
|
|
|
|
self.num_samples = num_samples
|
|
self.num_samples = num_samples
|
|
|
self.batch_size = batch_size
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
- self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
|
|
|
|
|
- self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
|
|
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
|
|
|
|
|
+ except:
|
|
|
|
|
+ self.normlime_weights = None
|
|
|
|
|
+ print("Warning: not find the correct precomputed Normlime result.")
|
|
|
|
|
|
|
|
self.predict_fn = predict_fn
|
|
self.predict_fn = predict_fn
|
|
|
|
|
|
|
@@ -215,9 +226,8 @@ class NormLIME(object):
|
|
|
# global weights
|
|
# global weights
|
|
|
g_weights = {y: [] for y in pred_labels}
|
|
g_weights = {y: [] for y in pred_labels}
|
|
|
for y in pred_labels:
|
|
for y in pred_labels:
|
|
|
- cluster_weights_y = self.normlime_weights[y]
|
|
|
|
|
|
|
+ cluster_weights_y = self.normlime_weights.get(y, {})
|
|
|
g_weights[y] = [
|
|
g_weights[y] = [
|
|
|
- # some are not in the dict, 3000 samples may be not enough.
|
|
|
|
|
(i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
|
|
(i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
|
|
|
]
|
|
]
|
|
|
|
|
|
|
@@ -226,38 +236,25 @@ class NormLIME(object):
|
|
|
|
|
|
|
|
return g_weights
|
|
return g_weights
|
|
|
|
|
|
|
|
- def preparation_normlime(self, data_path):
|
|
|
|
|
|
|
+ def preparation_normlime(self, data_):
|
|
|
self._lime = LIME(
|
|
self._lime = LIME(
|
|
|
- lambda images: self.predict_fn(images)[0],
|
|
|
|
|
|
|
+ self.predict_fn,
|
|
|
self.num_samples,
|
|
self.num_samples,
|
|
|
self.batch_size
|
|
self.batch_size
|
|
|
)
|
|
)
|
|
|
- self._lime.preparation_lime(data_path)
|
|
|
|
|
|
|
+ self._lime.preparation_lime(data_)
|
|
|
|
|
|
|
|
- image_show = read_image(data_path)
|
|
|
|
|
- result = self.predict_fn(image_show)
|
|
|
|
|
|
|
+ image_show = read_image(data_)
|
|
|
|
|
|
|
|
- logit = result[0][0] # only one image here.
|
|
|
|
|
- if abs(np.sum(logit) - 1.0) > 1e-4:
|
|
|
|
|
- # softmax
|
|
|
|
|
- exp_result = np.exp(logit)
|
|
|
|
|
- probability = exp_result / np.sum(exp_result)
|
|
|
|
|
- else:
|
|
|
|
|
- probability = logit
|
|
|
|
|
-
|
|
|
|
|
- # only explain top 1
|
|
|
|
|
- pred_label = np.argsort(probability)
|
|
|
|
|
- pred_label = pred_label[-1:]
|
|
|
|
|
-
|
|
|
|
|
- self.predicted_label = pred_label[0]
|
|
|
|
|
- self.predicted_probability = probability[pred_label[0]]
|
|
|
|
|
|
|
+ self.predicted_label = self._lime.predicted_label
|
|
|
|
|
+ self.predicted_probability = self._lime.predicted_probability
|
|
|
self.image = image_show[0]
|
|
self.image = image_show[0]
|
|
|
- self.labels = pred_label
|
|
|
|
|
- print('predicted result: ', pred_label[0], probability[pred_label[0]])
|
|
|
|
|
|
|
+ self.labels = self._lime.labels
|
|
|
|
|
+ # print(f'predicted result: {self.predicted_label} with probability {self.predicted_probability: .3f}')
|
|
|
|
|
+ print('performing NormLIME operations ...')
|
|
|
|
|
|
|
|
- local_feature_map = result[1][0]
|
|
|
|
|
cluster_labels = self.predict_cluster_labels(
|
|
cluster_labels = self.predict_cluster_labels(
|
|
|
- local_feature_map.transpose((1, 2, 0)), self._lime.lime_explainer.segments
|
|
|
|
|
|
|
+ compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_explainer.segments
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
|
|
g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
|
|
@@ -265,6 +262,10 @@ class NormLIME(object):
|
|
|
return g_weights
|
|
return g_weights
|
|
|
|
|
|
|
|
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
|
|
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
|
|
|
|
|
+ if self.normlime_weights is None:
|
|
|
|
|
+ raise ValueError("Not find the correct precomputed NormLIME result. \n"
|
|
|
|
|
+ "\t Try to call compute_normlime_weights() first or load the correct path.")
|
|
|
|
|
+
|
|
|
g_weights = self.preparation_normlime(data_)
|
|
g_weights = self.preparation_normlime(data_)
|
|
|
lime_weights = self._lime.lime_explainer.local_exp
|
|
lime_weights = self._lime.lime_explainer.local_exp
|
|
|
|
|
|
|
@@ -275,7 +276,8 @@ class NormLIME(object):
|
|
|
|
|
|
|
|
psize = 5
|
|
psize = 5
|
|
|
nrows = 4
|
|
nrows = 4
|
|
|
- weights_choices = [0.6, 0.85, 0.99]
|
|
|
|
|
|
|
+ weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
|
|
|
|
|
+ nums_to_show = []
|
|
|
ncols = len(weights_choices)
|
|
ncols = len(weights_choices)
|
|
|
|
|
|
|
|
plt.close()
|
|
plt.close()
|
|
@@ -293,32 +295,31 @@ class NormLIME(object):
|
|
|
# LIME visualization
|
|
# LIME visualization
|
|
|
for i, w in enumerate(weights_choices):
|
|
for i, w in enumerate(weights_choices):
|
|
|
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
|
|
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
|
|
|
|
|
+ nums_to_show.append(num_to_show)
|
|
|
temp, mask = self._lime.lime_explainer.get_image_and_mask(
|
|
temp, mask = self._lime.lime_explainer.get_image_and_mask(
|
|
|
l, positive_only=False, hide_rest=False, num_features=num_to_show
|
|
l, positive_only=False, hide_rest=False, num_features=num_to_show
|
|
|
)
|
|
)
|
|
|
axes[ncols + i].imshow(mark_boundaries(temp, mask))
|
|
axes[ncols + i].imshow(mark_boundaries(temp, mask))
|
|
|
- axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
|
|
|
|
|
|
|
+ axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels")
|
|
|
|
|
|
|
|
# NormLIME visualization
|
|
# NormLIME visualization
|
|
|
self._lime.lime_explainer.local_exp = g_weights
|
|
self._lime.lime_explainer.local_exp = g_weights
|
|
|
- for i, w in enumerate(weights_choices):
|
|
|
|
|
- num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
|
|
|
|
|
|
|
+ for i, num_to_show in enumerate(nums_to_show):
|
|
|
temp, mask = self._lime.lime_explainer.get_image_and_mask(
|
|
temp, mask = self._lime.lime_explainer.get_image_and_mask(
|
|
|
l, positive_only=False, hide_rest=False, num_features=num_to_show
|
|
l, positive_only=False, hide_rest=False, num_features=num_to_show
|
|
|
)
|
|
)
|
|
|
axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
|
|
axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
|
|
|
- axes[ncols * 2 + i].set_title(f"label {l}, first {num_to_show} superpixels")
|
|
|
|
|
|
|
+ axes[ncols * 2 + i].set_title(f"NormLIME: first {num_to_show} superpixels")
|
|
|
|
|
|
|
|
# NormLIME*LIME visualization
|
|
# NormLIME*LIME visualization
|
|
|
combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
|
|
combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
|
|
|
self._lime.lime_explainer.local_exp = combined_weights
|
|
self._lime.lime_explainer.local_exp = combined_weights
|
|
|
- for i, w in enumerate(weights_choices):
|
|
|
|
|
- num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
|
|
|
|
|
|
|
+ for i, num_to_show in enumerate(nums_to_show):
|
|
|
temp, mask = self._lime.lime_explainer.get_image_and_mask(
|
|
temp, mask = self._lime.lime_explainer.get_image_and_mask(
|
|
|
l, positive_only=False, hide_rest=False, num_features=num_to_show
|
|
l, positive_only=False, hide_rest=False, num_features=num_to_show
|
|
|
)
|
|
)
|
|
|
axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
|
|
axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
|
|
|
- axes[ncols * 3 + i].set_title(f"label {l}, first {num_to_show} superpixels")
|
|
|
|
|
|
|
+ axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels")
|
|
|
|
|
|
|
|
self._lime.lime_explainer.local_exp = lime_weights
|
|
self._lime.lime_explainer.local_exp = lime_weights
|
|
|
|
|
|
|
@@ -330,14 +331,6 @@ class NormLIME(object):
|
|
|
plt.show()
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
-def load_kmeans_model(fname):
|
|
|
|
|
- import pickle
|
|
|
|
|
- with open(fname, 'rb') as f:
|
|
|
|
|
- kmeans_model = pickle.load(f)
|
|
|
|
|
-
|
|
|
|
|
- return kmeans_model
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
|
|
def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
|
|
|
segments = lime_explainer.segments
|
|
segments = lime_explainer.segments
|
|
|
lime_weights = lime_explainer.local_exp[label]
|
|
lime_weights = lime_explainer.local_exp[label]
|
|
@@ -361,6 +354,9 @@ def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
|
|
|
n = i + 1
|
|
n = i + 1
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
|
|
+ if percentage_to_show <= 0.0:
|
|
|
|
|
+ return 5
|
|
|
|
|
+
|
|
|
if n == 0:
|
|
if n == 0:
|
|
|
return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1)
|
|
return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1)
|
|
|
|
|
|
|
@@ -391,55 +387,6 @@ def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam
|
|
|
return cam
|
|
return cam
|
|
|
|
|
|
|
|
|
|
|
|
|
-def avg_using_superpixels(features, segments):
|
|
|
|
|
- one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
|
|
|
|
|
- for x in np.unique(segments):
|
|
|
|
|
- one_list[x] = np.mean(features[segments == x], axis=0)
|
|
|
|
|
-
|
|
|
|
|
- return one_list
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def centroid_using_superpixels(features, segments):
|
|
|
|
|
- from skimage.measure import regionprops
|
|
|
|
|
- regions = regionprops(segments + 1)
|
|
|
|
|
- one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
|
|
|
|
|
- for i, r in enumerate(regions):
|
|
|
|
|
- one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
|
|
|
|
|
- # print(one_list.shape)
|
|
|
|
|
- return one_list
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def get_feature_for_kmeans(feature_map, segments):
|
|
|
|
|
- from sklearn.preprocessing import normalize
|
|
|
|
|
- centroid_feature = centroid_using_superpixels(feature_map, segments)
|
|
|
|
|
- avg_feature = avg_using_superpixels(feature_map, segments)
|
|
|
|
|
- x = np.concatenate((centroid_feature, avg_feature), axis=-1)
|
|
|
|
|
- x = normalize(x)
|
|
|
|
|
- return x
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def combine_normlime_and_lime(lime_weights, g_weights):
|
|
|
|
|
- pred_labels = lime_weights.keys()
|
|
|
|
|
- combined_weights = {y: [] for y in pred_labels}
|
|
|
|
|
-
|
|
|
|
|
- for y in pred_labels:
|
|
|
|
|
- normlized_lime_weights_y = lime_weights[y]
|
|
|
|
|
- lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y}
|
|
|
|
|
-
|
|
|
|
|
- normlized_g_weight_y = g_weights[y]
|
|
|
|
|
- normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y}
|
|
|
|
|
-
|
|
|
|
|
- combined_weights[y] = [
|
|
|
|
|
- (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
|
|
|
|
|
- for seg_k in lime_weights_dict.keys()
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- combined_weights[y] = sorted(combined_weights[y],
|
|
|
|
|
- key=lambda x: np.abs(x[1]), reverse=True)
|
|
|
|
|
-
|
|
|
|
|
- return combined_weights
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
|
|
def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
|
if isinstance(data_, str):
|
|
if isinstance(data_, str):
|
|
@@ -469,4 +416,4 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
|
|
|
os.path.join(
|
|
os.path.join(
|
|
|
save_outdir, f_out
|
|
save_outdir, f_out
|
|
|
)
|
|
)
|
|
|
- )
|
|
|
|
|
|
|
+ )
|