normlime_base.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. import glob
  18. from paddlex.interpret.as_data_reader.readers import read_image
  19. import paddlex.utils.logging as logging
  20. from . import lime_base
  21. from ._session_preparation import compute_features_for_kmeans, gen_user_home
  22. import paddlex.utils.logging as logging
  23. def load_kmeans_model(fname):
  24. import pickle
  25. with open(fname, 'rb') as f:
  26. kmeans_model = pickle.load(f)
  27. return kmeans_model
  28. def combine_normlime_and_lime(lime_weights, g_weights):
  29. pred_labels = lime_weights.keys()
  30. combined_weights = {y: [] for y in pred_labels}
  31. for y in pred_labels:
  32. normlized_lime_weights_y = lime_weights[y]
  33. lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y}
  34. normlized_g_weight_y = g_weights[y]
  35. normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y}
  36. combined_weights[y] = [
  37. (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
  38. for seg_k in lime_weights_dict.keys()
  39. ]
  40. combined_weights[y] = sorted(combined_weights[y],
  41. key=lambda x: np.abs(x[1]), reverse=True)
  42. return combined_weights
  43. def avg_using_superpixels(features, segments):
  44. one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
  45. for x in np.unique(segments):
  46. one_list[x] = np.mean(features[segments == x], axis=0)
  47. return one_list
  48. def centroid_using_superpixels(features, segments):
  49. from skimage.measure import regionprops
  50. regions = regionprops(segments + 1)
  51. one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
  52. for i, r in enumerate(regions):
  53. one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
  54. return one_list
  55. def get_feature_for_kmeans(feature_map, segments):
  56. from sklearn.preprocessing import normalize
  57. centroid_feature = centroid_using_superpixels(feature_map, segments)
  58. avg_feature = avg_using_superpixels(feature_map, segments)
  59. x = np.concatenate((centroid_feature, avg_feature), axis=-1)
  60. x = normalize(x)
  61. return x
  62. def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_size=50, save_dir='./tmp'):
  63. # save lime weights and kmeans cluster labels
  64. precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir)
  65. # load precomputed results, compute normlime weights and save.
  66. fname_list = glob.glob(os.path.join(save_dir, 'lime_weights_s{}*.npy'.format(num_samples)))
  67. return compute_normlime_weights(fname_list, save_dir, num_samples)
  68. def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, cluster_labels, save_path):
  69. lime_weights = {}
  70. for label in image_pred_labels:
  71. lime_weights[label] = lime_all_weights[label]
  72. for_normlime_weights = {
  73. 'lime_weights': lime_weights, # a dict: class_label: (seg_label, weight)
  74. 'cluster': cluster_labels # a list with segments as indices.
  75. }
  76. np.save(save_path, for_normlime_weights)
  77. def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir):
  78. root_path = gen_user_home()
  79. root_path = osp.join(root_path, '.paddlex')
  80. h_pre_models = osp.join(root_path, "pre_models")
  81. if not osp.exists(h_pre_models):
  82. if not osp.exists(root_path):
  83. os.makedirs(root_path)
  84. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  85. pdx.utils.download_and_decompress(url, path=root_path)
  86. h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
  87. kmeans_model = load_kmeans_model(h_pre_models_kmeans)
  88. for data_index, each_data_ in enumerate(list_data_):
  89. if isinstance(each_data_, str):
  90. save_path = "lime_weights_s{}_{}.npy".format(num_samples, each_data_.split('/')[-1].split('.')[0])
  91. save_path = os.path.join(save_dir, save_path)
  92. else:
  93. save_path = "lime_weights_s{}_{}.npy".format(num_samples, data_index)
  94. save_path = os.path.join(save_dir, save_path)
  95. if os.path.exists(save_path):
  96. logging.info(save_path + ' exists, not computing this one.', use_color=True)
  97. continue
  98. img_file_name = each_data_ if isinstance(each_data_, str) else data_index
  99. logging.info('processing '+ img_file_name + ' [{}/{}]'.format(data_index, len(list_data_)), use_color=True)
  100. image_show = read_image(each_data_)
  101. result = predict_fn(image_show)
  102. result = result[0] # only one image here.
  103. if abs(np.sum(result) - 1.0) > 1e-4:
  104. # softmax
  105. exp_result = np.exp(result)
  106. probability = exp_result / np.sum(exp_result)
  107. else:
  108. probability = result
  109. pred_label = np.argsort(probability)[::-1]
  110. # top_k = argmin(top_n) > threshold
  111. threshold = 0.05
  112. top_k = 0
  113. for l in pred_label:
  114. if probability[l] < threshold or top_k == 5:
  115. break
  116. top_k += 1
  117. if top_k == 0:
  118. top_k = 1
  119. pred_label = pred_label[:top_k]
  120. algo = lime_base.LimeImageInterpreter()
  121. interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0,
  122. num_samples=num_samples, batch_size=batch_size)
  123. X = get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments)
  124. try:
  125. cluster_labels = kmeans_model.predict(X)
  126. except AttributeError:
  127. from sklearn.metrics import pairwise_distances_argmin_min
  128. cluster_labels, _ = pairwise_distances_argmin_min(X, kmeans_model.cluster_centers_)
  129. save_one_lime_predict_and_kmean_labels(
  130. interpreter.local_weights, pred_label,
  131. cluster_labels,
  132. save_path
  133. )
  134. def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
  135. normlime_weights_all_labels = {}
  136. for f in a_list_lime_fnames:
  137. try:
  138. lime_weights_and_cluster = np.load(f, allow_pickle=True).item()
  139. lime_weights = lime_weights_and_cluster['lime_weights']
  140. cluster = lime_weights_and_cluster['cluster']
  141. except:
  142. logging.info('When loading precomputed LIME result, skipping' + str(f))
  143. continue
  144. logging.info('Loading precomputed LIME result,' + str(f))
  145. pred_labels = lime_weights.keys()
  146. for y in pred_labels:
  147. normlime_weights = normlime_weights_all_labels.get(y, {})
  148. w_f_y = [abs(w[1]) for w in lime_weights[y]]
  149. w_f_y_l1norm = sum(w_f_y)
  150. for w in lime_weights[y]:
  151. seg_label = w[0]
  152. weight = w[1] * w[1] / w_f_y_l1norm
  153. a = normlime_weights.get(cluster[seg_label], [])
  154. a.append(weight)
  155. normlime_weights[cluster[seg_label]] = a
  156. normlime_weights_all_labels[y] = normlime_weights
  157. # compute normlime
  158. for y in normlime_weights_all_labels:
  159. normlime_weights = normlime_weights_all_labels.get(y, {})
  160. for k in normlime_weights:
  161. normlime_weights[k] = sum(normlime_weights[k]) / len(normlime_weights[k])
  162. # check normlime
  163. if len(normlime_weights_all_labels.keys()) < max(normlime_weights_all_labels.keys()) + 1:
  164. logging.info(
  165. "\n" + \
  166. "Warning: !!! \n" + \
  167. "There are at least {} classes, ".format(max(normlime_weights_all_labels.keys()) + 1) + \
  168. "but the NormLIME has results of only {} classes. \n".format(len(normlime_weights_all_labels.keys())) + \
  169. "It may have cause unstable results in the later computation" + \
  170. " but can be improved by computing more test samples." + \
  171. "\n"
  172. )
  173. n = 0
  174. f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n)
  175. while os.path.exists(
  176. os.path.join(save_dir, f_out)
  177. ):
  178. n += 1
  179. f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n)
  180. continue
  181. np.save(
  182. os.path.join(save_dir, f_out),
  183. normlime_weights_all_labels
  184. )
  185. return os.path.join(save_dir, f_out)