|
|
@@ -17,6 +17,7 @@ import cv2
|
|
|
import copy
|
|
|
import os.path as osp
|
|
|
import numpy as np
|
|
|
+from .interpretation_predict import interpretation_predict
|
|
|
from .core.interpretation import Interpretation
|
|
|
from .core.normlime_base import precompute_normlime_weights
|
|
|
|
|
|
@@ -28,6 +29,18 @@ def visualize(img_file,
|
|
|
num_samples=3000,
|
|
|
batch_size=50,
|
|
|
save_dir='./'):
|
|
|
+ """可解释性可视化。
|
|
|
+ Args:
|
|
|
+ img_file (str): 预测图像路径。
|
|
|
+ model (paddlex.cv.models): paddlex中的模型。
|
|
|
+ dataset (paddlex.datasets): 数据集读取器,默认为None。
|
|
|
+ algo (str): 可解释性方式,当前可选'lime'和'normlime'。
|
|
|
+ num_samples (int): 随机采样数量,默认为3000。
|
|
|
+ batch_size (int): 预测数据batch大小,默认为50。
|
|
|
+ save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
|
|
|
+ """
|
|
|
+ assert model.model_type == 'classifier', \
|
|
|
+ 'Now the interpretation visualize only be supported in classifier!'
|
|
|
if model.status != 'Normal':
|
|
|
raise Exception('The interpretation only can deal with the Normal model')
|
|
|
model.arrange_transforms(
|
|
|
@@ -59,7 +72,7 @@ def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50):
|
|
|
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
|
|
|
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
|
|
|
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
|
|
|
- out = model.interpretation_predict(image)
|
|
|
+ out = interpretation_predict(model, image)
|
|
|
model.test_transforms.transforms = tmp_transforms
|
|
|
return out[0]
|
|
|
labels_name = None
|
|
|
@@ -78,7 +91,7 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
|
|
|
image = image.astype('float32')
|
|
|
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
|
|
|
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
|
|
|
- out = model.interpretation_predict(image)
|
|
|
+ out = interpretation_predict(model, image)
|
|
|
model.test_transforms.transforms = tmp_transforms
|
|
|
return out[0]
|
|
|
def predict_func(image):
|
|
|
@@ -87,7 +100,7 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
|
|
|
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
|
|
|
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
|
|
|
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
|
|
|
- out = model.interpretation_predict(image)
|
|
|
+ out = interpretation_predict(model, image)
|
|
|
model.test_transforms.transforms = tmp_transforms
|
|
|
return out[0]
|
|
|
labels_name = None
|