normlime_base.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. import glob
  18. import tqdm
  19. from paddlex.interpret.as_data_reader.readers import read_image
  20. import paddlex.utils.logging as logging
  21. from . import lime_base
  22. from ._session_preparation import compute_features_for_kmeans, gen_user_home
  23. import paddlex.utils.logging as logging
  24. def load_kmeans_model(fname):
  25. import pickle
  26. with open(fname, 'rb') as f:
  27. kmeans_model = pickle.load(f)
  28. return kmeans_model
  29. def combine_normlime_and_lime(lime_weights, g_weights):
  30. pred_labels = lime_weights.keys()
  31. combined_weights = {y: [] for y in pred_labels}
  32. for y in pred_labels:
  33. normlized_lime_weights_y = lime_weights[y]
  34. lime_weights_dict = {
  35. tuple_w[0]: tuple_w[1]
  36. for tuple_w in normlized_lime_weights_y
  37. }
  38. normlized_g_weight_y = g_weights[y]
  39. normlime_weights_dict = {
  40. tuple_w[0]: tuple_w[1]
  41. for tuple_w in normlized_g_weight_y
  42. }
  43. combined_weights[y] = [
  44. (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
  45. for seg_k in lime_weights_dict.keys()
  46. ]
  47. combined_weights[y] = sorted(
  48. combined_weights[y], key=lambda x: np.abs(x[1]), reverse=True)
  49. return combined_weights
  50. def avg_using_superpixels(features, segments):
  51. one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
  52. for x in np.unique(segments):
  53. one_list[x] = np.mean(features[segments == x], axis=0)
  54. return one_list
  55. def centroid_using_superpixels(features, segments):
  56. from skimage.measure import regionprops
  57. regions = regionprops(segments + 1)
  58. one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
  59. for i, r in enumerate(regions):
  60. one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] +
  61. 0.5), :]
  62. return one_list
  63. def get_feature_for_kmeans(feature_map, segments):
  64. from sklearn.preprocessing import normalize
  65. centroid_feature = centroid_using_superpixels(feature_map, segments)
  66. avg_feature = avg_using_superpixels(feature_map, segments)
  67. x = np.concatenate((centroid_feature, avg_feature), axis=-1)
  68. x = normalize(x)
  69. return x
  70. def precompute_normlime_weights(list_data_,
  71. predict_fn,
  72. num_samples=3000,
  73. batch_size=50,
  74. save_dir='./tmp'):
  75. # save lime weights and kmeans cluster labels
  76. precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size,
  77. save_dir)
  78. # load precomputed results, compute normlime weights and save.
  79. fname_list = glob.glob(
  80. os.path.join(save_dir, 'lime_weights_s{}*.npy'.format(num_samples)))
  81. return compute_normlime_weights(fname_list, save_dir, num_samples)
  82. def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels,
  83. cluster_labels, save_path):
  84. lime_weights = {}
  85. for label in image_pred_labels:
  86. lime_weights[label] = lime_all_weights[label]
  87. for_normlime_weights = {
  88. 'lime_weights':
  89. lime_weights, # a dict: class_label: (seg_label, weight)
  90. 'cluster': cluster_labels # a list with segments as indices.
  91. }
  92. np.save(save_path, for_normlime_weights)
  93. def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size,
  94. save_dir):
  95. root_path = gen_user_home()
  96. root_path = osp.join(root_path, '.paddlex')
  97. h_pre_models = osp.join(root_path, "pre_models")
  98. if not osp.exists(h_pre_models):
  99. if not osp.exists(root_path):
  100. os.makedirs(root_path)
  101. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  102. pdx.utils.download_and_decompress(url, path=root_path)
  103. h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
  104. kmeans_model = load_kmeans_model(h_pre_models_kmeans)
  105. for data_index, each_data_ in enumerate(list_data_):
  106. if isinstance(each_data_, str):
  107. save_path = "lime_weights_s{}_{}.npy".format(
  108. num_samples, each_data_.split('/')[-1].split('.')[0])
  109. save_path = os.path.join(save_dir, save_path)
  110. else:
  111. save_path = "lime_weights_s{}_{}.npy".format(num_samples,
  112. data_index)
  113. save_path = os.path.join(save_dir, save_path)
  114. if os.path.exists(save_path):
  115. logging.info(
  116. save_path + ' exists, not computing this one.', use_color=True)
  117. continue
  118. img_file_name = each_data_ if isinstance(each_data_,
  119. str) else data_index
  120. logging.info(
  121. 'processing ' + img_file_name + ' [{}/{}]'.format(data_index,
  122. len(list_data_)),
  123. use_color=True)
  124. image_show = read_image(each_data_)
  125. result = predict_fn(image_show)
  126. result = result[0] # only one image here.
  127. if abs(np.sum(result) - 1.0) > 1e-4:
  128. # softmax
  129. exp_result = np.exp(result)
  130. probability = exp_result / np.sum(exp_result)
  131. else:
  132. probability = result
  133. pred_label = np.argsort(probability)[::-1]
  134. # top_k = argmin(top_n) > threshold
  135. threshold = 0.05
  136. top_k = 0
  137. for l in pred_label:
  138. if probability[l] < threshold or top_k == 5:
  139. break
  140. top_k += 1
  141. if top_k == 0:
  142. top_k = 1
  143. pred_label = pred_label[:top_k]
  144. algo = lime_base.LimeImageInterpreter()
  145. interpreter = algo.interpret_instance(
  146. image_show[0],
  147. predict_fn,
  148. pred_label,
  149. 0,
  150. num_samples=num_samples,
  151. batch_size=batch_size)
  152. X = get_feature_for_kmeans(
  153. compute_features_for_kmeans(image_show).transpose((1, 2, 0)),
  154. interpreter.segments)
  155. try:
  156. cluster_labels = kmeans_model.predict(X)
  157. except AttributeError:
  158. from sklearn.metrics import pairwise_distances_argmin_min
  159. cluster_labels, _ = pairwise_distances_argmin_min(
  160. X, kmeans_model.cluster_centers_)
  161. save_one_lime_predict_and_kmean_labels(
  162. interpreter.local_weights, pred_label, cluster_labels, save_path)
  163. def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
  164. normlime_weights_all_labels = {}
  165. for f in a_list_lime_fnames:
  166. try:
  167. lime_weights_and_cluster = np.load(f, allow_pickle=True).item()
  168. lime_weights = lime_weights_and_cluster['lime_weights']
  169. cluster = lime_weights_and_cluster['cluster']
  170. except:
  171. logging.info('When loading precomputed LIME result, skipping' +
  172. str(f))
  173. continue
  174. logging.info('Loading precomputed LIME result,' + str(f))
  175. pred_labels = lime_weights.keys()
  176. for y in pred_labels:
  177. normlime_weights = normlime_weights_all_labels.get(y, {})
  178. w_f_y = [abs(w[1]) for w in lime_weights[y]]
  179. w_f_y_l1norm = sum(w_f_y)
  180. for w in lime_weights[y]:
  181. seg_label = w[0]
  182. weight = w[1] * w[1] / w_f_y_l1norm
  183. a = normlime_weights.get(cluster[seg_label], [])
  184. a.append(weight)
  185. normlime_weights[cluster[seg_label]] = a
  186. normlime_weights_all_labels[y] = normlime_weights
  187. # compute normlime
  188. for y in normlime_weights_all_labels:
  189. normlime_weights = normlime_weights_all_labels.get(y, {})
  190. for k in normlime_weights:
  191. normlime_weights[k] = sum(normlime_weights[k]) / len(
  192. normlime_weights[k])
  193. # check normlime
  194. if len(normlime_weights_all_labels.keys()) < max(
  195. normlime_weights_all_labels.keys()) + 1:
  196. logging.info(
  197. "\n" + \
  198. "Warning: !!! \n" + \
  199. "There are at least {} classes, ".format(max(normlime_weights_all_labels.keys()) + 1) + \
  200. "but the NormLIME has results of only {} classes. \n".format(len(normlime_weights_all_labels.keys())) + \
  201. "It may have cause unstable results in the later computation" + \
  202. " but can be improved by computing more test samples." + \
  203. "\n"
  204. )
  205. n = 0
  206. f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(
  207. lime_num_samples, len(a_list_lime_fnames), n)
  208. while os.path.exists(os.path.join(save_dir, f_out)):
  209. n += 1
  210. f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(
  211. lime_num_samples, len(a_list_lime_fnames), n)
  212. continue
  213. np.save(os.path.join(save_dir, f_out), normlime_weights_all_labels)
  214. return os.path.join(save_dir, f_out)
  215. def precompute_global_classifier(dataset,
  216. predict_fn,
  217. save_path,
  218. batch_size=50,
  219. max_num_samples=1000):
  220. from sklearn.linear_model import LogisticRegression
  221. root_path = gen_user_home()
  222. root_path = osp.join(root_path, '.paddlex')
  223. h_pre_models = osp.join(root_path, "pre_models")
  224. if not osp.exists(h_pre_models):
  225. if not osp.exists(root_path):
  226. os.makedirs(root_path)
  227. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  228. pdx.utils.download_and_decompress(url, path=root_path)
  229. h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
  230. kmeans_model = load_kmeans_model(h_pre_models_kmeans)
  231. image_list = []
  232. for item in dataset.file_list:
  233. image_list.append(item[0])
  234. x_data = []
  235. y_labels = []
  236. num_features = len(kmeans_model.cluster_centers_)
  237. logging.info(
  238. "Initialization for NormLIME: Computing each sample in the test list.",
  239. use_color=True)
  240. for each_data_ in tqdm.tqdm(image_list):
  241. x_data_i = np.zeros((num_features))
  242. image_show = read_image(each_data_)
  243. result = predict_fn(image_show)
  244. result = result[0] # only one image here.
  245. c = compute_features_for_kmeans(image_show).transpose((1, 2, 0))
  246. segments = np.zeros((image_show.shape[1], image_show.shape[2]),
  247. np.int32)
  248. num_blocks = 10
  249. height_per_i = segments.shape[0] // num_blocks + 1
  250. width_per_i = segments.shape[1] // num_blocks + 1
  251. for i in range(segments.shape[0]):
  252. for j in range(segments.shape[1]):
  253. segments[i,
  254. j] = i // height_per_i * num_blocks + j // width_per_i
  255. # segments = quickshift(image_show[0], sigma=1)
  256. X = get_feature_for_kmeans(c, segments)
  257. try:
  258. cluster_labels = kmeans_model.predict(X)
  259. except AttributeError:
  260. from sklearn.metrics import pairwise_distances_argmin_min
  261. cluster_labels, _ = pairwise_distances_argmin_min(
  262. X, kmeans_model.cluster_centers_)
  263. for c in cluster_labels:
  264. x_data_i[c] = 1
  265. # x_data_i /= len(cluster_labels)
  266. pred_y_i = np.argmax(result)
  267. y_labels.append(pred_y_i)
  268. x_data.append(x_data_i)
  269. if len(np.unique(y_labels)) < 2:
  270. logging.info("Warning: The test samples in the dataset is limited.\n \
  271. NormLIME may have no effect on the results.\n \
  272. Try to add more test samples, or see the results of LIME.")
  273. num_classes = np.max(np.unique(y_labels)) + 1
  274. normlime_weights_all_labels = {}
  275. for class_index in range(num_classes):
  276. w = np.ones((num_features)) / num_features
  277. normlime_weights_all_labels[class_index] = {
  278. i: wi
  279. for i, wi in enumerate(w)
  280. }
  281. logging.info("Saving the computed normlime_weights in {}".format(
  282. save_path))
  283. np.save(save_path, normlime_weights_all_labels)
  284. return save_path
  285. clf = LogisticRegression(multi_class='multinomial', max_iter=1000)
  286. clf.fit(x_data, y_labels)
  287. num_classes = np.max(np.unique(y_labels)) + 1
  288. normlime_weights_all_labels = {}
  289. if len(y_labels) / len(np.unique(y_labels)) < 3:
  290. logging.info("Warning: The test samples in the dataset is limited.\n \
  291. NormLIME may have no effect on the results.\n \
  292. Try to add more test samples, or see the results of LIME.")
  293. if len(np.unique(y_labels)) == 2:
  294. # binary: clf.coef_ has shape of [1, num_features]
  295. for class_index in range(num_classes):
  296. if class_index not in clf.classes_:
  297. w = np.ones((num_features)) / num_features
  298. normlime_weights_all_labels[class_index] = {
  299. i: wi
  300. for i, wi in enumerate(w)
  301. }
  302. continue
  303. if clf.classes_[0] == class_index:
  304. w = -clf.coef_[0]
  305. else:
  306. w = clf.coef_[0]
  307. # softmax
  308. w = w - np.max(w)
  309. exp_w = np.exp(w * 10)
  310. w = exp_w / np.sum(exp_w)
  311. normlime_weights_all_labels[class_index] = {
  312. i: wi
  313. for i, wi in enumerate(w)
  314. }
  315. else:
  316. # clf.coef_ has shape of [len(np.unique(y_labels)), num_features]
  317. for class_index in range(num_classes):
  318. if class_index not in clf.classes_:
  319. w = np.ones((num_features)) / num_features
  320. normlime_weights_all_labels[class_index] = {
  321. i: wi
  322. for i, wi in enumerate(w)
  323. }
  324. continue
  325. coef_class_index = np.where(clf.classes_ == class_index)[0][0]
  326. w = clf.coef_[coef_class_index]
  327. # softmax
  328. w = w - np.max(w)
  329. exp_w = np.exp(w * 10)
  330. w = exp_w / np.sum(exp_w)
  331. normlime_weights_all_labels[class_index] = {
  332. i: wi
  333. for i, wi in enumerate(w)
  334. }
  335. logging.info("Saving the computed normlime_weights in {}".format(
  336. save_path))
  337. np.save(save_path, normlime_weights_all_labels)
  338. return save_path