visualize.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import cv2
  16. import copy
  17. import os.path as osp
  18. import numpy as np
  19. import paddlex as pdx
  20. from .interpretation_predict import interpretation_predict
  21. from .core.interpretation import Interpretation
  22. from .core.normlime_base import precompute_normlime_weights
  23. from .core._session_preparation import gen_user_home
  24. def lime(img_file,
  25. model,
  26. num_samples=3000,
  27. batch_size=50,
  28. save_dir='./'):
  29. """使用LIME算法将模型预测结果的可解释性可视化。
  30. LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
  31. 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
  32. 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
  33. 得到每个输入维度的权重,以此来解释模型。
  34. 注意:LIME可解释性结果可视化目前只支持分类模型。
  35. Args:
  36. img_file (str): 预测图像路径。
  37. model (paddlex.cv.models): paddlex中的模型。
  38. num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
  39. batch_size (int): 预测数据batch大小,默认为50。
  40. save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
  41. """
  42. assert model.model_type == 'classifier', \
  43. 'Now the interpretation visualize only be supported in classifier!'
  44. if model.status != 'Normal':
  45. raise Exception('The interpretation only can deal with the Normal model')
  46. if not osp.exists(save_dir):
  47. os.makedirs(save_dir)
  48. model.arrange_transforms(
  49. transforms=model.test_transforms, mode='test')
  50. tmp_transforms = copy.deepcopy(model.test_transforms)
  51. tmp_transforms.transforms = tmp_transforms.transforms[:-2]
  52. img = tmp_transforms(img_file)[0]
  53. img = np.around(img).astype('uint8')
  54. img = np.expand_dims(img, axis=0)
  55. interpreter = None
  56. interpreter = get_lime_interpreter(img, model, num_samples=num_samples, batch_size=batch_size)
  57. img_name = osp.splitext(osp.split(img_file)[-1])[0]
  58. interpreter.interpret(img, save_dir=save_dir)
  59. def normlime(img_file,
  60. model,
  61. dataset=None,
  62. num_samples=3000,
  63. batch_size=50,
  64. save_dir='./'):
  65. """使用NormLIME算法将模型预测结果的可解释性可视化。
  66. NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测
  67. 试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
  68. 注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
  69. 注意2:NormLIME可解释性结果可视化目前只支持分类模型。
  70. Args:
  71. img_file (str): 预测图像路径。
  72. model (paddlex.cv.models): paddlex中的模型。
  73. dataset (paddlex.datasets): 数据集读取器,默认为None。
  74. num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
  75. batch_size (int): 预测数据batch大小,默认为50。
  76. save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
  77. """
  78. assert model.model_type == 'classifier', \
  79. 'Now the interpretation visualize only be supported in classifier!'
  80. if model.status != 'Normal':
  81. raise Exception('The interpretation only can deal with the Normal model')
  82. if not osp.exists(save_dir):
  83. os.makedirs(save_dir)
  84. model.arrange_transforms(
  85. transforms=model.test_transforms, mode='test')
  86. tmp_transforms = copy.deepcopy(model.test_transforms)
  87. tmp_transforms.transforms = tmp_transforms.transforms[:-2]
  88. img = tmp_transforms(img_file)[0]
  89. img = np.around(img).astype('uint8')
  90. img = np.expand_dims(img, axis=0)
  91. interpreter = None
  92. if dataset is None:
  93. raise Exception('The dataset is None. Cannot implement this kind of interpretation')
  94. interpreter = get_normlime_interpreter(img, model, dataset,
  95. num_samples=num_samples, batch_size=batch_size,
  96. save_dir=save_dir)
  97. img_name = osp.splitext(osp.split(img_file)[-1])[0]
  98. interpreter.interpret(img, save_dir=save_dir)
  99. def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
  100. def predict_func(image):
  101. image = image.astype('float32')
  102. for i in range(image.shape[0]):
  103. image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
  104. tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
  105. model.test_transforms.transforms = model.test_transforms.transforms[-2:]
  106. out = interpretation_predict(model, image)
  107. model.test_transforms.transforms = tmp_transforms
  108. return out[0]
  109. labels_name = None
  110. if hasattr(model, 'labels'):
  111. labels_name = model.labels
  112. interpreter = Interpretation('lime',
  113. predict_func,
  114. labels_name,
  115. num_samples=num_samples,
  116. batch_size=batch_size)
  117. return interpreter
  118. def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
  119. def precompute_predict_func(image):
  120. image = image.astype('float32')
  121. tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
  122. model.test_transforms.transforms = model.test_transforms.transforms[-2:]
  123. out = interpretation_predict(model, image)
  124. model.test_transforms.transforms = tmp_transforms
  125. return out[0]
  126. def predict_func(image):
  127. image = image.astype('float32')
  128. for i in range(image.shape[0]):
  129. image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
  130. tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
  131. model.test_transforms.transforms = model.test_transforms.transforms[-2:]
  132. out = interpretation_predict(model, image)
  133. model.test_transforms.transforms = tmp_transforms
  134. return out[0]
  135. labels_name = None
  136. if dataset is not None:
  137. labels_name = dataset.labels
  138. root_path = gen_user_home()
  139. root_path = osp.join(root_path, '.paddlex')
  140. pre_models_path = osp.join(root_path, "pre_models")
  141. if not osp.exists(pre_models_path):
  142. if not osp.exists(root_path):
  143. os.makedirs(root_path)
  144. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  145. pdx.utils.download_and_decompress(url, path=root_path)
  146. npy_dir = precompute_for_normlime(precompute_predict_func,
  147. dataset,
  148. num_samples=num_samples,
  149. batch_size=batch_size,
  150. save_dir=save_dir)
  151. interpreter = Interpretation('normlime',
  152. predict_func,
  153. labels_name,
  154. num_samples=num_samples,
  155. batch_size=batch_size,
  156. normlime_weights=npy_dir)
  157. return interpreter
  158. def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
  159. image_list = []
  160. for item in dataset.file_list:
  161. image_list.append(item[0])
  162. return precompute_normlime_weights(
  163. image_list,
  164. predict_func,
  165. num_samples=num_samples,
  166. batch_size=batch_size,
  167. save_dir=save_dir)