|
|
@@ -17,6 +17,7 @@ import cv2
|
|
|
import copy
|
|
|
import os.path as osp
|
|
|
import numpy as np
|
|
|
+import paddlex as pdx
|
|
|
from .interpretation_predict import interpretation_predict
|
|
|
from .core.interpretation import Interpretation
|
|
|
from .core.normlime_base import precompute_normlime_weights
|
|
|
@@ -35,7 +36,7 @@ def visualize(img_file,
|
|
|
model (paddlex.cv.models): paddlex中的模型。
|
|
|
dataset (paddlex.datasets): 数据集读取器,默认为None。
|
|
|
algo (str): 可解释性方式,当前可选'lime'和'normlime'。
|
|
|
- num_samples (int): 随机采样数量,默认为3000。
|
|
|
+ num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
|
|
|
batch_size (int): 预测数据batch大小,默认为50。
|
|
|
save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
|
|
|
"""
|
|
|
@@ -111,8 +112,8 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
|
|
|
pre_models_path = osp.join(root_path, "pre_models")
|
|
|
if not osp.exists(pre_models_path):
|
|
|
os.makedirs(pre_models_path)
|
|
|
- # TODO
|
|
|
- # paddlex.utils.download_and_decompress(url, path=pre_models_path)
|
|
|
+ url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
|
|
|
+ pdx.utils.download_and_decompress(url, path=pre_models_path)
|
|
|
npy_dir = precompute_for_normlime(precompute_predict_func,
|
|
|
dataset,
|
|
|
num_samples=num_samples,
|