lime_base.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. """
  2. Copyright (c) 2016, Marco Tulio Correia Ribeiro
  3. All rights reserved.
  4. Redistribution and use in source and binary forms, with or without
  5. modification, are permitted provided that the following conditions are met:
  6. * Redistributions of source code must retain the above copyright notice, this
  7. list of conditions and the following disclaimer.
  8. * Redistributions in binary form must reproduce the above copyright notice,
  9. this list of conditions and the following disclaimer in the documentation
  10. and/or other materials provided with the distribution.
  11. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  12. AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  13. IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  14. DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  15. FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  16. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  17. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  18. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  19. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  20. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  21. """
  22. """
  23. The code in this file (lime_base.py) is modified from https://github.com/marcotcr/lime.
  24. """
  25. import numpy as np
  26. import scipy as sp
  27. import tqdm
  28. import copy
  29. from functools import partial
  30. import paddlex.utils.logging as logging
  31. class LimeBase(object):
  32. """Class for learning a locally linear sparse model from perturbed data"""
  33. def __init__(self, kernel_fn, verbose=False, random_state=None):
  34. """Init function
  35. Args:
  36. kernel_fn: function that transforms an array of distances into an
  37. array of proximity values (floats).
  38. verbose: if true, print local prediction values from linear model.
  39. random_state: an integer or numpy.RandomState that will be used to
  40. generate random numbers. If None, the random state will be
  41. initialized using the internal numpy seed.
  42. """
  43. from sklearn.utils import check_random_state
  44. self.kernel_fn = kernel_fn
  45. self.verbose = verbose
  46. self.random_state = check_random_state(random_state)
  47. @staticmethod
  48. def generate_lars_path(weighted_data, weighted_labels):
  49. """Generates the lars path for weighted data.
  50. Args:
  51. weighted_data: data that has been weighted by kernel
  52. weighted_label: labels, weighted by kernel
  53. Returns:
  54. (alphas, coefs), both are arrays corresponding to the
  55. regularization parameter and coefficients, respectively
  56. """
  57. from sklearn.linear_model import lars_path
  58. x_vector = weighted_data
  59. alphas, _, coefs = lars_path(
  60. x_vector, weighted_labels, method='lasso', verbose=False)
  61. return alphas, coefs
  62. def forward_selection(self, data, labels, weights, num_features):
  63. """Iteratively adds features to the model"""
  64. clf = Ridge(
  65. alpha=0, fit_intercept=True, random_state=self.random_state)
  66. used_features = []
  67. for _ in range(min(num_features, data.shape[1])):
  68. max_ = -100000000
  69. best = 0
  70. for feature in range(data.shape[1]):
  71. if feature in used_features:
  72. continue
  73. clf.fit(data[:, used_features + [feature]],
  74. labels,
  75. sample_weight=weights)
  76. score = clf.score(
  77. data[:, used_features + [feature]],
  78. labels,
  79. sample_weight=weights)
  80. if score > max_:
  81. best = feature
  82. max_ = score
  83. used_features.append(best)
  84. return np.array(used_features)
  85. def feature_selection(self, data, labels, weights, num_features, method):
  86. """Selects features for the model. see interpret_instance_with_data to
  87. understand the parameters."""
  88. from sklearn.linear_model import Ridge
  89. if method == 'none':
  90. return np.array(range(data.shape[1]))
  91. elif method == 'forward_selection':
  92. return self.forward_selection(data, labels, weights, num_features)
  93. elif method == 'highest_weights':
  94. clf = Ridge(
  95. alpha=0.01, fit_intercept=True, random_state=self.random_state)
  96. clf.fit(data, labels, sample_weight=weights)
  97. coef = clf.coef_
  98. if sp.sparse.issparse(data):
  99. coef = sp.sparse.csr_matrix(clf.coef_)
  100. weighted_data = coef.multiply(data[0])
  101. # Note: most efficient to slice the data before reversing
  102. sdata = len(weighted_data.data)
  103. argsort_data = np.abs(weighted_data.data).argsort()
  104. # Edge case where data is more sparse than requested number of feature importances
  105. # In that case, we just pad with zero-valued features
  106. if sdata < num_features:
  107. nnz_indexes = argsort_data[::-1]
  108. indices = weighted_data.indices[nnz_indexes]
  109. num_to_pad = num_features - sdata
  110. indices = np.concatenate((indices, np.zeros(
  111. num_to_pad, dtype=indices.dtype)))
  112. indices_set = set(indices)
  113. pad_counter = 0
  114. for i in range(data.shape[1]):
  115. if i not in indices_set:
  116. indices[pad_counter + sdata] = i
  117. pad_counter += 1
  118. if pad_counter >= num_to_pad:
  119. break
  120. else:
  121. nnz_indexes = argsort_data[sdata - num_features:sdata][::
  122. -1]
  123. indices = weighted_data.indices[nnz_indexes]
  124. return indices
  125. else:
  126. weighted_data = coef * data[0]
  127. feature_weights = sorted(
  128. zip(range(data.shape[1]), weighted_data),
  129. key=lambda x: np.abs(x[1]),
  130. reverse=True)
  131. return np.array([x[0] for x in feature_weights[:num_features]])
  132. elif method == 'lasso_path':
  133. weighted_data = ((data - np.average(
  134. data, axis=0, weights=weights)) *
  135. np.sqrt(weights[:, np.newaxis]))
  136. weighted_labels = ((labels - np.average(
  137. labels, weights=weights)) * np.sqrt(weights))
  138. nonzero = range(weighted_data.shape[1])
  139. _, coefs = self.generate_lars_path(weighted_data, weighted_labels)
  140. for i in range(len(coefs.T) - 1, 0, -1):
  141. nonzero = coefs.T[i].nonzero()[0]
  142. if len(nonzero) <= num_features:
  143. break
  144. used_features = nonzero
  145. return used_features
  146. elif method == 'auto':
  147. if num_features <= 6:
  148. n_method = 'forward_selection'
  149. else:
  150. n_method = 'highest_weights'
  151. return self.feature_selection(data, labels, weights, num_features,
  152. n_method)
  153. def interpret_instance_with_data(self,
  154. neighborhood_data,
  155. neighborhood_labels,
  156. distances,
  157. label,
  158. num_features,
  159. feature_selection='auto',
  160. model_regressor=None):
  161. """Takes perturbed data, labels and distances, returns interpretation.
  162. Args:
  163. neighborhood_data: perturbed data, 2d array. first element is
  164. assumed to be the original data point.
  165. neighborhood_labels: corresponding perturbed labels. should have as
  166. many columns as the number of possible labels.
  167. distances: distances to original data point.
  168. label: label for which we want an interpretation
  169. num_features: maximum number of features in interpretation
  170. feature_selection: how to select num_features. options are:
  171. 'forward_selection': iteratively add features to the model.
  172. This is costly when num_features is high
  173. 'highest_weights': selects the features that have the highest
  174. product of absolute weight * original data point when
  175. learning with all the features
  176. 'lasso_path': chooses features based on the lasso
  177. regularization path
  178. 'none': uses all features, ignores num_features
  179. 'auto': uses forward_selection if num_features <= 6, and
  180. 'highest_weights' otherwise.
  181. model_regressor: sklearn regressor to use in interpretation.
  182. Defaults to Ridge regression if None. Must have
  183. model_regressor.coef_ and 'sample_weight' as a parameter
  184. to model_regressor.fit()
  185. Returns:
  186. (intercept, exp, score, local_pred):
  187. intercept is a float.
  188. exp is a sorted list of tuples, where each tuple (x,y) corresponds
  189. to the feature id (x) and the local weight (y). The list is sorted
  190. by decreasing absolute value of y.
  191. score is the R^2 value of the returned interpretation
  192. local_pred is the prediction of the interpretation model on the original instance
  193. """
  194. from sklearn.linear_model import Ridge
  195. weights = self.kernel_fn(distances)
  196. labels_column = neighborhood_labels[:, label]
  197. used_features = self.feature_selection(neighborhood_data,
  198. labels_column, weights,
  199. num_features, feature_selection)
  200. if model_regressor is None:
  201. model_regressor = Ridge(
  202. alpha=1, fit_intercept=True, random_state=self.random_state)
  203. easy_model = model_regressor
  204. easy_model.fit(neighborhood_data[:, used_features],
  205. labels_column,
  206. sample_weight=weights)
  207. prediction_score = easy_model.score(
  208. neighborhood_data[:, used_features],
  209. labels_column,
  210. sample_weight=weights)
  211. local_pred = easy_model.predict(neighborhood_data[0, used_features]
  212. .reshape(1, -1))
  213. if self.verbose:
  214. logging.info('Intercept' + str(easy_model.intercept_))
  215. logging.info('Prediction_local' + str(local_pred))
  216. logging.info('Right:' + str(neighborhood_labels[0, label]))
  217. return (easy_model.intercept_, sorted(
  218. zip(used_features, easy_model.coef_),
  219. key=lambda x: np.abs(x[1]),
  220. reverse=True), prediction_score, local_pred)
  221. class ImageInterpretation(object):
  222. def __init__(self, image, segments):
  223. """Init function.
  224. Args:
  225. image: 3d numpy array
  226. segments: 2d numpy array, with the output from skimage.segmentation
  227. """
  228. self.image = image
  229. self.segments = segments
  230. self.intercept = {}
  231. self.local_weights = {}
  232. self.local_pred = None
  233. def get_image_and_mask(self,
  234. label,
  235. positive_only=True,
  236. negative_only=False,
  237. hide_rest=False,
  238. num_features=5,
  239. min_weight=0.):
  240. """Init function.
  241. Args:
  242. label: label to interpret
  243. positive_only: if True, only take superpixels that positively contribute to
  244. the prediction of the label.
  245. negative_only: if True, only take superpixels that negatively contribute to
  246. the prediction of the label. If false, and so is positive_only, then both
  247. negativey and positively contributions will be taken.
  248. Both can't be True at the same time
  249. hide_rest: if True, make the non-interpretation part of the return
  250. image gray
  251. num_features: number of superpixels to include in interpretation
  252. min_weight: minimum weight of the superpixels to include in interpretation
  253. Returns:
  254. (image, mask), where image is a 3d numpy array and mask is a 2d
  255. numpy array that can be used with
  256. skimage.segmentation.mark_boundaries
  257. """
  258. if label not in self.local_weights:
  259. raise KeyError('Label not in interpretation')
  260. if positive_only & negative_only:
  261. raise ValueError(
  262. "Positive_only and negative_only cannot be true at the same time."
  263. )
  264. segments = self.segments
  265. image = self.image
  266. local_weights_label = self.local_weights[label]
  267. mask = np.zeros(segments.shape, segments.dtype)
  268. if hide_rest:
  269. temp = np.zeros(self.image.shape)
  270. else:
  271. temp = self.image.copy()
  272. if positive_only:
  273. fs = [
  274. x[0] for x in local_weights_label
  275. if x[1] > 0 and x[1] > min_weight
  276. ][:num_features]
  277. if negative_only:
  278. fs = [
  279. x[0] for x in local_weights_label
  280. if x[1] < 0 and abs(x[1]) > min_weight
  281. ][:num_features]
  282. if positive_only or negative_only:
  283. c = 1 if positive_only else 0
  284. for f in fs:
  285. temp[segments == f] = [0, 255, 0]
  286. # temp[segments == f, c] = np.max(image)
  287. mask[segments == f] = 1
  288. return temp, mask
  289. else:
  290. for f, w in local_weights_label[:num_features]:
  291. if np.abs(w) < min_weight:
  292. continue
  293. c = 0 if w < 0 else 1
  294. mask[segments == f] = -1 if w < 0 else 1
  295. temp[segments == f] = image[segments == f].copy()
  296. temp[segments == f, c] = np.max(image)
  297. return temp, mask
  298. def get_rendered_image(self, label, min_weight=0.005):
  299. """
  300. Args:
  301. label: label to interpret
  302. min_weight:
  303. Returns:
  304. image, is a 3d numpy array
  305. """
  306. if label not in self.local_weights:
  307. raise KeyError('Label not in interpretation')
  308. from matplotlib import cm
  309. segments = self.segments
  310. image = self.image
  311. local_weights_label = self.local_weights[label]
  312. temp = np.zeros_like(image)
  313. weight_max = abs(local_weights_label[0][1])
  314. local_weights_label = [(f, w / weight_max)
  315. for f, w in local_weights_label]
  316. local_weights_label = sorted(
  317. local_weights_label, key=lambda x: x[1],
  318. reverse=True) # negatives are at last.
  319. cmaps = cm.get_cmap('Spectral')
  320. colors = cmaps(np.linspace(0, 1, len(local_weights_label)))
  321. colors = colors[:, :3]
  322. for i, (f, w) in enumerate(local_weights_label):
  323. if np.abs(w) < min_weight:
  324. continue
  325. temp[segments == f] = image[segments == f].copy()
  326. temp[segments == f] = colors[i] * 255
  327. return temp
  328. class LimeImageInterpreter(object):
  329. """Interpres predictions on Image (i.e. matrix) data.
  330. For numerical features, perturb them by sampling from a Normal(0,1) and
  331. doing the inverse operation of mean-centering and scaling, according to the
  332. means and stds in the training data. For categorical features, perturb by
  333. sampling according to the training distribution, and making a binary
  334. feature that is 1 when the value is the same as the instance being
  335. interpreted."""
  336. def __init__(self,
  337. kernel_width=.25,
  338. kernel=None,
  339. verbose=False,
  340. feature_selection='auto',
  341. random_state=None):
  342. """Init function.
  343. Args:
  344. kernel_width: kernel width for the exponential kernel.
  345. If None, defaults to sqrt(number of columns) * 0.75.
  346. kernel: similarity kernel that takes euclidean distances and kernel
  347. width as input and outputs weights in (0,1). If None, defaults to
  348. an exponential kernel.
  349. verbose: if true, print local prediction values from linear model
  350. feature_selection: feature selection method. can be
  351. 'forward_selection', 'lasso_path', 'none' or 'auto'.
  352. See function 'einterpret_instance_with_data' in lime_base.py for
  353. details on what each of the options does.
  354. random_state: an integer or numpy.RandomState that will be used to
  355. generate random numbers. If None, the random state will be
  356. initialized using the internal numpy seed.
  357. """
  358. from sklearn.utils import check_random_state
  359. kernel_width = float(kernel_width)
  360. if kernel is None:
  361. def kernel(d, kernel_width):
  362. return np.sqrt(np.exp(-(d**2) / kernel_width**2))
  363. kernel_fn = partial(kernel, kernel_width=kernel_width)
  364. self.random_state = check_random_state(random_state)
  365. self.feature_selection = feature_selection
  366. self.base = LimeBase(
  367. kernel_fn, verbose, random_state=self.random_state)
  368. def interpret_instance(self,
  369. image,
  370. classifier_fn,
  371. labels=(1, ),
  372. hide_color=None,
  373. num_features=100000,
  374. num_samples=1000,
  375. batch_size=10,
  376. distance_metric='cosine',
  377. model_regressor=None):
  378. """Generates interpretations for a prediction.
  379. First, we generate neighborhood data by randomly perturbing features
  380. from the instance (see __data_inverse). We then learn locally weighted
  381. linear models on this neighborhood data to interpret each of the classes
  382. in an interpretable way (see lime_base.py).
  383. Args:
  384. image: 3 dimension RGB image. If this is only two dimensional,
  385. we will assume it's a grayscale image and call gray2rgb.
  386. classifier_fn: classifier prediction probability function, which
  387. takes a numpy array and outputs prediction probabilities. For
  388. ScikitClassifiers , this is classifier.predict_proba.
  389. labels: iterable with labels to be interpreted.
  390. hide_color: TODO
  391. num_features: maximum number of features present in interpretation
  392. num_samples: size of the neighborhood to learn the linear model
  393. batch_size: TODO
  394. distance_metric: the distance metric to use for weights.
  395. model_regressor: sklearn regressor to use in interpretation. Defaults
  396. to Ridge regression in LimeBase. Must have model_regressor.coef_
  397. and 'sample_weight' as a parameter to model_regressor.fit()
  398. Returns:
  399. An ImageIinterpretation object (see lime_image.py) with the corresponding
  400. interpretations.
  401. """
  402. import sklearn
  403. from skimage.measure import regionprops
  404. from skimage.segmentation import quickshift
  405. from skimage.color import gray2rgb
  406. if len(image.shape) == 2:
  407. image = gray2rgb(image)
  408. try:
  409. segments = quickshift(image, sigma=1)
  410. except ValueError as e:
  411. raise e
  412. self.segments = segments
  413. fudged_image = image.copy()
  414. # global_mean = np.mean(image, (0, 1))
  415. if hide_color is None:
  416. # if no hide_color, use the mean
  417. for x in np.unique(segments):
  418. mx = np.mean(image[segments == x], axis=0)
  419. fudged_image[segments == x] = mx
  420. elif hide_color == 'avg_from_neighbor':
  421. from scipy.spatial.distance import cdist
  422. n_features = np.unique(segments).shape[0]
  423. regions = regionprops(segments + 1)
  424. centroids = np.zeros((n_features, 2))
  425. for i, x in enumerate(regions):
  426. centroids[i] = np.array(x.centroid)
  427. d = cdist(centroids, centroids, 'sqeuclidean')
  428. for x in np.unique(segments):
  429. a = [image[segments == i] for i in np.argsort(d[x])[1:6]]
  430. mx = np.mean(np.concatenate(a), axis=0)
  431. fudged_image[segments == x] = mx
  432. else:
  433. fudged_image[:] = 0
  434. top = labels
  435. data, labels = self.data_labels(
  436. image,
  437. fudged_image,
  438. segments,
  439. classifier_fn,
  440. num_samples,
  441. batch_size=batch_size)
  442. distances = sklearn.metrics.pairwise_distances(
  443. data, data[0].reshape(1, -1), metric=distance_metric).ravel()
  444. interpretation_image = ImageInterpretation(image, segments)
  445. for label in top:
  446. (interpretation_image.intercept[label],
  447. interpretation_image.local_weights[label],
  448. interpretation_image.score, interpretation_image.local_pred
  449. ) = self.base.interpret_instance_with_data(
  450. data,
  451. labels,
  452. distances,
  453. label,
  454. num_features,
  455. model_regressor=model_regressor,
  456. feature_selection=self.feature_selection)
  457. return interpretation_image
  458. def data_labels(self,
  459. image,
  460. fudged_image,
  461. segments,
  462. classifier_fn,
  463. num_samples,
  464. batch_size=10):
  465. """Generates images and predictions in the neighborhood of this image.
  466. Args:
  467. image: 3d numpy array, the image
  468. fudged_image: 3d numpy array, image to replace original image when
  469. superpixel is turned off
  470. segments: segmentation of the image
  471. classifier_fn: function that takes a list of images and returns a
  472. matrix of prediction probabilities
  473. num_samples: size of the neighborhood to learn the linear model
  474. batch_size: classifier_fn will be called on batches of this size.
  475. Returns:
  476. A tuple (data, labels), where:
  477. data: dense num_samples * num_superpixels
  478. labels: prediction probabilities matrix
  479. """
  480. n_features = np.unique(segments).shape[0]
  481. data = self.random_state.randint(0, 2, num_samples * n_features) \
  482. .reshape((num_samples, n_features))
  483. labels = []
  484. data[0, :] = 1
  485. imgs = []
  486. logging.info("Computing LIME.", use_color=True)
  487. for row in tqdm.tqdm(data):
  488. temp = copy.deepcopy(image)
  489. zeros = np.where(row == 0)[0]
  490. mask = np.zeros(segments.shape).astype(bool)
  491. for z in zeros:
  492. mask[segments == z] = True
  493. temp[mask] = fudged_image[mask]
  494. imgs.append(temp)
  495. if len(imgs) == batch_size:
  496. preds = classifier_fn(np.array(imgs))
  497. labels.extend(preds)
  498. imgs = []
  499. if len(imgs) > 0:
  500. preds = classifier_fn(np.array(imgs))
  501. labels.extend(preds)
  502. return data, np.array(labels)