visualize.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import copy
  17. import os.path as osp
  18. import numpy as np
  19. import paddlex as pdx
  20. from .interpretation_predict import interpretation_predict
  21. from .core.interpretation import Interpretation
  22. from .core.normlime_base import precompute_global_classifier
  23. from .core._session_preparation import gen_user_home
  24. from paddlex.cv.transforms import arrange_transforms
  25. def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
  26. """使用LIME算法将模型预测结果的可解释性可视化。
  27. LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
  28. 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
  29. 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
  30. 得到每个输入维度的权重,以此来解释模型。
  31. 注意:LIME可解释性结果可视化目前只支持分类模型。
  32. Args:
  33. img_file (str): 预测图像路径。
  34. model (paddlex.cv.models): paddlex中的模型。
  35. num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
  36. batch_size (int): 预测数据batch大小,默认为50。
  37. save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
  38. """
  39. assert model.model_type == 'classifier', \
  40. 'Now the interpretation visualize only be supported in classifier!'
  41. if model.status != 'Normal':
  42. raise Exception(
  43. 'The interpretation only can deal with the Normal model')
  44. if not osp.exists(save_dir):
  45. os.makedirs(save_dir)
  46. arrange_transforms(
  47. model.model_type,
  48. model.__class__.__name__,
  49. transforms=model.test_transforms,
  50. mode='test')
  51. tmp_transforms = copy.deepcopy(model.test_transforms)
  52. tmp_transforms.transforms = tmp_transforms.transforms[:-2]
  53. img = tmp_transforms(img_file)[0]
  54. img = np.around(img).astype('uint8')
  55. img = np.expand_dims(img, axis=0)
  56. interpreter = None
  57. interpreter = get_lime_interpreter(
  58. img, model, num_samples=num_samples, batch_size=batch_size)
  59. img_name = osp.splitext(osp.split(img_file)[-1])[0]
  60. interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
  61. def normlime(img_file,
  62. model,
  63. dataset=None,
  64. num_samples=3000,
  65. batch_size=50,
  66. save_dir='./',
  67. normlime_weights_file=None):
  68. """使用NormLIME算法将模型预测结果的可解释性可视化。
  69. NormLIME是利用一定数量的样本来出一个全局的解释。由于NormLIME计算量较大,此处采用一种简化的方式:
  70. 使用一定数量的测试样本(目前默认使用所有测试样本),对每个样本进行特征提取,映射到同一个特征空间;
  71. 然后以此特征做为输入,以模型输出做为输出,使用线性回归对其进行拟合,得到一个全局的输入和输出的关系。
  72. 之后,对一测试样本进行解释时,使用NormLIME全局的解释,来对LIME的结果进行滤波,使最终的可视化结果更加稳定。
  73. 注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
  74. 注意2:NormLIME可解释性结果可视化目前只支持分类模型。
  75. Args:
  76. img_file (str): 预测图像路径。
  77. model (paddlex.cv.models): paddlex中的模型。
  78. dataset (paddlex.datasets): 数据集读取器,默认为None。
  79. num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
  80. batch_size (int): 预测数据batch大小,默认为50。
  81. save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
  82. normlime_weights_file (str): NormLIME初始化文件名,若不存在,则计算一次,保存于该路径;若存在,则直接载入。
  83. """
  84. assert model.model_type == 'classifier', \
  85. 'Now the interpretation visualize only be supported in classifier!'
  86. if model.status != 'Normal':
  87. raise Exception(
  88. 'The interpretation only can deal with the Normal model')
  89. if not osp.exists(save_dir):
  90. os.makedirs(save_dir)
  91. arrange_transforms(
  92. model.model_type,
  93. model.__class__.__name__,
  94. transforms=model.test_transforms,
  95. mode='test')
  96. tmp_transforms = copy.deepcopy(model.test_transforms)
  97. tmp_transforms.transforms = tmp_transforms.transforms[:-2]
  98. img = tmp_transforms(img_file)[0]
  99. img = np.around(img).astype('uint8')
  100. img = np.expand_dims(img, axis=0)
  101. interpreter = None
  102. if dataset is None:
  103. raise Exception(
  104. 'The dataset is None. Cannot implement this kind of interpretation')
  105. interpreter = get_normlime_interpreter(
  106. img,
  107. model,
  108. dataset,
  109. num_samples=num_samples,
  110. batch_size=batch_size,
  111. save_dir=save_dir,
  112. normlime_weights_file=normlime_weights_file)
  113. img_name = osp.splitext(osp.split(img_file)[-1])[0]
  114. interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
  115. def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
  116. def predict_func(image):
  117. out = interpretation_predict(model, image)
  118. return out[0]
  119. labels_name = None
  120. if hasattr(model, 'labels'):
  121. labels_name = model.labels
  122. interpreter = Interpretation(
  123. 'lime',
  124. predict_func,
  125. labels_name,
  126. num_samples=num_samples,
  127. batch_size=batch_size)
  128. return interpreter
  129. def get_normlime_interpreter(img,
  130. model,
  131. dataset,
  132. num_samples=3000,
  133. batch_size=50,
  134. save_dir='./',
  135. normlime_weights_file=None):
  136. def predict_func(image):
  137. out = interpretation_predict(model, image)
  138. return out[0]
  139. labels_name = None
  140. if dataset is not None:
  141. labels_name = dataset.labels
  142. root_path = gen_user_home()
  143. root_path = osp.join(root_path, '.paddlex')
  144. pre_models_path = osp.join(root_path, "pre_models")
  145. if not osp.exists(pre_models_path):
  146. if not osp.exists(root_path):
  147. os.makedirs(root_path)
  148. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  149. pdx.utils.download_and_decompress(url, path=root_path)
  150. if osp.exists(osp.join(save_dir, normlime_weights_file)):
  151. normlime_weights_file = osp.join(save_dir, normlime_weights_file)
  152. try:
  153. np.load(normlime_weights_file, allow_pickle=True).item()
  154. except:
  155. normlime_weights_file = precompute_global_classifier(
  156. dataset,
  157. predict_func,
  158. save_path=normlime_weights_file,
  159. batch_size=batch_size)
  160. else:
  161. normlime_weights_file = precompute_global_classifier(
  162. dataset,
  163. predict_func,
  164. save_path=osp.join(save_dir, normlime_weights_file),
  165. batch_size=batch_size)
  166. interpreter = Interpretation(
  167. 'normlime',
  168. predict_func,
  169. labels_name,
  170. num_samples=num_samples,
  171. batch_size=batch_size,
  172. normlime_weights=normlime_weights_file)
  173. return interpreter