interpretation_algorithms.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import numpy as np
  16. import time
  17. from . import lime_base
  18. from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, h_pre_models_kmeans
  19. from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model
  20. from paddlex.interpret.as_data_reader.readers import read_image
  21. import cv2
  22. class CAM(object):
  23. def __init__(self, predict_fn, label_names):
  24. """
  25. Args:
  26. predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255)
  27. output: [
  28. logits [N, num_classes],
  29. feature map before global average pooling [N, num_channels, h_, w_]
  30. ]
  31. """
  32. self.predict_fn = predict_fn
  33. self.label_names = label_names
  34. def preparation_cam(self, data_):
  35. image_show = read_image(data_)
  36. result = self.predict_fn(image_show)
  37. logit = result[0][0]
  38. if abs(np.sum(logit) - 1.0) > 1e-4:
  39. # softmax
  40. logit = logit - np.max(logit)
  41. exp_result = np.exp(logit)
  42. probability = exp_result / np.sum(exp_result)
  43. else:
  44. probability = logit
  45. # only interpret top 1
  46. pred_label = np.argsort(probability)
  47. pred_label = pred_label[-1:]
  48. self.predicted_label = pred_label[0]
  49. self.predicted_probability = probability[pred_label[0]]
  50. self.image = image_show[0]
  51. self.labels = pred_label
  52. fc_weights = paddle_get_fc_weights()
  53. feature_maps = result[1]
  54. l = pred_label[0]
  55. ln = l
  56. if self.label_names is not None:
  57. ln = self.label_names[l]
  58. print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
  59. return feature_maps, fc_weights
  60. def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
  61. feature_maps, fc_weights = self.preparation_cam(data_)
  62. cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
  63. if visualization or save_to_disk:
  64. import matplotlib.pyplot as plt
  65. from skimage.segmentation import mark_boundaries
  66. l = self.labels[0]
  67. ln = l
  68. if self.label_names is not None:
  69. ln = self.label_names[l]
  70. psize = 5
  71. nrows = 1
  72. ncols = 2
  73. plt.close()
  74. f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
  75. for ax in axes.ravel():
  76. ax.axis("off")
  77. axes = axes.ravel()
  78. axes[0].imshow(self.image)
  79. axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
  80. axes[1].imshow(cam)
  81. axes[1].set_title("CAM")
  82. if save_to_disk and save_outdir is not None:
  83. os.makedirs(save_outdir, exist_ok=True)
  84. save_fig(data_, save_outdir, 'cam')
  85. if visualization:
  86. plt.show()
  87. return
  88. class LIME(object):
  89. def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50):
  90. """
  91. LIME wrapper. See lime_base.py for the detailed LIME implementation.
  92. Args:
  93. predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME.
  94. num_samples: the number of samples that LIME takes for fitting.
  95. batch_size: batch size for model inference each time.
  96. """
  97. self.num_samples = num_samples
  98. self.batch_size = batch_size
  99. self.predict_fn = predict_fn
  100. self.labels = None
  101. self.image = None
  102. self.lime_interpreter = None
  103. self.label_names = label_names
  104. def preparation_lime(self, data_):
  105. image_show = read_image(data_)
  106. result = self.predict_fn(image_show)
  107. result = result[0] # only one image here.
  108. if abs(np.sum(result) - 1.0) > 1e-4:
  109. # softmax
  110. result = result - np.max(result)
  111. exp_result = np.exp(result)
  112. probability = exp_result / np.sum(exp_result)
  113. else:
  114. probability = result
  115. # only interpret top 1
  116. pred_label = np.argsort(probability)
  117. pred_label = pred_label[-1:]
  118. self.predicted_label = pred_label[0]
  119. self.predicted_probability = probability[pred_label[0]]
  120. self.image = image_show[0]
  121. self.labels = pred_label
  122. l = pred_label[0]
  123. ln = l
  124. if self.label_names is not None:
  125. ln = self.label_names[l]
  126. print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
  127. end = time.time()
  128. algo = lime_base.LimeImageInterpreter()
  129. interpreter = algo.interpret_instance(self.image, self.predict_fn, self.labels, 0,
  130. num_samples=self.num_samples, batch_size=self.batch_size)
  131. self.lime_interpreter = interpreter
  132. print('lime time: ', time.time() - end, 's.')
  133. def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
  134. if self.lime_interpreter is None:
  135. self.preparation_lime(data_)
  136. if visualization or save_to_disk:
  137. import matplotlib.pyplot as plt
  138. from skimage.segmentation import mark_boundaries
  139. l = self.labels[0]
  140. ln = l
  141. if self.label_names is not None:
  142. ln = self.label_names[l]
  143. psize = 5
  144. nrows = 2
  145. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  146. ncols = len(weights_choices)
  147. plt.close()
  148. f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
  149. for ax in axes.ravel():
  150. ax.axis("off")
  151. axes = axes.ravel()
  152. axes[0].imshow(self.image)
  153. axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
  154. axes[1].imshow(mark_boundaries(self.image, self.lime_interpreter.segments))
  155. axes[1].set_title("superpixel segmentation")
  156. # LIME visualization
  157. for i, w in enumerate(weights_choices):
  158. num_to_show = auto_choose_num_features_to_show(self.lime_interpreter, l, w)
  159. temp, mask = self.lime_interpreter.get_image_and_mask(
  160. l, positive_only=False, hide_rest=False, num_features=num_to_show
  161. )
  162. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  163. axes[ncols + i].set_title(f"label {ln}, first {num_to_show} superpixels")
  164. if save_to_disk and save_outdir is not None:
  165. os.makedirs(save_outdir, exist_ok=True)
  166. save_fig(data_, save_outdir, 'lime', self.num_samples)
  167. if visualization:
  168. plt.show()
  169. return
  170. class NormLIME(object):
  171. def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50,
  172. kmeans_model_for_normlime=None, normlime_weights=None):
  173. if kmeans_model_for_normlime is None:
  174. try:
  175. self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
  176. except:
  177. raise ValueError("NormLIME needs the KMeans model, where we provided a default one in "
  178. "pre_models/kmeans_model.pkl.")
  179. else:
  180. print("Warning: It is *strongly* suggested to use the default KMeans model in pre_models/kmeans_model.pkl. "
  181. "Use another one will change the final result.")
  182. self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
  183. self.num_samples = num_samples
  184. self.batch_size = batch_size
  185. try:
  186. self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
  187. except:
  188. self.normlime_weights = None
  189. print("Warning: not find the correct precomputed Normlime result.")
  190. self.predict_fn = predict_fn
  191. self.labels = None
  192. self.image = None
  193. self.label_names = label_names
  194. def predict_cluster_labels(self, feature_map, segments):
  195. return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments))
  196. def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
  197. # global weights
  198. g_weights = {y: [] for y in pred_labels}
  199. for y in pred_labels:
  200. cluster_weights_y = self.normlime_weights.get(y, {})
  201. g_weights[y] = [
  202. (i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
  203. ]
  204. g_weights[y] = sorted(g_weights[y],
  205. key=lambda x: np.abs(x[1]), reverse=True)
  206. return g_weights
  207. def preparation_normlime(self, data_):
  208. self._lime = LIME(
  209. self.predict_fn,
  210. self.label_names,
  211. self.num_samples,
  212. self.batch_size
  213. )
  214. self._lime.preparation_lime(data_)
  215. image_show = read_image(data_)
  216. self.predicted_label = self._lime.predicted_label
  217. self.predicted_probability = self._lime.predicted_probability
  218. self.image = image_show[0]
  219. self.labels = self._lime.labels
  220. # print(f'predicted result: {self.predicted_label} with probability {self.predicted_probability: .3f}')
  221. print('performing NormLIME operations ...')
  222. cluster_labels = self.predict_cluster_labels(
  223. compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_interpreter.segments
  224. )
  225. g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
  226. return g_weights
  227. def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
  228. if self.normlime_weights is None:
  229. raise ValueError("Not find the correct precomputed NormLIME result. \n"
  230. "\t Try to call compute_normlime_weights() first or load the correct path.")
  231. g_weights = self.preparation_normlime(data_)
  232. lime_weights = self._lime.lime_interpreter.local_weights
  233. if visualization or save_to_disk:
  234. import matplotlib.pyplot as plt
  235. from skimage.segmentation import mark_boundaries
  236. l = self.labels[0]
  237. ln = l
  238. if self.label_names is not None:
  239. ln = self.label_names[l]
  240. psize = 5
  241. nrows = 4
  242. weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
  243. nums_to_show = []
  244. ncols = len(weights_choices)
  245. plt.close()
  246. f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
  247. for ax in axes.ravel():
  248. ax.axis("off")
  249. axes = axes.ravel()
  250. axes[0].imshow(self.image)
  251. axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
  252. axes[1].imshow(mark_boundaries(self.image, self._lime.lime_interpreter.segments))
  253. axes[1].set_title("superpixel segmentation")
  254. # LIME visualization
  255. for i, w in enumerate(weights_choices):
  256. num_to_show = auto_choose_num_features_to_show(self._lime.lime_interpreter, l, w)
  257. nums_to_show.append(num_to_show)
  258. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  259. l, positive_only=False, hide_rest=False, num_features=num_to_show
  260. )
  261. axes[ncols + i].imshow(mark_boundaries(temp, mask))
  262. axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels")
  263. # NormLIME visualization
  264. self._lime.lime_interpreter.local_weights = g_weights
  265. for i, num_to_show in enumerate(nums_to_show):
  266. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  267. l, positive_only=False, hide_rest=False, num_features=num_to_show
  268. )
  269. axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
  270. axes[ncols * 2 + i].set_title(f"NormLIME: first {num_to_show} superpixels")
  271. # NormLIME*LIME visualization
  272. combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
  273. self._lime.lime_interpreter.local_weights = combined_weights
  274. for i, num_to_show in enumerate(nums_to_show):
  275. temp, mask = self._lime.lime_interpreter.get_image_and_mask(
  276. l, positive_only=False, hide_rest=False, num_features=num_to_show
  277. )
  278. axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
  279. axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels")
  280. self._lime.lime_interpreter.local_weights = lime_weights
  281. if save_to_disk and save_outdir is not None:
  282. os.makedirs(save_outdir, exist_ok=True)
  283. save_fig(data_, save_outdir, 'normlime', self.num_samples)
  284. if visualization:
  285. plt.show()
  286. def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show):
  287. segments = lime_interpreter.segments
  288. lime_weights = lime_interpreter.local_weights[label]
  289. num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
  290. # l1 norm with filtered weights.
  291. used_weights = [(tuple_w[0], tuple_w[1]) for i, tuple_w in enumerate(lime_weights) if tuple_w[1] > 0]
  292. norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)])
  293. normalized_weights = [(tuple_w[0], tuple_w[1] / norm) for i, tuple_w in enumerate(lime_weights)]
  294. a = 0.0
  295. n = 0
  296. for i, tuple_w in enumerate(normalized_weights):
  297. if tuple_w[1] < 0:
  298. continue
  299. if len(np.where(segments == tuple_w[0])[0]) < num_pixels_threshold_in_a_sp:
  300. continue
  301. a += tuple_w[1]
  302. if a > percentage_to_show:
  303. n = i + 1
  304. break
  305. if percentage_to_show <= 0.0:
  306. return 5
  307. if n == 0:
  308. return auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show-0.1)
  309. return n
  310. def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam_max=None):
  311. _, nc, h, w = feature_maps.shape
  312. cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1)
  313. cam = cam.sum((0, 1))
  314. if cam_min is None:
  315. cam_min = np.min(cam)
  316. if cam_max is None:
  317. cam_max = np.max(cam)
  318. cam = cam - cam_min
  319. cam = cam / cam_max
  320. cam = np.uint8(255 * cam)
  321. cam_img = cv2.resize(cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
  322. heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET)
  323. heatmap = np.float32(heatmap)
  324. cam = heatmap + np.float32(image_show)
  325. cam = cam / np.max(cam)
  326. return cam
  327. def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
  328. import matplotlib.pyplot as plt
  329. if isinstance(data_, str):
  330. if algorithm_name == 'cam':
  331. f_out = f"{algorithm_name}_{data_.split('/')[-1]}.png"
  332. else:
  333. f_out = f"{algorithm_name}_{data_.split('/')[-1]}_s{num_samples}.png"
  334. plt.savefig(
  335. os.path.join(save_outdir, f_out)
  336. )
  337. else:
  338. n = 0
  339. if algorithm_name == 'cam':
  340. f_out = f'cam-{n}.png'
  341. else:
  342. f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
  343. while os.path.exists(
  344. os.path.join(save_outdir, f_out)
  345. ):
  346. n += 1
  347. if algorithm_name == 'cam':
  348. f_out = f'cam-{n}.png'
  349. else:
  350. f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
  351. continue
  352. plt.savefig(
  353. os.path.join(
  354. save_outdir, f_out
  355. )
  356. )
  357. print('The image of intrepretation result save in {}'.format(os.path.join(
  358. save_outdir, f_out
  359. )))