visualize.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  2. # See the License for the specific language governing permissions and
  3. # limitations under the License.
  4. import os
  5. import cv2
  6. import copy
  7. import os.path as osp
  8. import numpy as np
  9. from .core.explanation import Explanation
  10. def visualize(img_file,
  11. model,
  12. explanation_type='lime',
  13. num_samples=3000,
  14. batch_size=50,
  15. save_dir='./'):
  16. model.arrange_transforms(
  17. transforms=model.test_transforms, mode='test')
  18. tmp_transforms = copy.deepcopy(model.test_transforms)
  19. tmp_transforms.transforms = tmp_transforms.transforms[:-2]
  20. img = tmp_transforms(img_file)[0]
  21. img = np.around(img).astype('uint8')
  22. img = np.expand_dims(img, axis=0)
  23. explaier = None
  24. if explanation_type == 'lime':
  25. explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size)
  26. else:
  27. raise Exception('The {} explanantion method is not supported yet!'.format(explanation_type))
  28. img_name = osp.splitext(osp.split(img_file)[-1])[0]
  29. explaier.explain(img, save_dir=save_dir)
  30. def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
  31. def predict_func(image):
  32. image = image.astype('float32')
  33. model.test_transforms.transforms = model.test_transforms.transforms[-2:]
  34. out = model.explanation_predict(image)
  35. return out[0]
  36. explaier = Explanation('lime',
  37. predict_func,
  38. num_samples=num_samples,
  39. batch_size=batch_size)
  40. return explaier