|
|
@@ -22,20 +22,65 @@ from .interpretation_predict import interpretation_predict
|
|
|
from .core.interpretation import Interpretation
|
|
|
from .core.normlime_base import precompute_normlime_weights
|
|
|
from .core._session_preparation import gen_user_home
|
|
|
-
|
|
|
-def visualize(img_file,
|
|
|
+
|
|
|
+def lime(img_file,
|
|
|
+ model,
|
|
|
+ num_samples=3000,
|
|
|
+ batch_size=50,
|
|
|
+ save_dir='./'):
|
|
|
+ """使用LIME算法将模型预测结果的可解释性可视化。
|
|
|
+
|
|
|
+ LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,
|
|
|
+ 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入
|
|
|
+ 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,
|
|
|
+ 得到每个输入维度的权重,以此来解释模型。
|
|
|
+
|
|
|
+ 注意:LIME可解释性结果可视化目前只支持分类模型。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img_file (str): 预测图像路径。
|
|
|
+ model (paddlex.cv.models): paddlex中的模型。
|
|
|
+ num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
|
|
|
+ batch_size (int): 预测数据batch大小,默认为50。
|
|
|
+ save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
|
|
|
+ """
|
|
|
+ assert model.model_type == 'classifier', \
|
|
|
+ 'Now the interpretation visualize only be supported in classifier!'
|
|
|
+ if model.status != 'Normal':
|
|
|
+ raise Exception('The interpretation only can deal with the Normal model')
|
|
|
+ if not osp.exists(save_dir):
|
|
|
+ os.makedirs(save_dir)
|
|
|
+ model.arrange_transforms(
|
|
|
+ transforms=model.test_transforms, mode='test')
|
|
|
+ tmp_transforms = copy.deepcopy(model.test_transforms)
|
|
|
+ tmp_transforms.transforms = tmp_transforms.transforms[:-2]
|
|
|
+ img = tmp_transforms(img_file)[0]
|
|
|
+ img = np.around(img).astype('uint8')
|
|
|
+ img = np.expand_dims(img, axis=0)
|
|
|
+ interpreter = None
|
|
|
+ interpreter = get_lime_interpreter(img, model, num_samples=num_samples, batch_size=batch_size)
|
|
|
+ img_name = osp.splitext(osp.split(img_file)[-1])[0]
|
|
|
+ interpreter.interpret(img, save_dir=save_dir)
|
|
|
+
|
|
|
+
|
|
|
+def normlime(img_file,
|
|
|
model,
|
|
|
dataset=None,
|
|
|
- algo='lime',
|
|
|
num_samples=3000,
|
|
|
batch_size=50,
|
|
|
save_dir='./'):
|
|
|
- """可解释性可视化。
|
|
|
+ """使用NormLIME算法将模型预测结果的可解释性可视化。
|
|
|
+
|
|
|
+ NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测
|
|
|
+ 试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。
|
|
|
+
|
|
|
+ 注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。
|
|
|
+ 注意2:NormLIME可解释性结果可视化目前只支持分类模型。
|
|
|
+
|
|
|
Args:
|
|
|
img_file (str): 预测图像路径。
|
|
|
model (paddlex.cv.models): paddlex中的模型。
|
|
|
dataset (paddlex.datasets): 数据集读取器,默认为None。
|
|
|
- algo (str): 可解释性方式,当前可选'lime'和'normlime'。
|
|
|
num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
|
|
|
batch_size (int): 预测数据batch大小,默认为50。
|
|
|
save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
|
|
|
@@ -54,21 +99,16 @@ def visualize(img_file,
|
|
|
img = np.around(img).astype('uint8')
|
|
|
img = np.expand_dims(img, axis=0)
|
|
|
interpreter = None
|
|
|
- if algo == 'lime':
|
|
|
- interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
|
|
|
- elif algo == 'normlime':
|
|
|
- if dataset is None:
|
|
|
- raise Exception('The dataset is None. Cannot implement this kind of interpretation')
|
|
|
- interpreter = get_normlime_interpreter(img, model, dataset,
|
|
|
- num_samples=num_samples, batch_size=batch_size,
|
|
|
+ if dataset is None:
|
|
|
+ raise Exception('The dataset is None. Cannot implement this kind of interpretation')
|
|
|
+ interpreter = get_normlime_interpreter(img, model, dataset,
|
|
|
+ num_samples=num_samples, batch_size=batch_size,
|
|
|
save_dir=save_dir)
|
|
|
- else:
|
|
|
- raise Exception('The {} interpretation method is not supported yet!'.format(algo))
|
|
|
img_name = osp.splitext(osp.split(img_file)[-1])[0]
|
|
|
interpreter.interpret(img, save_dir=save_dir)
|
|
|
|
|
|
|
|
|
-def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
|
|
|
+def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
|
|
|
def predict_func(image):
|
|
|
image = image.astype('float32')
|
|
|
for i in range(image.shape[0]):
|
|
|
@@ -79,8 +119,8 @@ def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
|
|
|
model.test_transforms.transforms = tmp_transforms
|
|
|
return out[0]
|
|
|
labels_name = None
|
|
|
- if dataset is not None:
|
|
|
- labels_name = dataset.labels
|
|
|
+ if hasattr(model, 'labels'):
|
|
|
+ labels_name = model.labels
|
|
|
interpreter = Interpretation('lime',
|
|
|
predict_func,
|
|
|
labels_name,
|