interpretation_algorithms.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. import time
  18. from . import lime_base
  19. from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, gen_user_home
  20. from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model
  21. from paddlex.interpret.as_data_reader.readers import read_image
  22. import paddlex.utils.logging as logging
  23. import cv2
  24. class CAM(object):
  25. def __init__(self, predict_fn, label_names):
  26. """
  27. Args:
  28. predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255)
  29. output: [
  30. logits [N, num_classes],
  31. feature map before global average pooling [N, num_channels, h_, w_]
  32. ]
  33. """
  34. self.predict_fn = predict_fn
  35. self.label_names = label_names
  36. def preparation_cam(self, data_):
  37. image_show = read_image(data_)
  38. result = self.predict_fn(image_show)
  39. logit = result[0][0]
  40. if abs(np.sum(logit) - 1.0) > 1e-4:
  41. # softmax
  42. logit = logit - np.max(logit)
  43. exp_result = np.exp(logit)
  44. probability = exp_result / np.sum(exp_result)
  45. else:
  46. probability = logit
  47. # only interpret top 1
  48. pred_label = np.argsort(probability)
  49. pred_label = pred_label[-1:]
  50. self.predicted_label = pred_label[0]
  51. self.predicted_probability = probability[pred_label[0]]
  52. self.image = image_show[0]
  53. self.labels = pred_label
  54. fc_weights = paddle_get_fc_weights()
  55. feature_maps = result[1]
  56. l = pred_label[0]
  57. ln = l
  58. if self.label_names is not None:
  59. ln = self.label_names[l]
  60. prob_str = "%.3f" % (probability[pred_label[0]])
  61. logging.info("predicted result: {} with probability {}.".format(
  62. ln, prob_str))
  63. return feature_maps, fc_weights
  64. def interpret(self, data_, visualization=True, save_outdir=None):
  65. feature_maps, fc_weights = self.preparation_cam(data_)
  66. cam = get_cam(self.image, feature_maps, fc_weights,
  67. self.predicted_label)
  68. if visualization or save_outdir is not None:
  69. import matplotlib.pyplot as plt
  70. from skimage.segmentation import mark_boundaries
  71. l = self.labels[0]
  72. ln = l
  73. if self.label_names is not None:
  74. ln = self.label_names[l]
  75. psize = 5
  76. nrows = 1
  77. ncols = 2
  78. plt.close()
  79. f, axes = plt.subplots(
  80. nrows, ncols, figsize=(psize * ncols, psize * nrows))
  81. for ax in axes.ravel():
  82. ax.axis("off")
  83. axes = axes.ravel()
  84. axes[0].imshow(self.image)
  85. prob_str = "{%.3f}" % (self.predicted_probability)
  86. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  87. axes[1].imshow(cam)
  88. axes[1].set_title("CAM")
  89. if save_outdir is not None:
  90. os.makedirs(save_outdir, exist_ok=True)
  91. save_fig(data_, save_outdir, 'cam')
  92. if visualization:
  93. plt.show()
  94. return
  95. class LIME(object):
  96. def __init__(self,
  97. predict_fn,
  98. label_names,
  99. num_samples=3000,
  100. batch_size=50):
  101. """
  102. LIME wrapper. See lime_base.py for the detailed LIME implementation.
  103. Args:
  104. predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME.
  105. num_samples: the number of samples that LIME takes for fitting.
  106. batch_size: batch size for model inference each time.
  107. """
  108. self.num_samples = num_samples
  109. self.batch_size = batch_size
  110. self.predict_fn = predict_fn
  111. self.labels = None
  112. self.image = None
  113. self.lime_interpreter = None
  114. self.label_names = label_names
  115. def preparation_lime(self, data_):
  116. image_show = read_image(data_)
  117. result = self.predict_fn(image_show)
  118. result = result[0] # only one image here.
  119. if abs(np.sum(result) - 1.0) > 1e-4:
  120. # softmax
  121. result = result - np.max(result)
  122. exp_result = np.exp(result)
  123. probability = exp_result / np.sum(exp_result)
  124. else:
  125. probability = result
  126. # only interpret top 1
  127. pred_label = np.argsort(probability)
  128. pred_label = pred_label[-1:]
  129. self.predicted_label = pred_label[0]
  130. self.predicted_probability = probability[pred_label[0]]
  131. self.image = image_show[0]
  132. self.labels = pred_label
  133. l = pred_label[0]
  134. ln = l
  135. if self.label_names is not None:
  136. ln = self.label_names[l]
  137. prob_str = "%.3f" % (probability[pred_label[0]])
  138. logging.info("predicted result: {} with probability {}.".format(
  139. ln, prob_str))
  140. end = time.time()
  141. algo = lime_base.LimeImageInterpreter()
  142. interpreter = algo.interpret_instance(
  143. self.image,
  144. self.predict_fn,
  145. self.labels,
  146. 0,
  147. num_samples=self.num_samples,
  148. batch_size=self.batch_size)
  149. self.lime_interpreter = interpreter
  150. logging.info('lime time: ' + str(time.time() - end) + 's.')
  151. def interpret(self, data_, visualization=True, save_outdir=None):
  152. if self.lime_interpreter is None:
  153. self.preparation_lime(data_)
  154. if visualization or save_outdir is not None:
  155. import matplotlib.pyplot as plt
  156. from skimage.segmentation import mark_boundaries
  157. l = self.labels[0]
  158. ln = l
  159. if self.label_names is not None:
  160. ln = self.label_names[l]
  161. psize = 5
  162. nrows = 2
  163. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  164. ncols = len(weights_choices)
  165. plt.close()
  166. f, axes = plt.subplots(
  167. nrows, ncols, figsize=(psize * ncols, psize * nrows))
  168. for ax in axes.ravel():
  169. ax.axis("off")
  170. axes = axes.ravel()
  171. axes[0].imshow(self.image)
  172. prob_str = "{%.3f}" % (self.predicted_probability)
  173. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  174. axes[1].imshow(
  175. mark_boundaries(self.image, self.lime_interpreter.segments))
  176. axes[1].set_title("superpixel segmentation")
  177. # LIME visualization
  178. for i, w in enumerate(weights_choices):
  179. num_to_show = auto_choose_num_features_to_show(
  180. self.lime_interpreter, l, w)
  181. temp, mask = self.lime_interpreter.get_image_and_mask(
  182. l,
  183. positive_only=False,
  184. hide_rest=False,
  185. num_features=num_to_show)
  186. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  187. axes[ncols + i].set_title(
  188. "label {}, first {} superpixels".format(ln, num_to_show))
  189. if save_outdir is not None:
  190. os.makedirs(save_outdir, exist_ok=True)
  191. save_fig(data_, save_outdir, 'lime', self.num_samples)
  192. if visualization:
  193. plt.show()
  194. return
  195. class NormLIMEStandard(object):
  196. def __init__(self,
  197. predict_fn,
  198. label_names,
  199. num_samples=3000,
  200. batch_size=50,
  201. kmeans_model_for_normlime=None,
  202. normlime_weights=None):
  203. root_path = gen_user_home()
  204. root_path = osp.join(root_path, '.paddlex')
  205. h_pre_models = osp.join(root_path, "pre_models")
  206. if not osp.exists(h_pre_models):
  207. if not osp.exists(root_path):
  208. os.makedirs(root_path)
  209. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  210. pdx.utils.download_and_decompress(url, path=root_path)
  211. h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
  212. if kmeans_model_for_normlime is None:
  213. try:
  214. self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
  215. except:
  216. raise ValueError(
  217. "NormLIME needs the KMeans model, where we provided a default one in "
  218. "pre_models/kmeans_model.pkl.")
  219. else:
  220. logging.debug("Warning: It is *strongly* suggested to use the \
  221. default KMeans model in pre_models/kmeans_model.pkl. \
  222. Use another one will change the final result.")
  223. self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
  224. self.num_samples = num_samples
  225. self.batch_size = batch_size
  226. try:
  227. self.normlime_weights = np.load(
  228. normlime_weights, allow_pickle=True).item()
  229. except:
  230. self.normlime_weights = None
  231. logging.debug(
  232. "Warning: not find the correct precomputed Normlime result.")
  233. self.predict_fn = predict_fn
  234. self.labels = None
  235. self.image = None
  236. self.label_names = label_names
  237. def predict_cluster_labels(self, feature_map, segments):
  238. X = get_feature_for_kmeans(feature_map, segments)
  239. try:
  240. cluster_labels = self.kmeans_model.predict(X)
  241. except AttributeError:
  242. from sklearn.metrics import pairwise_distances_argmin_min
  243. cluster_labels, _ = pairwise_distances_argmin_min(
  244. X, self.kmeans_model.cluster_centers_)
  245. return cluster_labels
  246. def predict_using_normlime_weights(self, pred_labels,
  247. predicted_cluster_labels):
  248. # global weights
  249. g_weights = {y: [] for y in pred_labels}
  250. for y in pred_labels:
  251. cluster_weights_y = self.normlime_weights.get(y, {})
  252. g_weights[y] = [(i, cluster_weights_y.get(k, 0.0))
  253. for i, k in enumerate(predicted_cluster_labels)]
  254. g_weights[y] = sorted(
  255. g_weights[y], key=lambda x: np.abs(x[1]), reverse=True)
  256. return g_weights
  257. def preparation_normlime(self, data_):
  258. self._lime = LIME(self.predict_fn, self.label_names, self.num_samples,
  259. self.batch_size)
  260. self._lime.preparation_lime(data_)
  261. image_show = read_image(data_)
  262. self.predicted_label = self._lime.predicted_label
  263. self.predicted_probability = self._lime.predicted_probability
  264. self.image = image_show[0]
  265. self.labels = self._lime.labels
  266. logging.info('performing NormLIME operations ...')
  267. cluster_labels = self.predict_cluster_labels(
  268. compute_features_for_kmeans(image_show).transpose((1, 2, 0)),
  269. self._lime.lime_interpreter.segments)
  270. g_weights = self.predict_using_normlime_weights(self.labels,
  271. cluster_labels)
  272. return g_weights
  273. def interpret(self, data_, visualization=True, save_outdir=None):
  274. if self.normlime_weights is None:
  275. raise ValueError(
  276. "Not find the correct precomputed NormLIME result. \n"
  277. "\t Try to call compute_normlime_weights() first or load the correct path."
  278. )
  279. g_weights = self.preparation_normlime(data_)
  280. lime_weights = self._lime.lime_interpreter.local_weights
  281. if visualization or save_outdir is not None:
  282. import matplotlib.pyplot as plt
  283. from skimage.segmentation import mark_boundaries
  284. l = self.labels[0]
  285. ln = l
  286. if self.label_names is not None:
  287. ln = self.label_names[l]
  288. psize = 5
  289. nrows = 4
  290. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  291. nums_to_show = []
  292. ncols = len(weights_choices)
  293. plt.close()
  294. f, axes = plt.subplots(
  295. nrows, ncols, figsize=(psize * ncols, psize * nrows))
  296. for ax in axes.ravel():
  297. ax.axis("off")
  298. axes = axes.ravel()
  299. axes[0].imshow(self.image)
  300. prob_str = "{%.3f}" % (self.predicted_probability)
  301. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  302. axes[1].imshow(
  303. mark_boundaries(self.image,
  304. self._lime.lime_interpreter.segments))
  305. axes[1].set_title("superpixel segmentation")
  306. # LIME visualization
  307. for i, w in enumerate(weights_choices):
  308. num_to_show = auto_choose_num_features_to_show(
  309. self._lime.lime_interpreter, l, w)
  310. nums_to_show.append(num_to_show)
  311. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  312. l,
  313. positive_only=False,
  314. hide_rest=False,
  315. num_features=num_to_show)
  316. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  317. axes[ncols + i].set_title("LIME: first {} superpixels".format(
  318. num_to_show))
  319. # NormLIME visualization
  320. self._lime.lime_interpreter.local_weights = g_weights
  321. for i, num_to_show in enumerate(nums_to_show):
  322. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  323. l,
  324. positive_only=False,
  325. hide_rest=False,
  326. num_features=num_to_show)
  327. axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
  328. axes[ncols * 2 + i].set_title(
  329. "NormLIME: first {} superpixels".format(num_to_show))
  330. # NormLIME*LIME visualization
  331. combined_weights = combine_normlime_and_lime(lime_weights,
  332. g_weights)
  333. self._lime.lime_interpreter.local_weights = combined_weights
  334. for i, num_to_show in enumerate(nums_to_show):
  335. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  336. l,
  337. positive_only=False,
  338. hide_rest=False,
  339. num_features=num_to_show)
  340. axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
  341. axes[ncols * 3 + i].set_title(
  342. "Combined: first {} superpixels".format(num_to_show))
  343. self._lime.lime_interpreter.local_weights = lime_weights
  344. if save_outdir is not None:
  345. os.makedirs(save_outdir, exist_ok=True)
  346. save_fig(data_, save_outdir, 'normlime', self.num_samples)
  347. if visualization:
  348. plt.show()
  349. class NormLIME(object):
  350. def __init__(self,
  351. predict_fn,
  352. label_names,
  353. num_samples=3000,
  354. batch_size=50,
  355. kmeans_model_for_normlime=None,
  356. normlime_weights=None):
  357. root_path = gen_user_home()
  358. root_path = osp.join(root_path, '.paddlex')
  359. h_pre_models = osp.join(root_path, "pre_models")
  360. if not osp.exists(h_pre_models):
  361. if not osp.exists(root_path):
  362. os.makedirs(root_path)
  363. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  364. pdx.utils.download_and_decompress(url, path=root_path)
  365. h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
  366. if kmeans_model_for_normlime is None:
  367. try:
  368. self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
  369. except:
  370. raise ValueError(
  371. "NormLIME needs the KMeans model, where we provided a default one in "
  372. "pre_models/kmeans_model.pkl.")
  373. else:
  374. logging.debug("Warning: It is *strongly* suggested to use the \
  375. default KMeans model in pre_models/kmeans_model.pkl. \
  376. Use another one will change the final result.")
  377. self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
  378. self.num_samples = num_samples
  379. self.batch_size = batch_size
  380. try:
  381. self.normlime_weights = np.load(
  382. normlime_weights, allow_pickle=True).item()
  383. except:
  384. self.normlime_weights = None
  385. logging.debug(
  386. "Warning: not find the correct precomputed Normlime result.")
  387. self.predict_fn = predict_fn
  388. self.labels = None
  389. self.image = None
  390. self.label_names = label_names
  391. def predict_cluster_labels(self, feature_map, segments):
  392. X = get_feature_for_kmeans(feature_map, segments)
  393. try:
  394. cluster_labels = self.kmeans_model.predict(X)
  395. except AttributeError:
  396. from sklearn.metrics import pairwise_distances_argmin_min
  397. cluster_labels, _ = pairwise_distances_argmin_min(
  398. X, self.kmeans_model.cluster_centers_)
  399. return cluster_labels
  400. def predict_using_normlime_weights(self, pred_labels,
  401. predicted_cluster_labels):
  402. # global weights
  403. g_weights = {y: [] for y in pred_labels}
  404. for y in pred_labels:
  405. cluster_weights_y = self.normlime_weights.get(y, {})
  406. g_weights[y] = [(i, cluster_weights_y.get(k, 0.0))
  407. for i, k in enumerate(predicted_cluster_labels)]
  408. g_weights[y] = sorted(
  409. g_weights[y], key=lambda x: np.abs(x[1]), reverse=True)
  410. return g_weights
  411. def preparation_normlime(self, data_):
  412. self._lime = LIME(self.predict_fn, self.label_names, self.num_samples,
  413. self.batch_size)
  414. self._lime.preparation_lime(data_)
  415. image_show = read_image(data_)
  416. self.predicted_label = self._lime.predicted_label
  417. self.predicted_probability = self._lime.predicted_probability
  418. self.image = image_show[0]
  419. self.labels = self._lime.labels
  420. logging.info('performing NormLIME operations ...')
  421. cluster_labels = self.predict_cluster_labels(
  422. compute_features_for_kmeans(image_show).transpose((1, 2, 0)),
  423. self._lime.lime_interpreter.segments)
  424. g_weights = self.predict_using_normlime_weights(self.labels,
  425. cluster_labels)
  426. return g_weights
  427. def interpret(self, data_, visualization=True, save_outdir=None):
  428. if self.normlime_weights is None:
  429. raise ValueError(
  430. "Not find the correct precomputed NormLIME result. \n"
  431. "\t Try to call compute_normlime_weights() first or load the correct path."
  432. )
  433. g_weights = self.preparation_normlime(data_)
  434. lime_weights = self._lime.lime_interpreter.local_weights
  435. if visualization or save_outdir is not None:
  436. import matplotlib.pyplot as plt
  437. from skimage.segmentation import mark_boundaries
  438. l = self.labels[0]
  439. ln = l
  440. if self.label_names is not None:
  441. ln = self.label_names[l]
  442. psize = 5
  443. nrows = 4
  444. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  445. nums_to_show = []
  446. ncols = len(weights_choices)
  447. plt.close()
  448. f, axes = plt.subplots(
  449. nrows, ncols, figsize=(psize * ncols, psize * nrows))
  450. for ax in axes.ravel():
  451. ax.axis("off")
  452. axes = axes.ravel()
  453. axes[0].imshow(self.image)
  454. prob_str = "{%.3f}" % (self.predicted_probability)
  455. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  456. axes[1].imshow(
  457. mark_boundaries(self.image,
  458. self._lime.lime_interpreter.segments))
  459. axes[1].set_title("superpixel segmentation")
  460. # LIME visualization
  461. for i, w in enumerate(weights_choices):
  462. num_to_show = auto_choose_num_features_to_show(
  463. self._lime.lime_interpreter, l, w)
  464. nums_to_show.append(num_to_show)
  465. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  466. l,
  467. positive_only=True,
  468. hide_rest=False,
  469. num_features=num_to_show)
  470. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  471. axes[ncols + i].set_title("LIME: first {} superpixels".format(
  472. num_to_show))
  473. # NormLIME visualization
  474. self._lime.lime_interpreter.local_weights = g_weights
  475. for i, num_to_show in enumerate(nums_to_show):
  476. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  477. l,
  478. positive_only=True,
  479. hide_rest=False,
  480. num_features=num_to_show)
  481. axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
  482. axes[ncols * 2 + i].set_title(
  483. "NormLIME: first {} superpixels".format(num_to_show))
  484. # NormLIME*LIME visualization
  485. combined_weights = combine_normlime_and_lime(lime_weights,
  486. g_weights)
  487. self._lime.lime_interpreter.local_weights = combined_weights
  488. for i, num_to_show in enumerate(nums_to_show):
  489. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  490. l,
  491. positive_only=True,
  492. hide_rest=False,
  493. num_features=num_to_show)
  494. axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
  495. axes[ncols * 3 + i].set_title(
  496. "Combined: first {} superpixels".format(num_to_show))
  497. self._lime.lime_interpreter.local_weights = lime_weights
  498. if save_outdir is not None:
  499. os.makedirs(save_outdir, exist_ok=True)
  500. save_fig(data_, save_outdir, 'normlime', self.num_samples)
  501. if visualization:
  502. plt.show()
  503. def auto_choose_num_features_to_show(lime_interpreter, label,
  504. percentage_to_show):
  505. segments = lime_interpreter.segments
  506. lime_weights = lime_interpreter.local_weights[label]
  507. num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[
  508. 1] // len(np.unique(segments)) // 8
  509. # l1 norm with filtered weights.
  510. used_weights = [(tuple_w[0], tuple_w[1])
  511. for i, tuple_w in enumerate(lime_weights)
  512. if tuple_w[1] > 0]
  513. norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)])
  514. normalized_weights = [(tuple_w[0], tuple_w[1] / norm)
  515. for i, tuple_w in enumerate(lime_weights)]
  516. a = 0.0
  517. n = 0
  518. for i, tuple_w in enumerate(normalized_weights):
  519. if tuple_w[1] < 0:
  520. continue
  521. if len(np.where(segments == tuple_w[0])[
  522. 0]) < num_pixels_threshold_in_a_sp:
  523. continue
  524. a += tuple_w[1]
  525. if a > percentage_to_show:
  526. n = i + 1
  527. break
  528. if percentage_to_show <= 0.0:
  529. return 5
  530. if n == 0:
  531. return auto_choose_num_features_to_show(lime_interpreter, label,
  532. percentage_to_show - 0.1)
  533. return n
  534. def get_cam(image_show,
  535. feature_maps,
  536. fc_weights,
  537. label_index,
  538. cam_min=None,
  539. cam_max=None):
  540. _, nc, h, w = feature_maps.shape
  541. cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1)
  542. cam = cam.sum((0, 1))
  543. if cam_min is None:
  544. cam_min = np.min(cam)
  545. if cam_max is None:
  546. cam_max = np.max(cam)
  547. cam = cam - cam_min
  548. cam = cam / cam_max
  549. cam = np.uint8(255 * cam)
  550. cam_img = cv2.resize(
  551. cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
  552. heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET)
  553. heatmap = np.float32(heatmap)
  554. cam = heatmap + np.float32(image_show)
  555. cam = cam / np.max(cam)
  556. return cam
  557. def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
  558. import matplotlib.pyplot as plt
  559. if isinstance(data_, str):
  560. if algorithm_name == 'cam':
  561. f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
  562. else:
  563. f_out = "{}_{}_s{}.png".format(algorithm_name,
  564. data_.split('/')[-1], num_samples)
  565. plt.savefig(os.path.join(save_outdir, f_out))
  566. else:
  567. n = 0
  568. if algorithm_name == 'cam':
  569. f_out = 'cam-{}.png'.format(n)
  570. else:
  571. f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
  572. while os.path.exists(os.path.join(save_outdir, f_out)):
  573. n += 1
  574. if algorithm_name == 'cam':
  575. f_out = 'cam-{}.png'.format(n)
  576. else:
  577. f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
  578. continue
  579. plt.savefig(os.path.join(save_outdir, f_out))
  580. logging.info('The image of intrepretation result save in {}'.format(
  581. os.path.join(save_outdir, f_out)))