lime_base.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. from __future__ import print_function
  15. import numpy as np
  16. import scipy as sp
  17. import sklearn
  18. import sklearn.preprocessing
  19. from skimage.color import gray2rgb
  20. from sklearn.linear_model import Ridge, lars_path
  21. from sklearn.utils import check_random_state
  22. import copy
  23. from functools import partial
  24. from skimage.segmentation import quickshift
  25. from skimage.measure import regionprops
  26. class LimeBase(object):
  27. """Class for learning a locally linear sparse model from perturbed data"""
  28. def __init__(self,
  29. kernel_fn,
  30. verbose=False,
  31. random_state=None):
  32. """Init function
  33. Args:
  34. kernel_fn: function that transforms an array of distances into an
  35. array of proximity values (floats).
  36. verbose: if true, print local prediction values from linear model.
  37. random_state: an integer or numpy.RandomState that will be used to
  38. generate random numbers. If None, the random state will be
  39. initialized using the internal numpy seed.
  40. """
  41. self.kernel_fn = kernel_fn
  42. self.verbose = verbose
  43. self.random_state = check_random_state(random_state)
  44. @staticmethod
  45. def generate_lars_path(weighted_data, weighted_labels):
  46. """Generates the lars path for weighted data.
  47. Args:
  48. weighted_data: data that has been weighted by kernel
  49. weighted_label: labels, weighted by kernel
  50. Returns:
  51. (alphas, coefs), both are arrays corresponding to the
  52. regularization parameter and coefficients, respectively
  53. """
  54. x_vector = weighted_data
  55. alphas, _, coefs = lars_path(x_vector,
  56. weighted_labels,
  57. method='lasso',
  58. verbose=False)
  59. return alphas, coefs
  60. def forward_selection(self, data, labels, weights, num_features):
  61. """Iteratively adds features to the model"""
  62. clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
  63. used_features = []
  64. for _ in range(min(num_features, data.shape[1])):
  65. max_ = -100000000
  66. best = 0
  67. for feature in range(data.shape[1]):
  68. if feature in used_features:
  69. continue
  70. clf.fit(data[:, used_features + [feature]], labels,
  71. sample_weight=weights)
  72. score = clf.score(data[:, used_features + [feature]],
  73. labels,
  74. sample_weight=weights)
  75. if score > max_:
  76. best = feature
  77. max_ = score
  78. used_features.append(best)
  79. return np.array(used_features)
  80. def feature_selection(self, data, labels, weights, num_features, method):
  81. """Selects features for the model. see explain_instance_with_data to
  82. understand the parameters."""
  83. if method == 'none':
  84. return np.array(range(data.shape[1]))
  85. elif method == 'forward_selection':
  86. return self.forward_selection(data, labels, weights, num_features)
  87. elif method == 'highest_weights':
  88. clf = Ridge(alpha=0.01, fit_intercept=True,
  89. random_state=self.random_state)
  90. clf.fit(data, labels, sample_weight=weights)
  91. coef = clf.coef_
  92. if sp.sparse.issparse(data):
  93. coef = sp.sparse.csr_matrix(clf.coef_)
  94. weighted_data = coef.multiply(data[0])
  95. # Note: most efficient to slice the data before reversing
  96. sdata = len(weighted_data.data)
  97. argsort_data = np.abs(weighted_data.data).argsort()
  98. # Edge case where data is more sparse than requested number of feature importances
  99. # In that case, we just pad with zero-valued features
  100. if sdata < num_features:
  101. nnz_indexes = argsort_data[::-1]
  102. indices = weighted_data.indices[nnz_indexes]
  103. num_to_pad = num_features - sdata
  104. indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype)))
  105. indices_set = set(indices)
  106. pad_counter = 0
  107. for i in range(data.shape[1]):
  108. if i not in indices_set:
  109. indices[pad_counter + sdata] = i
  110. pad_counter += 1
  111. if pad_counter >= num_to_pad:
  112. break
  113. else:
  114. nnz_indexes = argsort_data[sdata - num_features:sdata][::-1]
  115. indices = weighted_data.indices[nnz_indexes]
  116. return indices
  117. else:
  118. weighted_data = coef * data[0]
  119. feature_weights = sorted(
  120. zip(range(data.shape[1]), weighted_data),
  121. key=lambda x: np.abs(x[1]),
  122. reverse=True)
  123. return np.array([x[0] for x in feature_weights[:num_features]])
  124. elif method == 'lasso_path':
  125. weighted_data = ((data - np.average(data, axis=0, weights=weights))
  126. * np.sqrt(weights[:, np.newaxis]))
  127. weighted_labels = ((labels - np.average(labels, weights=weights))
  128. * np.sqrt(weights))
  129. nonzero = range(weighted_data.shape[1])
  130. _, coefs = self.generate_lars_path(weighted_data,
  131. weighted_labels)
  132. for i in range(len(coefs.T) - 1, 0, -1):
  133. nonzero = coefs.T[i].nonzero()[0]
  134. if len(nonzero) <= num_features:
  135. break
  136. used_features = nonzero
  137. return used_features
  138. elif method == 'auto':
  139. if num_features <= 6:
  140. n_method = 'forward_selection'
  141. else:
  142. n_method = 'highest_weights'
  143. return self.feature_selection(data, labels, weights,
  144. num_features, n_method)
  145. def explain_instance_with_data(self,
  146. neighborhood_data,
  147. neighborhood_labels,
  148. distances,
  149. label,
  150. num_features,
  151. feature_selection='auto',
  152. model_regressor=None):
  153. """Takes perturbed data, labels and distances, returns explanation.
  154. Args:
  155. neighborhood_data: perturbed data, 2d array. first element is
  156. assumed to be the original data point.
  157. neighborhood_labels: corresponding perturbed labels. should have as
  158. many columns as the number of possible labels.
  159. distances: distances to original data point.
  160. label: label for which we want an explanation
  161. num_features: maximum number of features in explanation
  162. feature_selection: how to select num_features. options are:
  163. 'forward_selection': iteratively add features to the model.
  164. This is costly when num_features is high
  165. 'highest_weights': selects the features that have the highest
  166. product of absolute weight * original data point when
  167. learning with all the features
  168. 'lasso_path': chooses features based on the lasso
  169. regularization path
  170. 'none': uses all features, ignores num_features
  171. 'auto': uses forward_selection if num_features <= 6, and
  172. 'highest_weights' otherwise.
  173. model_regressor: sklearn regressor to use in explanation.
  174. Defaults to Ridge regression if None. Must have
  175. model_regressor.coef_ and 'sample_weight' as a parameter
  176. to model_regressor.fit()
  177. Returns:
  178. (intercept, exp, score, local_pred):
  179. intercept is a float.
  180. exp is a sorted list of tuples, where each tuple (x,y) corresponds
  181. to the feature id (x) and the local weight (y). The list is sorted
  182. by decreasing absolute value of y.
  183. score is the R^2 value of the returned explanation
  184. local_pred is the prediction of the explanation model on the original instance
  185. """
  186. weights = self.kernel_fn(distances)
  187. labels_column = neighborhood_labels[:, label]
  188. used_features = self.feature_selection(neighborhood_data,
  189. labels_column,
  190. weights,
  191. num_features,
  192. feature_selection)
  193. if model_regressor is None:
  194. model_regressor = Ridge(alpha=1, fit_intercept=True,
  195. random_state=self.random_state)
  196. easy_model = model_regressor
  197. easy_model.fit(neighborhood_data[:, used_features],
  198. labels_column, sample_weight=weights)
  199. prediction_score = easy_model.score(
  200. neighborhood_data[:, used_features],
  201. labels_column, sample_weight=weights)
  202. local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
  203. if self.verbose:
  204. print('Intercept', easy_model.intercept_)
  205. print('Prediction_local', local_pred,)
  206. print('Right:', neighborhood_labels[0, label])
  207. return (easy_model.intercept_,
  208. sorted(zip(used_features, easy_model.coef_),
  209. key=lambda x: np.abs(x[1]), reverse=True),
  210. prediction_score, local_pred)
  211. class ImageExplanation(object):
  212. def __init__(self, image, segments):
  213. """Init function.
  214. Args:
  215. image: 3d numpy array
  216. segments: 2d numpy array, with the output from skimage.segmentation
  217. """
  218. self.image = image
  219. self.segments = segments
  220. self.intercept = {}
  221. self.local_exp = {}
  222. self.local_pred = None
  223. def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
  224. num_features=5, min_weight=0.):
  225. """Init function.
  226. Args:
  227. label: label to explain
  228. positive_only: if True, only take superpixels that positively contribute to
  229. the prediction of the label.
  230. negative_only: if True, only take superpixels that negatively contribute to
  231. the prediction of the label. If false, and so is positive_only, then both
  232. negativey and positively contributions will be taken.
  233. Both can't be True at the same time
  234. hide_rest: if True, make the non-explanation part of the return
  235. image gray
  236. num_features: number of superpixels to include in explanation
  237. min_weight: minimum weight of the superpixels to include in explanation
  238. Returns:
  239. (image, mask), where image is a 3d numpy array and mask is a 2d
  240. numpy array that can be used with
  241. skimage.segmentation.mark_boundaries
  242. """
  243. if label not in self.local_exp:
  244. raise KeyError('Label not in explanation')
  245. if positive_only & negative_only:
  246. raise ValueError("Positive_only and negative_only cannot be true at the same time.")
  247. segments = self.segments
  248. image = self.image
  249. exp = self.local_exp[label]
  250. mask = np.zeros(segments.shape, segments.dtype)
  251. if hide_rest:
  252. temp = np.zeros(self.image.shape)
  253. else:
  254. temp = self.image.copy()
  255. if positive_only:
  256. fs = [x[0] for x in exp
  257. if x[1] > 0 and x[1] > min_weight][:num_features]
  258. if negative_only:
  259. fs = [x[0] for x in exp
  260. if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
  261. if positive_only or negative_only:
  262. for f in fs:
  263. temp[segments == f] = image[segments == f].copy()
  264. mask[segments == f] = 1
  265. return temp, mask
  266. else:
  267. for f, w in exp[:num_features]:
  268. if np.abs(w) < min_weight:
  269. continue
  270. c = 0 if w < 0 else 1
  271. mask[segments == f] = -1 if w < 0 else 1
  272. temp[segments == f] = image[segments == f].copy()
  273. temp[segments == f, c] = np.max(image)
  274. return temp, mask
  275. def get_rendered_image(self, label, min_weight=0.005):
  276. """
  277. Args:
  278. label: label to explain
  279. min_weight:
  280. Returns:
  281. image, is a 3d numpy array
  282. """
  283. if label not in self.local_exp:
  284. raise KeyError('Label not in explanation')
  285. from matplotlib import cm
  286. segments = self.segments
  287. image = self.image
  288. exp = self.local_exp[label]
  289. temp = np.zeros_like(image)
  290. weight_max = abs(exp[0][1])
  291. exp = [(f, w/weight_max) for f, w in exp]
  292. exp = sorted(exp, key=lambda x: x[1], reverse=True) # negatives are at last.
  293. cmaps = cm.get_cmap('Spectral')
  294. # sigmoid_space = 1 / (1 + np.exp(-np.linspace(-20, 20, len(exp))))
  295. colors = cmaps(np.linspace(0, 1, len(exp)))
  296. colors = colors[:, :3]
  297. for i, (f, w) in enumerate(exp):
  298. if np.abs(w) < min_weight:
  299. continue
  300. temp[segments == f] = image[segments == f].copy()
  301. temp[segments == f] = colors[i] * 255
  302. return temp
  303. class LimeImageExplainer(object):
  304. """Explains predictions on Image (i.e. matrix) data.
  305. For numerical features, perturb them by sampling from a Normal(0,1) and
  306. doing the inverse operation of mean-centering and scaling, according to the
  307. means and stds in the training data. For categorical features, perturb by
  308. sampling according to the training distribution, and making a binary
  309. feature that is 1 when the value is the same as the instance being
  310. explained."""
  311. def __init__(self, kernel_width=.25, kernel=None, verbose=False,
  312. feature_selection='auto', random_state=None):
  313. """Init function.
  314. Args:
  315. kernel_width: kernel width for the exponential kernel.
  316. If None, defaults to sqrt(number of columns) * 0.75.
  317. kernel: similarity kernel that takes euclidean distances and kernel
  318. width as input and outputs weights in (0,1). If None, defaults to
  319. an exponential kernel.
  320. verbose: if true, print local prediction values from linear model
  321. feature_selection: feature selection method. can be
  322. 'forward_selection', 'lasso_path', 'none' or 'auto'.
  323. See function 'explain_instance_with_data' in lime_base.py for
  324. details on what each of the options does.
  325. random_state: an integer or numpy.RandomState that will be used to
  326. generate random numbers. If None, the random state will be
  327. initialized using the internal numpy seed.
  328. """
  329. kernel_width = float(kernel_width)
  330. if kernel is None:
  331. def kernel(d, kernel_width):
  332. return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
  333. kernel_fn = partial(kernel, kernel_width=kernel_width)
  334. self.random_state = check_random_state(random_state)
  335. self.feature_selection = feature_selection
  336. self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state)
  337. def explain_instance(self, image, classifier_fn, labels=(1,),
  338. hide_color=None,
  339. num_features=100000, num_samples=1000,
  340. batch_size=10,
  341. distance_metric='cosine',
  342. model_regressor=None
  343. ):
  344. """Generates explanations for a prediction.
  345. First, we generate neighborhood data by randomly perturbing features
  346. from the instance (see __data_inverse). We then learn locally weighted
  347. linear models on this neighborhood data to explain each of the classes
  348. in an interpretable way (see lime_base.py).
  349. Args:
  350. image: 3 dimension RGB image. If this is only two dimensional,
  351. we will assume it's a grayscale image and call gray2rgb.
  352. classifier_fn: classifier prediction probability function, which
  353. takes a numpy array and outputs prediction probabilities. For
  354. ScikitClassifiers , this is classifier.predict_proba.
  355. labels: iterable with labels to be explained.
  356. hide_color: TODO
  357. num_features: maximum number of features present in explanation
  358. num_samples: size of the neighborhood to learn the linear model
  359. batch_size: TODO
  360. distance_metric: the distance metric to use for weights.
  361. model_regressor: sklearn regressor to use in explanation. Defaults
  362. to Ridge regression in LimeBase. Must have model_regressor.coef_
  363. and 'sample_weight' as a parameter to model_regressor.fit()
  364. Returns:
  365. An ImageExplanation object (see lime_image.py) with the corresponding
  366. explanations.
  367. """
  368. if len(image.shape) == 2:
  369. image = gray2rgb(image)
  370. try:
  371. segments = quickshift(image, sigma=1)
  372. except ValueError as e:
  373. raise e
  374. self.segments = segments
  375. fudged_image = image.copy()
  376. if hide_color is None:
  377. # if no hide_color, use the mean
  378. for x in np.unique(segments):
  379. mx = np.mean(image[segments == x], axis=0)
  380. fudged_image[segments == x] = mx
  381. elif hide_color == 'avg_from_neighbor':
  382. from scipy.spatial.distance import cdist
  383. n_features = np.unique(segments).shape[0]
  384. regions = regionprops(segments + 1)
  385. centroids = np.zeros((n_features, 2))
  386. for i, x in enumerate(regions):
  387. centroids[i] = np.array(x.centroid)
  388. d = cdist(centroids, centroids, 'sqeuclidean')
  389. for x in np.unique(segments):
  390. # print(np.argmin(d[x]))
  391. a = [image[segments == i] for i in np.argsort(d[x])[1:6]]
  392. mx = np.mean(np.concatenate(a), axis=0)
  393. fudged_image[segments == x] = mx
  394. else:
  395. fudged_image[:] = 0
  396. top = labels
  397. data, labels = self.data_labels(image, fudged_image, segments,
  398. classifier_fn, num_samples,
  399. batch_size=batch_size)
  400. distances = sklearn.metrics.pairwise_distances(
  401. data,
  402. data[0].reshape(1, -1),
  403. metric=distance_metric
  404. ).ravel()
  405. ret_exp = ImageExplanation(image, segments)
  406. for label in top:
  407. (ret_exp.intercept[label],
  408. ret_exp.local_exp[label],
  409. ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data(
  410. data, labels, distances, label, num_features,
  411. model_regressor=model_regressor,
  412. feature_selection=self.feature_selection)
  413. return ret_exp
  414. def data_labels(self,
  415. image,
  416. fudged_image,
  417. segments,
  418. classifier_fn,
  419. num_samples,
  420. batch_size=10):
  421. """Generates images and predictions in the neighborhood of this image.
  422. Args:
  423. image: 3d numpy array, the image
  424. fudged_image: 3d numpy array, image to replace original image when
  425. superpixel is turned off
  426. segments: segmentation of the image
  427. classifier_fn: function that takes a list of images and returns a
  428. matrix of prediction probabilities
  429. num_samples: size of the neighborhood to learn the linear model
  430. batch_size: classifier_fn will be called on batches of this size.
  431. Returns:
  432. A tuple (data, labels), where:
  433. data: dense num_samples * num_superpixels
  434. labels: prediction probabilities matrix
  435. """
  436. n_features = np.unique(segments).shape[0]
  437. data = self.random_state.randint(0, 2, num_samples * n_features) \
  438. .reshape((num_samples, n_features))
  439. labels = []
  440. data[0, :] = 1
  441. imgs = []
  442. for row in data:
  443. temp = copy.deepcopy(image)
  444. zeros = np.where(row == 0)[0]
  445. mask = np.zeros(segments.shape).astype(bool)
  446. for z in zeros:
  447. mask[segments == z] = True
  448. temp[mask] = fudged_image[mask]
  449. imgs.append(temp)
  450. if len(imgs) == batch_size:
  451. preds = classifier_fn(np.array(imgs))
  452. labels.extend(preds)
  453. imgs = []
  454. if len(imgs) > 0:
  455. preds = classifier_fn(np.array(imgs))
  456. labels.extend(preds)
  457. return data, np.array(labels)