explanation.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from .explanation_algorithms import CAM, LIME, NormLIME
  2. class Explanation(object):
  3. """
  4. Base class for all explanation algorithms.
  5. """
  6. def __init__(self, explanation_algorithm_name, predict_fn, **kwargs):
  7. supported_algorithms = {
  8. 'cam': CAM,
  9. 'lime': LIME,
  10. 'normlime': NormLIME
  11. }
  12. self.algorithm_name = explanation_algorithm_name.lower()
  13. assert self.algorithm_name in supported_algorithms.keys()
  14. self.predict_fn = predict_fn
  15. # initialization for the explanation algorithm.
  16. self.explain_algorithm = supported_algorithms[self.algorithm_name](
  17. self.predict_fn, **kwargs
  18. )
  19. def explain(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
  20. """
  21. Args:
  22. data_: data_ can be a path or numpy.ndarray.
  23. visualization: whether to show using matplotlib.
  24. save_to_disk: whether to save the figure in local disk.
  25. save_dir: dir to save figure if save_to_disk is True.
  26. Returns:
  27. """
  28. return self.explain_algorithm.explain(data_, visualization, save_to_disk, save_dir)