visualize.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import cv2
  16. import copy
  17. import os.path as osp
  18. import numpy as np
  19. import paddlex as pdx
  20. from .interpretation_predict import interpretation_predict
  21. from .core.interpretation import Interpretation
  22. from .core.normlime_base import precompute_normlime_weights
  23. from .core._session_preparation import gen_user_home
  24. def visualize(img_file,
  25. model,
  26. dataset=None,
  27. algo='lime',
  28. num_samples=3000,
  29. batch_size=50,
  30. save_dir='./'):
  31. """可解释性可视化。
  32. 将模型预测结果的可解释性可视化,支持LIME和NormLIME两种可解释性算法。
  33. LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
  34. 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
  35. 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
  36. 得到每个输入维度的权重,以此来解释模型。
  37. NormLIME则是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测
  38. 试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
  39. 注意:dataset参数只有在algo为"normlime"的情况下才使用,dataset读取的是一个数据集,
  40. 该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
  41. Args:
  42. img_file (str): 预测图像路径。
  43. model (paddlex.cv.models): paddlex中的模型。
  44. dataset (paddlex.datasets): 数据集读取器,默认为None。
  45. algo (str): 可解释性方式,当前可选'lime'和'normlime'。
  46. num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
  47. batch_size (int): 预测数据batch大小,默认为50。
  48. save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
  49. """
  50. assert model.model_type == 'classifier', \
  51. 'Now the interpretation visualize only be supported in classifier!'
  52. if model.status != 'Normal':
  53. raise Exception('The interpretation only can deal with the Normal model')
  54. if not osp.exists(save_dir):
  55. os.makedirs(save_dir)
  56. model.arrange_transforms(
  57. transforms=model.test_transforms, mode='test')
  58. tmp_transforms = copy.deepcopy(model.test_transforms)
  59. tmp_transforms.transforms = tmp_transforms.transforms[:-2]
  60. img = tmp_transforms(img_file)[0]
  61. img = np.around(img).astype('uint8')
  62. img = np.expand_dims(img, axis=0)
  63. interpreter = None
  64. if algo == 'lime':
  65. interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
  66. elif algo == 'normlime':
  67. if dataset is None:
  68. raise Exception('The dataset is None. Cannot implement this kind of interpretation')
  69. interpreter = get_normlime_interpreter(img, model, dataset,
  70. num_samples=num_samples, batch_size=batch_size,
  71. save_dir=save_dir)
  72. else:
  73. raise Exception('The {} interpretation method is not supported yet!'.format(algo))
  74. img_name = osp.splitext(osp.split(img_file)[-1])[0]
  75. interpreter.interpret(img, save_dir=save_dir)
  76. def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
  77. def predict_func(image):
  78. image = image.astype('float32')
  79. for i in range(image.shape[0]):
  80. image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
  81. tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
  82. model.test_transforms.transforms = model.test_transforms.transforms[-2:]
  83. out = interpretation_predict(model, image)
  84. model.test_transforms.transforms = tmp_transforms
  85. return out[0]
  86. labels_name = None
  87. if dataset is not None:
  88. labels_name = dataset.labels
  89. interpreter = Interpretation('lime',
  90. predict_func,
  91. labels_name,
  92. num_samples=num_samples,
  93. batch_size=batch_size)
  94. return interpreter
  95. def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
  96. def precompute_predict_func(image):
  97. image = image.astype('float32')
  98. tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
  99. model.test_transforms.transforms = model.test_transforms.transforms[-2:]
  100. out = interpretation_predict(model, image)
  101. model.test_transforms.transforms = tmp_transforms
  102. return out[0]
  103. def predict_func(image):
  104. image = image.astype('float32')
  105. for i in range(image.shape[0]):
  106. image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
  107. tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
  108. model.test_transforms.transforms = model.test_transforms.transforms[-2:]
  109. out = interpretation_predict(model, image)
  110. model.test_transforms.transforms = tmp_transforms
  111. return out[0]
  112. labels_name = None
  113. if dataset is not None:
  114. labels_name = dataset.labels
  115. root_path = gen_user_home()
  116. root_path = osp.join(root_path, '.paddlex')
  117. pre_models_path = osp.join(root_path, "pre_models")
  118. if not osp.exists(pre_models_path):
  119. if not osp.exists(root_path):
  120. os.makedirs(root_path)
  121. url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
  122. pdx.utils.download_and_decompress(url, path=root_path)
  123. npy_dir = precompute_for_normlime(precompute_predict_func,
  124. dataset,
  125. num_samples=num_samples,
  126. batch_size=batch_size,
  127. save_dir=save_dir)
  128. interpreter = Interpretation('normlime',
  129. predict_func,
  130. labels_name,
  131. num_samples=num_samples,
  132. batch_size=batch_size,
  133. normlime_weights=npy_dir)
  134. return interpreter
  135. def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
  136. image_list = []
  137. for item in dataset.file_list:
  138. image_list.append(item[0])
  139. return precompute_normlime_weights(
  140. image_list,
  141. predict_func,
  142. num_samples=num_samples,
  143. batch_size=batch_size,
  144. save_dir=save_dir)