interpretation_algorithms.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. import time
  18. from . import lime_base
  19. from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, gen_user_home
  20. from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model
  21. from paddlex.interpret.as_data_reader.readers import read_image
  22. import paddlex.utils.logging as logging
  23. import cv2
  24. class CAM(object):
  25. def __init__(self, predict_fn, label_names):
  26. """
  27. Args:
  28. predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255)
  29. output: [
  30. logits [N, num_classes],
  31. feature map before global average pooling [N, num_channels, h_, w_]
  32. ]
  33. """
  34. self.predict_fn = predict_fn
  35. self.label_names = label_names
  36. def preparation_cam(self, data_):
  37. image_show = read_image(data_)
  38. result = self.predict_fn(image_show)
  39. logit = result[0][0]
  40. if abs(np.sum(logit) - 1.0) > 1e-4:
  41. # softmax
  42. logit = logit - np.max(logit)
  43. exp_result = np.exp(logit)
  44. probability = exp_result / np.sum(exp_result)
  45. else:
  46. probability = logit
  47. # only interpret top 1
  48. pred_label = np.argsort(probability)
  49. pred_label = pred_label[-1:]
  50. self.predicted_label = pred_label[0]
  51. self.predicted_probability = probability[pred_label[0]]
  52. self.image = image_show[0]
  53. self.labels = pred_label
  54. fc_weights = paddle_get_fc_weights()
  55. feature_maps = result[1]
  56. l = pred_label[0]
  57. ln = l
  58. if self.label_names is not None:
  59. ln = self.label_names[l]
  60. prob_str = "%.3f" % (probability[pred_label[0]])
  61. logging.info("predicted result: {} with probability {}.".format(ln, prob_str))
  62. return feature_maps, fc_weights
  63. def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
  64. feature_maps, fc_weights = self.preparation_cam(data_)
  65. cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
  66. if visualization or save_to_disk:
  67. import matplotlib.pyplot as plt
  68. from skimage.segmentation import mark_boundaries
  69. l = self.labels[0]
  70. ln = l
  71. if self.label_names is not None:
  72. ln = self.label_names[l]
  73. psize = 5
  74. nrows = 1
  75. ncols = 2
  76. plt.close()
  77. f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
  78. for ax in axes.ravel():
  79. ax.axis("off")
  80. axes = axes.ravel()
  81. axes[0].imshow(self.image)
  82. prob_str = "{%.3f}" % (self.predicted_probability)
  83. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  84. axes[1].imshow(cam)
  85. axes[1].set_title("CAM")
  86. if save_to_disk and save_outdir is not None:
  87. os.makedirs(save_outdir, exist_ok=True)
  88. save_fig(data_, save_outdir, 'cam')
  89. if visualization:
  90. plt.show()
  91. return
  92. class LIME(object):
  93. def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50):
  94. """
  95. LIME wrapper. See lime_base.py for the detailed LIME implementation.
  96. Args:
  97. predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME.
  98. num_samples: the number of samples that LIME takes for fitting.
  99. batch_size: batch size for model inference each time.
  100. """
  101. self.num_samples = num_samples
  102. self.batch_size = batch_size
  103. self.predict_fn = predict_fn
  104. self.labels = None
  105. self.image = None
  106. self.lime_interpreter = None
  107. self.label_names = label_names
  108. def preparation_lime(self, data_):
  109. image_show = read_image(data_)
  110. result = self.predict_fn(image_show)
  111. result = result[0] # only one image here.
  112. if abs(np.sum(result) - 1.0) > 1e-4:
  113. # softmax
  114. result = result - np.max(result)
  115. exp_result = np.exp(result)
  116. probability = exp_result / np.sum(exp_result)
  117. else:
  118. probability = result
  119. # only interpret top 1
  120. pred_label = np.argsort(probability)
  121. pred_label = pred_label[-1:]
  122. self.predicted_label = pred_label[0]
  123. self.predicted_probability = probability[pred_label[0]]
  124. self.image = image_show[0]
  125. self.labels = pred_label
  126. l = pred_label[0]
  127. ln = l
  128. if self.label_names is not None:
  129. ln = self.label_names[l]
  130. prob_str = "%.3f" % (probability[pred_label[0]])
  131. logging.info("predicted result: {} with probability {}.".format(ln, prob_str))
  132. end = time.time()
  133. algo = lime_base.LimeImageInterpreter()
  134. interpreter = algo.interpret_instance(self.image, self.predict_fn, self.labels, 0,
  135. num_samples=self.num_samples, batch_size=self.batch_size)
  136. self.lime_interpreter = interpreter
  137. logging.info('lime time: ' + str(time.time() - end) + 's.')
  138. def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
  139. if self.lime_interpreter is None:
  140. self.preparation_lime(data_)
  141. if visualization or save_to_disk:
  142. import matplotlib.pyplot as plt
  143. from skimage.segmentation import mark_boundaries
  144. l = self.labels[0]
  145. ln = l
  146. if self.label_names is not None:
  147. ln = self.label_names[l]
  148. psize = 5
  149. nrows = 2
  150. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  151. ncols = len(weights_choices)
  152. plt.close()
  153. f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
  154. for ax in axes.ravel():
  155. ax.axis("off")
  156. axes = axes.ravel()
  157. axes[0].imshow(self.image)
  158. prob_str = "{%.3f}" % (self.predicted_probability)
  159. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  160. axes[1].imshow(mark_boundaries(self.image, self.lime_interpreter.segments))
  161. axes[1].set_title("superpixel segmentation")
  162. # LIME visualization
  163. for i, w in enumerate(weights_choices):
  164. num_to_show = auto_choose_num_features_to_show(self.lime_interpreter, l, w)
  165. temp, mask = self.lime_interpreter.get_image_and_mask(
  166. l, positive_only=False, hide_rest=False, num_features=num_to_show
  167. )
  168. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  169. axes[ncols + i].set_title("label {}, first {} superpixels".format(ln, num_to_show))
  170. if save_to_disk and save_outdir is not None:
  171. os.makedirs(save_outdir, exist_ok=True)
  172. save_fig(data_, save_outdir, 'lime', self.num_samples)
  173. if visualization:
  174. plt.show()
  175. return
  176. class NormLIME(object):
  177. def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50,
  178. kmeans_model_for_normlime=None, normlime_weights=None):
  179. root_path = gen_user_home()
  180. root_path = osp.join(root_path, '.paddlex')
  181. h_pre_models = osp.join(root_path, "pre_models")
  182. if not osp.exists(h_pre_models):
  183. if not osp.exists(root_path):
  184. os.makedirs(root_path)
  185. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  186. pdx.utils.download_and_decompress(url, path=root_path)
  187. h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
  188. if kmeans_model_for_normlime is None:
  189. try:
  190. self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
  191. except:
  192. raise ValueError("NormLIME needs the KMeans model, where we provided a default one in "
  193. "pre_models/kmeans_model.pkl.")
  194. else:
  195. logging.debug("Warning: It is *strongly* suggested to use the \
  196. default KMeans model in pre_models/kmeans_model.pkl. \
  197. Use another one will change the final result.")
  198. self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
  199. self.num_samples = num_samples
  200. self.batch_size = batch_size
  201. try:
  202. self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
  203. except:
  204. self.normlime_weights = None
  205. logging.debug("Warning: not find the correct precomputed Normlime result.")
  206. self.predict_fn = predict_fn
  207. self.labels = None
  208. self.image = None
  209. self.label_names = label_names
  210. def predict_cluster_labels(self, feature_map, segments):
  211. X = get_feature_for_kmeans(feature_map, segments)
  212. try:
  213. cluster_labels = self.kmeans_model.predict(X)
  214. except AttributeError:
  215. from sklearn.metrics import pairwise_distances_argmin_min
  216. cluster_labels, _ = pairwise_distances_argmin_min(X, self.kmeans_model.cluster_centers_)
  217. return cluster_labels
  218. def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
  219. # global weights
  220. g_weights = {y: [] for y in pred_labels}
  221. for y in pred_labels:
  222. cluster_weights_y = self.normlime_weights.get(y, {})
  223. g_weights[y] = [
  224. (i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
  225. ]
  226. g_weights[y] = sorted(g_weights[y],
  227. key=lambda x: np.abs(x[1]), reverse=True)
  228. return g_weights
  229. def preparation_normlime(self, data_):
  230. self._lime = LIME(
  231. self.predict_fn,
  232. self.label_names,
  233. self.num_samples,
  234. self.batch_size
  235. )
  236. self._lime.preparation_lime(data_)
  237. image_show = read_image(data_)
  238. self.predicted_label = self._lime.predicted_label
  239. self.predicted_probability = self._lime.predicted_probability
  240. self.image = image_show[0]
  241. self.labels = self._lime.labels
  242. logging.info('performing NormLIME operations ...')
  243. cluster_labels = self.predict_cluster_labels(
  244. compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_interpreter.segments
  245. )
  246. g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
  247. return g_weights
  248. def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
  249. if self.normlime_weights is None:
  250. raise ValueError("Not find the correct precomputed NormLIME result. \n"
  251. "\t Try to call compute_normlime_weights() first or load the correct path.")
  252. g_weights = self.preparation_normlime(data_)
  253. lime_weights = self._lime.lime_interpreter.local_weights
  254. if visualization or save_to_disk:
  255. import matplotlib.pyplot as plt
  256. from skimage.segmentation import mark_boundaries
  257. l = self.labels[0]
  258. ln = l
  259. if self.label_names is not None:
  260. ln = self.label_names[l]
  261. psize = 5
  262. nrows = 4
  263. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  264. nums_to_show = []
  265. ncols = len(weights_choices)
  266. plt.close()
  267. f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
  268. for ax in axes.ravel():
  269. ax.axis("off")
  270. axes = axes.ravel()
  271. axes[0].imshow(self.image)
  272. prob_str = "{%.3f}" % (self.predicted_probability)
  273. axes[0].set_title("label {}, proba: {}".format(ln, prob_str))
  274. axes[1].imshow(mark_boundaries(self.image, self._lime.lime_interpreter.segments))
  275. axes[1].set_title("superpixel segmentation")
  276. # LIME visualization
  277. for i, w in enumerate(weights_choices):
  278. num_to_show = auto_choose_num_features_to_show(self._lime.lime_interpreter, l, w)
  279. nums_to_show.append(num_to_show)
  280. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  281. l, positive_only=False, hide_rest=False, num_features=num_to_show
  282. )
  283. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  284. axes[ncols + i].set_title("LIME: first {} superpixels".format(num_to_show))
  285. # NormLIME visualization
  286. self._lime.lime_interpreter.local_weights = g_weights
  287. for i, num_to_show in enumerate(nums_to_show):
  288. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  289. l, positive_only=False, hide_rest=False, num_features=num_to_show
  290. )
  291. axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
  292. axes[ncols * 2 + i].set_title("NormLIME: first {} superpixels".format(num_to_show))
  293. # NormLIME*LIME visualization
  294. combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
  295. self._lime.lime_interpreter.local_weights = combined_weights
  296. for i, num_to_show in enumerate(nums_to_show):
  297. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  298. l, positive_only=False, hide_rest=False, num_features=num_to_show
  299. )
  300. axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
  301. axes[ncols * 3 + i].set_title("Combined: first {} superpixels".format(num_to_show))
  302. self._lime.lime_interpreter.local_weights = lime_weights
  303. if save_to_disk and save_outdir is not None:
  304. os.makedirs(save_outdir, exist_ok=True)
  305. save_fig(data_, save_outdir, 'normlime', self.num_samples)
  306. if visualization:
  307. plt.show()
  308. def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show):
  309. segments = lime_interpreter.segments
  310. lime_weights = lime_interpreter.local_weights[label]
  311. num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
  312. # l1 norm with filtered weights.
  313. used_weights = [(tuple_w[0], tuple_w[1]) for i, tuple_w in enumerate(lime_weights) if tuple_w[1] > 0]
  314. norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)])
  315. normalized_weights = [(tuple_w[0], tuple_w[1] / norm) for i, tuple_w in enumerate(lime_weights)]
  316. a = 0.0
  317. n = 0
  318. for i, tuple_w in enumerate(normalized_weights):
  319. if tuple_w[1] < 0:
  320. continue
  321. if len(np.where(segments == tuple_w[0])[0]) < num_pixels_threshold_in_a_sp:
  322. continue
  323. a += tuple_w[1]
  324. if a > percentage_to_show:
  325. n = i + 1
  326. break
  327. if percentage_to_show <= 0.0:
  328. return 5
  329. if n == 0:
  330. return auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show-0.1)
  331. return n
  332. def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam_max=None):
  333. _, nc, h, w = feature_maps.shape
  334. cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1)
  335. cam = cam.sum((0, 1))
  336. if cam_min is None:
  337. cam_min = np.min(cam)
  338. if cam_max is None:
  339. cam_max = np.max(cam)
  340. cam = cam - cam_min
  341. cam = cam / cam_max
  342. cam = np.uint8(255 * cam)
  343. cam_img = cv2.resize(cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
  344. heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET)
  345. heatmap = np.float32(heatmap)
  346. cam = heatmap + np.float32(image_show)
  347. cam = cam / np.max(cam)
  348. return cam
  349. def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
  350. import matplotlib.pyplot as plt
  351. if isinstance(data_, str):
  352. if algorithm_name == 'cam':
  353. f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
  354. else:
  355. f_out = "{}_{}_s{}.png".format(algorithm_name, data_.split('/')[-1], num_samples)
  356. plt.savefig(
  357. os.path.join(save_outdir, f_out)
  358. )
  359. else:
  360. n = 0
  361. if algorithm_name == 'cam':
  362. f_out = 'cam-{}.png'.format(n)
  363. else:
  364. f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
  365. while os.path.exists(
  366. os.path.join(save_outdir, f_out)
  367. ):
  368. n += 1
  369. if algorithm_name == 'cam':
  370. f_out = 'cam-{}.png'.format(n)
  371. else:
  372. f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
  373. continue
  374. plt.savefig(
  375. os.path.join(
  376. save_outdir, f_out
  377. )
  378. )
  379. logging.info('The image of intrepretation result save in {}'.format(os.path.join(
  380. save_outdir, f_out
  381. )))