lime_base.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. """
  2. Contains abstract functionality for learning locally linear sparse model.
  3. """
  4. from __future__ import print_function
  5. import numpy as np
  6. import scipy as sp
  7. import sklearn
  8. import sklearn.preprocessing
  9. from skimage.color import gray2rgb
  10. from sklearn.linear_model import Ridge, lars_path
  11. from sklearn.utils import check_random_state
  12. import copy
  13. from functools import partial
  14. from skimage.segmentation import quickshift
  15. from skimage.measure import regionprops
  16. class LimeBase(object):
  17. """Class for learning a locally linear sparse model from perturbed data"""
  18. def __init__(self,
  19. kernel_fn,
  20. verbose=False,
  21. random_state=None):
  22. """Init function
  23. Args:
  24. kernel_fn: function that transforms an array of distances into an
  25. array of proximity values (floats).
  26. verbose: if true, print local prediction values from linear model.
  27. random_state: an integer or numpy.RandomState that will be used to
  28. generate random numbers. If None, the random state will be
  29. initialized using the internal numpy seed.
  30. """
  31. self.kernel_fn = kernel_fn
  32. self.verbose = verbose
  33. self.random_state = check_random_state(random_state)
  34. @staticmethod
  35. def generate_lars_path(weighted_data, weighted_labels):
  36. """Generates the lars path for weighted data.
  37. Args:
  38. weighted_data: data that has been weighted by kernel
  39. weighted_label: labels, weighted by kernel
  40. Returns:
  41. (alphas, coefs), both are arrays corresponding to the
  42. regularization parameter and coefficients, respectively
  43. """
  44. x_vector = weighted_data
  45. alphas, _, coefs = lars_path(x_vector,
  46. weighted_labels,
  47. method='lasso',
  48. verbose=False)
  49. return alphas, coefs
  50. def forward_selection(self, data, labels, weights, num_features):
  51. """Iteratively adds features to the model"""
  52. clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
  53. used_features = []
  54. for _ in range(min(num_features, data.shape[1])):
  55. max_ = -100000000
  56. best = 0
  57. for feature in range(data.shape[1]):
  58. if feature in used_features:
  59. continue
  60. clf.fit(data[:, used_features + [feature]], labels,
  61. sample_weight=weights)
  62. score = clf.score(data[:, used_features + [feature]],
  63. labels,
  64. sample_weight=weights)
  65. if score > max_:
  66. best = feature
  67. max_ = score
  68. used_features.append(best)
  69. return np.array(used_features)
  70. def feature_selection(self, data, labels, weights, num_features, method):
  71. """Selects features for the model. see explain_instance_with_data to
  72. understand the parameters."""
  73. if method == 'none':
  74. return np.array(range(data.shape[1]))
  75. elif method == 'forward_selection':
  76. return self.forward_selection(data, labels, weights, num_features)
  77. elif method == 'highest_weights':
  78. clf = Ridge(alpha=0.01, fit_intercept=True,
  79. random_state=self.random_state)
  80. clf.fit(data, labels, sample_weight=weights)
  81. coef = clf.coef_
  82. if sp.sparse.issparse(data):
  83. coef = sp.sparse.csr_matrix(clf.coef_)
  84. weighted_data = coef.multiply(data[0])
  85. # Note: most efficient to slice the data before reversing
  86. sdata = len(weighted_data.data)
  87. argsort_data = np.abs(weighted_data.data).argsort()
  88. # Edge case where data is more sparse than requested number of feature importances
  89. # In that case, we just pad with zero-valued features
  90. if sdata < num_features:
  91. nnz_indexes = argsort_data[::-1]
  92. indices = weighted_data.indices[nnz_indexes]
  93. num_to_pad = num_features - sdata
  94. indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype)))
  95. indices_set = set(indices)
  96. pad_counter = 0
  97. for i in range(data.shape[1]):
  98. if i not in indices_set:
  99. indices[pad_counter + sdata] = i
  100. pad_counter += 1
  101. if pad_counter >= num_to_pad:
  102. break
  103. else:
  104. nnz_indexes = argsort_data[sdata - num_features:sdata][::-1]
  105. indices = weighted_data.indices[nnz_indexes]
  106. return indices
  107. else:
  108. weighted_data = coef * data[0]
  109. feature_weights = sorted(
  110. zip(range(data.shape[1]), weighted_data),
  111. key=lambda x: np.abs(x[1]),
  112. reverse=True)
  113. return np.array([x[0] for x in feature_weights[:num_features]])
  114. elif method == 'lasso_path':
  115. weighted_data = ((data - np.average(data, axis=0, weights=weights))
  116. * np.sqrt(weights[:, np.newaxis]))
  117. weighted_labels = ((labels - np.average(labels, weights=weights))
  118. * np.sqrt(weights))
  119. nonzero = range(weighted_data.shape[1])
  120. _, coefs = self.generate_lars_path(weighted_data,
  121. weighted_labels)
  122. for i in range(len(coefs.T) - 1, 0, -1):
  123. nonzero = coefs.T[i].nonzero()[0]
  124. if len(nonzero) <= num_features:
  125. break
  126. used_features = nonzero
  127. return used_features
  128. elif method == 'auto':
  129. if num_features <= 6:
  130. n_method = 'forward_selection'
  131. else:
  132. n_method = 'highest_weights'
  133. return self.feature_selection(data, labels, weights,
  134. num_features, n_method)
  135. def explain_instance_with_data(self,
  136. neighborhood_data,
  137. neighborhood_labels,
  138. distances,
  139. label,
  140. num_features,
  141. feature_selection='auto',
  142. model_regressor=None):
  143. """Takes perturbed data, labels and distances, returns explanation.
  144. Args:
  145. neighborhood_data: perturbed data, 2d array. first element is
  146. assumed to be the original data point.
  147. neighborhood_labels: corresponding perturbed labels. should have as
  148. many columns as the number of possible labels.
  149. distances: distances to original data point.
  150. label: label for which we want an explanation
  151. num_features: maximum number of features in explanation
  152. feature_selection: how to select num_features. options are:
  153. 'forward_selection': iteratively add features to the model.
  154. This is costly when num_features is high
  155. 'highest_weights': selects the features that have the highest
  156. product of absolute weight * original data point when
  157. learning with all the features
  158. 'lasso_path': chooses features based on the lasso
  159. regularization path
  160. 'none': uses all features, ignores num_features
  161. 'auto': uses forward_selection if num_features <= 6, and
  162. 'highest_weights' otherwise.
  163. model_regressor: sklearn regressor to use in explanation.
  164. Defaults to Ridge regression if None. Must have
  165. model_regressor.coef_ and 'sample_weight' as a parameter
  166. to model_regressor.fit()
  167. Returns:
  168. (intercept, exp, score, local_pred):
  169. intercept is a float.
  170. exp is a sorted list of tuples, where each tuple (x,y) corresponds
  171. to the feature id (x) and the local weight (y). The list is sorted
  172. by decreasing absolute value of y.
  173. score is the R^2 value of the returned explanation
  174. local_pred is the prediction of the explanation model on the original instance
  175. """
  176. weights = self.kernel_fn(distances)
  177. labels_column = neighborhood_labels[:, label]
  178. used_features = self.feature_selection(neighborhood_data,
  179. labels_column,
  180. weights,
  181. num_features,
  182. feature_selection)
  183. if model_regressor is None:
  184. model_regressor = Ridge(alpha=1, fit_intercept=True,
  185. random_state=self.random_state)
  186. easy_model = model_regressor
  187. easy_model.fit(neighborhood_data[:, used_features],
  188. labels_column, sample_weight=weights)
  189. prediction_score = easy_model.score(
  190. neighborhood_data[:, used_features],
  191. labels_column, sample_weight=weights)
  192. local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
  193. if self.verbose:
  194. print('Intercept', easy_model.intercept_)
  195. print('Prediction_local', local_pred,)
  196. print('Right:', neighborhood_labels[0, label])
  197. return (easy_model.intercept_,
  198. sorted(zip(used_features, easy_model.coef_),
  199. key=lambda x: np.abs(x[1]), reverse=True),
  200. prediction_score, local_pred)
  201. class ImageExplanation(object):
  202. def __init__(self, image, segments):
  203. """Init function.
  204. Args:
  205. image: 3d numpy array
  206. segments: 2d numpy array, with the output from skimage.segmentation
  207. """
  208. self.image = image
  209. self.segments = segments
  210. self.intercept = {}
  211. self.local_exp = {}
  212. self.local_pred = None
  213. def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
  214. num_features=5, min_weight=0.):
  215. """Init function.
  216. Args:
  217. label: label to explain
  218. positive_only: if True, only take superpixels that positively contribute to
  219. the prediction of the label.
  220. negative_only: if True, only take superpixels that negatively contribute to
  221. the prediction of the label. If false, and so is positive_only, then both
  222. negativey and positively contributions will be taken.
  223. Both can't be True at the same time
  224. hide_rest: if True, make the non-explanation part of the return
  225. image gray
  226. num_features: number of superpixels to include in explanation
  227. min_weight: minimum weight of the superpixels to include in explanation
  228. Returns:
  229. (image, mask), where image is a 3d numpy array and mask is a 2d
  230. numpy array that can be used with
  231. skimage.segmentation.mark_boundaries
  232. """
  233. if label not in self.local_exp:
  234. raise KeyError('Label not in explanation')
  235. if positive_only & negative_only:
  236. raise ValueError("Positive_only and negative_only cannot be true at the same time.")
  237. segments = self.segments
  238. image = self.image
  239. exp = self.local_exp[label]
  240. mask = np.zeros(segments.shape, segments.dtype)
  241. if hide_rest:
  242. temp = np.zeros(self.image.shape)
  243. else:
  244. temp = self.image.copy()
  245. if positive_only:
  246. fs = [x[0] for x in exp
  247. if x[1] > 0 and x[1] > min_weight][:num_features]
  248. if negative_only:
  249. fs = [x[0] for x in exp
  250. if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
  251. if positive_only or negative_only:
  252. for f in fs:
  253. temp[segments == f] = image[segments == f].copy()
  254. mask[segments == f] = 1
  255. return temp, mask
  256. else:
  257. for f, w in exp[:num_features]:
  258. if np.abs(w) < min_weight:
  259. continue
  260. c = 0 if w < 0 else 1
  261. mask[segments == f] = -1 if w < 0 else 1
  262. temp[segments == f] = image[segments == f].copy()
  263. temp[segments == f, c] = np.max(image)
  264. return temp, mask
  265. def get_rendered_image(self, label, min_weight=0.005):
  266. """
  267. Args:
  268. label: label to explain
  269. min_weight:
  270. Returns:
  271. image, is a 3d numpy array
  272. """
  273. if label not in self.local_exp:
  274. raise KeyError('Label not in explanation')
  275. from matplotlib import cm
  276. segments = self.segments
  277. image = self.image
  278. exp = self.local_exp[label]
  279. temp = np.zeros_like(image)
  280. weight_max = abs(exp[0][1])
  281. exp = [(f, w/weight_max) for f, w in exp]
  282. exp = sorted(exp, key=lambda x: x[1], reverse=True) # negatives are at last.
  283. cmaps = cm.get_cmap('Spectral')
  284. # sigmoid_space = 1 / (1 + np.exp(-np.linspace(-20, 20, len(exp))))
  285. colors = cmaps(np.linspace(0, 1, len(exp)))
  286. colors = colors[:, :3]
  287. for i, (f, w) in enumerate(exp):
  288. if np.abs(w) < min_weight:
  289. continue
  290. temp[segments == f] = image[segments == f].copy()
  291. temp[segments == f] = colors[i] * 255
  292. return temp
  293. class LimeImageExplainer(object):
  294. """Explains predictions on Image (i.e. matrix) data.
  295. For numerical features, perturb them by sampling from a Normal(0,1) and
  296. doing the inverse operation of mean-centering and scaling, according to the
  297. means and stds in the training data. For categorical features, perturb by
  298. sampling according to the training distribution, and making a binary
  299. feature that is 1 when the value is the same as the instance being
  300. explained."""
  301. def __init__(self, kernel_width=.25, kernel=None, verbose=False,
  302. feature_selection='auto', random_state=None):
  303. """Init function.
  304. Args:
  305. kernel_width: kernel width for the exponential kernel.
  306. If None, defaults to sqrt(number of columns) * 0.75.
  307. kernel: similarity kernel that takes euclidean distances and kernel
  308. width as input and outputs weights in (0,1). If None, defaults to
  309. an exponential kernel.
  310. verbose: if true, print local prediction values from linear model
  311. feature_selection: feature selection method. can be
  312. 'forward_selection', 'lasso_path', 'none' or 'auto'.
  313. See function 'explain_instance_with_data' in lime_base.py for
  314. details on what each of the options does.
  315. random_state: an integer or numpy.RandomState that will be used to
  316. generate random numbers. If None, the random state will be
  317. initialized using the internal numpy seed.
  318. """
  319. kernel_width = float(kernel_width)
  320. if kernel is None:
  321. def kernel(d, kernel_width):
  322. return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
  323. kernel_fn = partial(kernel, kernel_width=kernel_width)
  324. self.random_state = check_random_state(random_state)
  325. self.feature_selection = feature_selection
  326. self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
  327. def explain_instance(self, image, classifier_fn, labels=(1,),
  328. hide_color=None,
  329. num_features=100000, num_samples=1000,
  330. batch_size=10,
  331. distance_metric='cosine',
  332. model_regressor=None
  333. ):
  334. """Generates explanations for a prediction.
  335. First, we generate neighborhood data by randomly perturbing features
  336. from the instance (see __data_inverse). We then learn locally weighted
  337. linear models on this neighborhood data to explain each of the classes
  338. in an interpretable way (see lime_base.py).
  339. Args:
  340. image: 3 dimension RGB image. If this is only two dimensional,
  341. we will assume it's a grayscale image and call gray2rgb.
  342. classifier_fn: classifier prediction probability function, which
  343. takes a numpy array and outputs prediction probabilities. For
  344. ScikitClassifiers , this is classifier.predict_proba.
  345. labels: iterable with labels to be explained.
  346. hide_color: TODO
  347. num_features: maximum number of features present in explanation
  348. num_samples: size of the neighborhood to learn the linear model
  349. batch_size: TODO
  350. distance_metric: the distance metric to use for weights.
  351. model_regressor: sklearn regressor to use in explanation. Defaults
  352. to Ridge regression in LimeBase. Must have model_regressor.coef_
  353. and 'sample_weight' as a parameter to model_regressor.fit()
  354. Returns:
  355. An ImageExplanation object (see lime_image.py) with the corresponding
  356. explanations.
  357. """
  358. if len(image.shape) == 2:
  359. image = gray2rgb(image)
  360. try:
  361. segments = quickshift(image, sigma=1)
  362. except ValueError as e:
  363. raise e
  364. self.segments = segments
  365. fudged_image = image.copy()
  366. if hide_color is None:
  367. # if no hide_color, use the mean
  368. for x in np.unique(segments):
  369. mx = np.mean(image[segments == x], axis=0)
  370. fudged_image[segments == x] = mx
  371. elif hide_color == 'avg_from_neighbor':
  372. from scipy.spatial.distance import cdist
  373. n_features = np.unique(segments).shape[0]
  374. regions = regionprops(segments + 1)
  375. centroids = np.zeros((n_features, 2))
  376. for i, x in enumerate(regions):
  377. centroids[i] = np.array(x.centroid)
  378. d = cdist(centroids, centroids, 'sqeuclidean')
  379. for x in np.unique(segments):
  380. # print(np.argmin(d[x]))
  381. a = [image[segments == i] for i in np.argsort(d[x])[1:6]]
  382. mx = np.mean(np.concatenate(a), axis=0)
  383. fudged_image[segments == x] = mx
  384. else:
  385. fudged_image[:] = 0
  386. top = labels
  387. data, labels = self.data_labels(image, fudged_image, segments,
  388. classifier_fn, num_samples,
  389. batch_size=batch_size)
  390. distances = sklearn.metrics.pairwise_distances(
  391. data,
  392. data[0].reshape(1, -1),
  393. metric=distance_metric
  394. ).ravel()
  395. ret_exp = ImageExplanation(image, segments)
  396. for label in top:
  397. (ret_exp.intercept[label],
  398. ret_exp.local_exp[label],
  399. ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data(
  400. data, labels, distances, label, num_features,
  401. model_regressor=model_regressor,
  402. feature_selection=self.feature_selection)
  403. return ret_exp
  404. def data_labels(self,
  405. image,
  406. fudged_image,
  407. segments,
  408. classifier_fn,
  409. num_samples,
  410. batch_size=10):
  411. """Generates images and predictions in the neighborhood of this image.
  412. Args:
  413. image: 3d numpy array, the image
  414. fudged_image: 3d numpy array, image to replace original image when
  415. superpixel is turned off
  416. segments: segmentation of the image
  417. classifier_fn: function that takes a list of images and returns a
  418. matrix of prediction probabilities
  419. num_samples: size of the neighborhood to learn the linear model
  420. batch_size: classifier_fn will be called on batches of this size.
  421. Returns:
  422. A tuple (data, labels), where:
  423. data: dense num_samples * num_superpixels
  424. labels: prediction probabilities matrix
  425. """
  426. n_features = np.unique(segments).shape[0]
  427. data = self.random_state.randint(0, 2, num_samples * n_features) \
  428. .reshape((num_samples, n_features))
  429. labels = []
  430. data[0, :] = 1
  431. imgs = []
  432. for row in data:
  433. temp = copy.deepcopy(image)
  434. zeros = np.where(row == 0)[0]
  435. mask = np.zeros(segments.shape).astype(bool)
  436. for z in zeros:
  437. mask[segments == z] = True
  438. temp[mask] = fudged_image[mask]
  439. imgs.append(temp)
  440. if len(imgs) == batch_size:
  441. preds = classifier_fn(np.array(imgs))
  442. labels.extend(preds)
  443. imgs = []
  444. if len(imgs) > 0:
  445. preds = classifier_fn(np.array(imgs))
  446. labels.extend(preds)
  447. return data, np.array(labels)