sunyanfang01 5 лет назад
Родитель
Сommit
fd222281d2
65 измененных файлов с 432 добавлено и 135 удалено
  1. 2 4
      paddlex/cv/models/classifier.py
  2. BIN
      paddlex/cv/models/explanation/__pycache__/visualize.cpython-37.pyc
  3. BIN
      paddlex/cv/models/explanation/as_data_reader/__pycache__/data_path_utils.cpython-37.pyc
  4. BIN
      paddlex/cv/models/explanation/as_data_reader/__pycache__/readers.cpython-37.pyc
  5. 1 1
      paddlex/cv/models/explanation/as_data_reader/data_path_utils.py
  6. BIN
      paddlex/cv/models/explanation/core/__pycache__/_session_preparation.cpython-37.pyc
  7. BIN
      paddlex/cv/models/explanation/core/__pycache__/explanation.cpython-37.pyc
  8. BIN
      paddlex/cv/models/explanation/core/__pycache__/explanation_algorithms.cpython-37.pyc
  9. BIN
      paddlex/cv/models/explanation/core/__pycache__/lime_base.cpython-37.pyc
  10. BIN
      paddlex/cv/models/explanation/core/__pycache__/normlime_base.cpython-37.pyc
  11. 87 1
      paddlex/cv/models/explanation/core/_session_preparation.py
  12. 1 1
      paddlex/cv/models/explanation/core/explanation.py
  13. 51 104
      paddlex/cv/models/explanation/core/explanation_algorithms.py
  14. 221 0
      paddlex/cv/models/explanation/core/normlime_base.py
  15. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_1_mean
  16. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_1_offset
  17. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_1_scale
  18. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_1_variance
  19. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_2_mean
  20. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_2_offset
  21. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_2_scale
  22. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_2_variance
  23. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_3_mean
  24. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_3_offset
  25. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_3_scale
  26. BIN
      paddlex/cv/models/explanation/pre_models/bnv1_3_variance
  27. BIN
      paddlex/cv/models/explanation/pre_models/conv1_1_weights
  28. BIN
      paddlex/cv/models/explanation/pre_models/conv1_2_weights
  29. BIN
      paddlex/cv/models/explanation/pre_models/conv1_3_weights
  30. BIN
      paddlex/cv/models/explanation/pre_models/kmeans_model.pkl
  31. BIN
      paddlex/cv/models/explanation/pre_models/normlime_weights_imagenet_resnet50vc.npy
  32. 57 0
      paddlex/cv/models/explanation/visualize.py
  33. BIN
      paddlex/cv/nets/__pycache__/__init__.cpython-37.pyc
  34. BIN
      paddlex/cv/nets/__pycache__/backbone_utils.cpython-37.pyc
  35. BIN
      paddlex/cv/nets/__pycache__/darknet.cpython-37.pyc
  36. BIN
      paddlex/cv/nets/__pycache__/densenet.cpython-37.pyc
  37. BIN
      paddlex/cv/nets/__pycache__/mobilenet_v1.cpython-37.pyc
  38. BIN
      paddlex/cv/nets/__pycache__/mobilenet_v2.cpython-37.pyc
  39. BIN
      paddlex/cv/nets/__pycache__/mobilenet_v3.cpython-37.pyc
  40. BIN
      paddlex/cv/nets/__pycache__/resnet.cpython-37.pyc
  41. BIN
      paddlex/cv/nets/__pycache__/shufflenet_v2.cpython-37.pyc
  42. BIN
      paddlex/cv/nets/__pycache__/xception.cpython-37.pyc
  43. 2 5
      paddlex/cv/nets/darknet.py
  44. 2 3
      paddlex/cv/nets/densenet.py
  45. BIN
      paddlex/cv/nets/detection/__pycache__/__init__.cpython-37.pyc
  46. BIN
      paddlex/cv/nets/detection/__pycache__/bbox_head.cpython-37.pyc
  47. BIN
      paddlex/cv/nets/detection/__pycache__/faster_rcnn.cpython-37.pyc
  48. BIN
      paddlex/cv/nets/detection/__pycache__/fpn.cpython-37.pyc
  49. BIN
      paddlex/cv/nets/detection/__pycache__/mask_head.cpython-37.pyc
  50. BIN
      paddlex/cv/nets/detection/__pycache__/mask_rcnn.cpython-37.pyc
  51. BIN
      paddlex/cv/nets/detection/__pycache__/roi_extractor.cpython-37.pyc
  52. BIN
      paddlex/cv/nets/detection/__pycache__/rpn_head.cpython-37.pyc
  53. BIN
      paddlex/cv/nets/detection/__pycache__/yolo_v3.cpython-37.pyc
  54. 2 4
      paddlex/cv/nets/mobilenet_v1.py
  55. 1 3
      paddlex/cv/nets/mobilenet_v2.py
  56. 1 2
      paddlex/cv/nets/mobilenet_v3.py
  57. 1 1
      paddlex/cv/nets/resnet.py
  58. BIN
      paddlex/cv/nets/segmentation/__pycache__/__init__.cpython-37.pyc
  59. BIN
      paddlex/cv/nets/segmentation/__pycache__/deeplabv3p.cpython-37.pyc
  60. BIN
      paddlex/cv/nets/segmentation/__pycache__/unet.cpython-37.pyc
  61. BIN
      paddlex/cv/nets/segmentation/model_utils/__pycache__/__init__.cpython-37.pyc
  62. BIN
      paddlex/cv/nets/segmentation/model_utils/__pycache__/libs.cpython-37.pyc
  63. BIN
      paddlex/cv/nets/segmentation/model_utils/__pycache__/loss.cpython-37.pyc
  64. 1 3
      paddlex/cv/nets/shufflenet_v2.py
  65. 2 3
      paddlex/cv/nets/xception.py

+ 2 - 4
paddlex/cv/models/classifier.py

@@ -61,11 +61,9 @@ class BaseClassifier(BaseAPI):
             label = fluid.data(dtype='int64', shape=[None, 1], name='label')
         model = getattr(paddlex.cv.nets, str.lower(self.model_name))
         net_out = model(image, num_classes=self.num_classes)
-        softmax_out = fluid.layers.softmax(net_out['logits'], use_cudnn=False)
+        softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
         inputs = OrderedDict([('image', image)])
-        outputs = net_out
-        outputs.update({'predict': softmax_out})
-        outputs.move_to_end('predict', last=False)
+        outputs = OrderedDict([('predict', softmax_out), ('logits', net_out)])
         if mode != 'test':
             cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
             avg_cost = fluid.layers.mean(cost)

BIN
paddlex/cv/models/explanation/__pycache__/visualize.cpython-37.pyc


BIN
paddlex/cv/models/explanation/as_data_reader/__pycache__/data_path_utils.cpython-37.pyc


BIN
paddlex/cv/models/explanation/as_data_reader/__pycache__/readers.cpython-37.pyc


+ 1 - 1
paddlex/cv/models/explanation/as_data_reader/data_path_utils.py

@@ -19,4 +19,4 @@ def _find_classes(dir):
     classes = [d.name for d in os.scandir(dir) if d.is_dir()]
     classes.sort()
     class_to_idx = {classes[i]: i for i in range(len(classes))}
-    return classes, class_to_idx
+    return classes, class_to_idx

BIN
paddlex/cv/models/explanation/core/__pycache__/_session_preparation.cpython-37.pyc


BIN
paddlex/cv/models/explanation/core/__pycache__/explanation.cpython-37.pyc


BIN
paddlex/cv/models/explanation/core/__pycache__/explanation_algorithms.cpython-37.pyc


BIN
paddlex/cv/models/explanation/core/__pycache__/lime_base.cpython-37.pyc


BIN
paddlex/cv/models/explanation/core/__pycache__/normlime_base.cpython-37.pyc


+ 87 - 1
paddlex/cv/models/explanation/core/_session_preparation.py

@@ -15,6 +15,13 @@
 import os
 import paddle.fluid as fluid
 import numpy as np
+from paddle.fluid.param_attr import ParamAttr
+from ..as_data_reader.readers import preprocess_image
+
+root_path = os.environ['HOME']
+root_path = os.path.join(root_path, '.paddlex')
+h_pre_models = os.path.join(root_path, "pre_models")
+h_pre_models_kmeans = os.path.join(h_pre_models, "kmeans_model.pkl")
 
 
 def paddle_get_fc_weights(var_name="fc_0.w_0"):
@@ -24,4 +31,83 @@ def paddle_get_fc_weights(var_name="fc_0.w_0"):
 
 def paddle_resize(extracted_features, outsize):
     resized_features = fluid.layers.resize_bilinear(extracted_features, outsize)
-    return resized_features
+    return resized_features
+
+
+def compute_features_for_kmeans(data_content):
+    def conv_bn_layer(input,
+                      num_filters,
+                      filter_size,
+                      stride=1,
+                      groups=1,
+                      act=None,
+                      name=None,
+                      is_test=True,
+                      global_name=''):
+        conv = fluid.layers.conv2d(
+            input=input,
+            num_filters=num_filters,
+            filter_size=filter_size,
+            stride=stride,
+            padding=(filter_size - 1) // 2,
+            groups=groups,
+            act=None,
+            param_attr=ParamAttr(name=global_name + name + "_weights"),
+            bias_attr=False,
+            name=global_name + name + '.conv2d.output.1')
+        if name == "conv1":
+            bn_name = "bn_" + name
+        else:
+            bn_name = "bn" + name[3:]
+        return fluid.layers.batch_norm(
+            input=conv,
+            act=act,
+            name=global_name + bn_name + '.output.1',
+            param_attr=ParamAttr(global_name + bn_name + '_scale'),
+            bias_attr=ParamAttr(global_name + bn_name + '_offset'),
+            moving_mean_name=global_name + bn_name + '_mean',
+            moving_variance_name=global_name + bn_name + '_variance',
+            use_global_stats=is_test
+        )
+
+    startup_prog = fluid.default_startup_program().clone(for_test=True)
+    prog = fluid.Program()
+    with fluid.program_guard(prog, startup_prog):
+        with fluid.unique_name.guard():
+            image_op = fluid.data(name='image', shape=[None, 3, 224, 224], dtype='float32')
+
+            conv = conv_bn_layer(
+                input=image_op,
+                num_filters=32,
+                filter_size=3,
+                stride=2,
+                act='relu',
+                name='conv1_1')
+            conv = conv_bn_layer(
+                input=conv,
+                num_filters=32,
+                filter_size=3,
+                stride=1,
+                act='relu',
+                name='conv1_2')
+            conv = conv_bn_layer(
+                input=conv,
+                num_filters=64,
+                filter_size=3,
+                stride=1,
+                act='relu',
+                name='conv1_3')
+            extracted_features = conv
+            resized_features = fluid.layers.resize_bilinear(extracted_features, image_op.shape[2:])
+
+    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
+    place = fluid.CUDAPlace(gpu_id)
+    # place = fluid.CPUPlace()
+    exe = fluid.Executor(place)
+    exe.run(startup_prog)
+    fluid.io.load_persistables(exe, h_pre_models, prog)
+
+    images = preprocess_image(data_content)  # transpose to [N, 3, H, W], scaled to [0.0, 1.0]
+    result = exe.run(prog, fetch_list=[resized_features], feed={'image': images})
+
+    return result[0][0]

+ 1 - 1
paddlex/cv/models/explanation/core/explanation.py

@@ -13,6 +13,7 @@
 #limitations under the License.
 
 from .explanation_algorithms import CAM, LIME, NormLIME
+from .normlime_base import precompute_normlime_weights
 
 
 class Explanation(object):
@@ -48,4 +49,3 @@ class Explanation(object):
 
         """
         return self.explain_algorithm.explain(data_, visualization, save_to_disk, save_dir)
-

+ 51 - 104
paddlex/cv/models/explanation/core/explanation_algorithms.py

@@ -18,7 +18,8 @@ import time
 
 from . import lime_base
 from ..as_data_reader.readers import read_image
-from ._session_preparation import paddle_get_fc_weights
+from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, h_pre_models_kmeans
+from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model
 
 import cv2
 
@@ -37,8 +38,8 @@ class CAM(object):
         """
         self.predict_fn = predict_fn
 
-    def preparation_cam(self, data_path):
-        image_show = read_image(data_path)
+    def preparation_cam(self, data_):
+        image_show = read_image(data_)
         result = self.predict_fn(image_show)
 
         logit = result[0][0]
@@ -61,7 +62,7 @@ class CAM(object):
         fc_weights = paddle_get_fc_weights()
         feature_maps = result[1]
 
-        print('predicted result: ', pred_label[0], probability[pred_label[0]])
+        print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}')
         return feature_maps, fc_weights
 
     def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
@@ -115,8 +116,8 @@ class LIME(object):
         self.image = None
         self.lime_explainer = None
 
-    def preparation_lime(self, data_path):
-        image_show = read_image(data_path)
+    def preparation_lime(self, data_):
+        image_show = read_image(data_)
         result = self.predict_fn(image_show)
 
         result = result[0]  # only one image here.
@@ -137,7 +138,7 @@ class LIME(object):
         self.image = image_show[0]
         self.labels = pred_label
 
-        print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]: .3f}')
+        print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}')
 
         end = time.time()
         algo = lime_base.LimeImageExplainer()
@@ -157,7 +158,7 @@ class LIME(object):
 
             psize = 5
             nrows = 2
-            weights_choices = [0.6, 0.75, 0.85]
+            weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
             ncols = len(weights_choices)
 
             plt.close()
@@ -193,15 +194,25 @@ class LIME(object):
 class NormLIME(object):
     def __init__(self, predict_fn, num_samples=3000, batch_size=50,
                  kmeans_model_for_normlime=None, normlime_weights=None):
-        assert kmeans_model_for_normlime is not None, "NormLIME needs the KMeans model."
-        if normlime_weights is None:
-            raise NotImplementedError("Computing NormLIME weights is not implemented yet.")
+        if kmeans_model_for_normlime is None:
+            try:
+                self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
+            except:
+                raise ValueError("NormLIME needs the KMeans model, where we provided a default one in "
+                                 "pre_models/kmeans_model.pkl.")
+        else:
+            print("Warning: It is *strongly* suggested to use the default KMeans model in pre_models/kmeans_model.pkl. "
+                  "Use another one will change the final result.")
+            self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
 
         self.num_samples = num_samples
         self.batch_size = batch_size
 
-        self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
-        self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
+        try:
+            self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
+        except:
+            self.normlime_weights = None
+            print("Warning: not find the correct precomputed Normlime result.")
 
         self.predict_fn = predict_fn
 
@@ -215,9 +226,8 @@ class NormLIME(object):
         # global weights
         g_weights = {y: [] for y in pred_labels}
         for y in pred_labels:
-            cluster_weights_y = self.normlime_weights[y]
+            cluster_weights_y = self.normlime_weights.get(y, {})
             g_weights[y] = [
-                # some are not in the dict, 3000 samples may be not enough.
                 (i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
             ]
 
@@ -226,38 +236,25 @@ class NormLIME(object):
 
         return g_weights
 
-    def preparation_normlime(self, data_path):
+    def preparation_normlime(self, data_):
         self._lime = LIME(
-            lambda images: self.predict_fn(images)[0],
+            self.predict_fn,
             self.num_samples,
             self.batch_size
         )
-        self._lime.preparation_lime(data_path)
+        self._lime.preparation_lime(data_)
 
-        image_show = read_image(data_path)
-        result = self.predict_fn(image_show)
+        image_show = read_image(data_)
 
-        logit = result[0][0]  # only one image here.
-        if abs(np.sum(logit) - 1.0) > 1e-4:
-            # softmax
-            exp_result = np.exp(logit)
-            probability = exp_result / np.sum(exp_result)
-        else:
-            probability = logit
-
-        # only explain top 1
-        pred_label = np.argsort(probability)
-        pred_label = pred_label[-1:]
-
-        self.predicted_label = pred_label[0]
-        self.predicted_probability = probability[pred_label[0]]
+        self.predicted_label = self._lime.predicted_label
+        self.predicted_probability = self._lime.predicted_probability
         self.image = image_show[0]
-        self.labels = pred_label
-        print('predicted result: ', pred_label[0], probability[pred_label[0]])
+        self.labels = self._lime.labels
+        # print(f'predicted result: {self.predicted_label} with probability {self.predicted_probability: .3f}')
+        print('performing NormLIME operations ...')
 
-        local_feature_map = result[1][0]
         cluster_labels = self.predict_cluster_labels(
-            local_feature_map.transpose((1, 2, 0)), self._lime.lime_explainer.segments
+            compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_explainer.segments
         )
 
         g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
@@ -265,6 +262,10 @@ class NormLIME(object):
         return g_weights
 
     def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
+        if self.normlime_weights is None:
+            raise ValueError("Not find the correct precomputed NormLIME result. \n"
+                             "\t Try to call compute_normlime_weights() first or load the correct path.")
+
         g_weights = self.preparation_normlime(data_)
         lime_weights = self._lime.lime_explainer.local_exp
 
@@ -275,7 +276,8 @@ class NormLIME(object):
 
             psize = 5
             nrows = 4
-            weights_choices = [0.6, 0.85, 0.99]
+            weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85]
+            nums_to_show = []
             ncols = len(weights_choices)
 
             plt.close()
@@ -293,32 +295,31 @@ class NormLIME(object):
             # LIME visualization
             for i, w in enumerate(weights_choices):
                 num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
+                nums_to_show.append(num_to_show)
                 temp, mask = self._lime.lime_explainer.get_image_and_mask(
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 axes[ncols + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
+                axes[ncols + i].set_title(f"LIME: first {num_to_show} superpixels")
 
             # NormLIME visualization
             self._lime.lime_explainer.local_exp = g_weights
-            for i, w in enumerate(weights_choices):
-                num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
+            for i, num_to_show in enumerate(nums_to_show):
                 temp, mask = self._lime.lime_explainer.get_image_and_mask(
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols * 2 + i].set_title(f"label {l}, first {num_to_show} superpixels")
+                axes[ncols * 2 + i].set_title(f"NormLIME: first {num_to_show} superpixels")
 
             # NormLIME*LIME visualization
             combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
             self._lime.lime_explainer.local_exp = combined_weights
-            for i, w in enumerate(weights_choices):
-                num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
+            for i, num_to_show in enumerate(nums_to_show):
                 temp, mask = self._lime.lime_explainer.get_image_and_mask(
                     l, positive_only=False, hide_rest=False, num_features=num_to_show
                 )
                 axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
-                axes[ncols * 3 + i].set_title(f"label {l}, first {num_to_show} superpixels")
+                axes[ncols * 3 + i].set_title(f"Combined: first {num_to_show} superpixels")
 
             self._lime.lime_explainer.local_exp = lime_weights
 
@@ -330,14 +331,6 @@ class NormLIME(object):
             plt.show()
 
 
-def load_kmeans_model(fname):
-    import pickle
-    with open(fname, 'rb') as f:
-        kmeans_model = pickle.load(f)
-
-    return kmeans_model
-
-
 def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
     segments = lime_explainer.segments
     lime_weights = lime_explainer.local_exp[label]
@@ -361,6 +354,9 @@ def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
             n = i + 1
             break
 
+    if percentage_to_show <= 0.0:
+        return 5
+
     if n == 0:
         return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1)
 
@@ -391,55 +387,6 @@ def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam
     return cam
 
 
-def avg_using_superpixels(features, segments):
-    one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
-    for x in np.unique(segments):
-        one_list[x] = np.mean(features[segments == x], axis=0)
-
-    return one_list
-
-
-def centroid_using_superpixels(features, segments):
-    from skimage.measure import regionprops
-    regions = regionprops(segments + 1)
-    one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
-    for i, r in enumerate(regions):
-        one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
-    # print(one_list.shape)
-    return one_list
-
-
-def get_feature_for_kmeans(feature_map, segments):
-    from sklearn.preprocessing import normalize
-    centroid_feature = centroid_using_superpixels(feature_map, segments)
-    avg_feature = avg_using_superpixels(feature_map, segments)
-    x = np.concatenate((centroid_feature, avg_feature), axis=-1)
-    x = normalize(x)
-    return x
-
-
-def combine_normlime_and_lime(lime_weights, g_weights):
-    pred_labels = lime_weights.keys()
-    combined_weights = {y: [] for y in pred_labels}
-
-    for y in pred_labels:
-        normlized_lime_weights_y = lime_weights[y]
-        lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y}
-
-        normlized_g_weight_y = g_weights[y]
-        normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y}
-
-        combined_weights[y] = [
-            (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
-            for seg_k in lime_weights_dict.keys()
-        ]
-
-        combined_weights[y] = sorted(combined_weights[y],
-                                     key=lambda x: np.abs(x[1]), reverse=True)
-
-    return combined_weights
-
-
 def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
     import matplotlib.pyplot as plt
     if isinstance(data_, str):
@@ -469,4 +416,4 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
             os.path.join(
                 save_outdir, f_out
             )
-        )
+        )

+ 221 - 0
paddlex/cv/models/explanation/core/normlime_base.py

@@ -0,0 +1,221 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import os
+import numpy as np
+import glob
+
+from ..as_data_reader.readers import read_image
+from . import lime_base
+from ._session_preparation import compute_features_for_kmeans, h_pre_models_kmeans
+
+
+def load_kmeans_model(fname):
+    import pickle
+    with open(fname, 'rb') as f:
+        kmeans_model = pickle.load(f)
+
+    return kmeans_model
+
+
+def combine_normlime_and_lime(lime_weights, g_weights):
+    pred_labels = lime_weights.keys()
+    combined_weights = {y: [] for y in pred_labels}
+
+    for y in pred_labels:
+        normlized_lime_weights_y = lime_weights[y]
+        lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y}
+
+        normlized_g_weight_y = g_weights[y]
+        normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y}
+
+        combined_weights[y] = [
+            (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
+            for seg_k in lime_weights_dict.keys()
+        ]
+
+        combined_weights[y] = sorted(combined_weights[y],
+                                     key=lambda x: np.abs(x[1]), reverse=True)
+
+    return combined_weights
+
+
+def avg_using_superpixels(features, segments):
+    one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
+    for x in np.unique(segments):
+        one_list[x] = np.mean(features[segments == x], axis=0)
+
+    return one_list
+
+
+def centroid_using_superpixels(features, segments):
+    from skimage.measure import regionprops
+    regions = regionprops(segments + 1)
+    one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
+    for i, r in enumerate(regions):
+        one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
+    # print(one_list.shape)
+    return one_list
+
+
+def get_feature_for_kmeans(feature_map, segments):
+    from sklearn.preprocessing import normalize
+    centroid_feature = centroid_using_superpixels(feature_map, segments)
+    avg_feature = avg_using_superpixels(feature_map, segments)
+    x = np.concatenate((centroid_feature, avg_feature), axis=-1)
+    x = normalize(x)
+    return x
+
+
+def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_size=50, save_dir='./tmp'):
+    # save lime weights and kmeans cluster labels
+    precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir)
+
+    # load precomputed results, compute normlime weights and save.
+    fname_list = glob.glob(os.path.join(save_dir, f'lime_weights_s{num_samples}*.npy'))
+    return compute_normlime_weights(fname_list, save_dir, num_samples)
+
+
+def save_one_lime_predict_and_kmean_labels(lime_exp_all_weights, image_pred_labels, cluster_labels, save_path):
+
+    lime_weights = {}
+    for label in image_pred_labels:
+        lime_weights[label] = lime_exp_all_weights[label]
+
+    for_normlime_weights = {
+        'lime_weights': lime_weights,  # a dict: class_label: (seg_label, weight)
+        'cluster': cluster_labels  # a list with segments as indices.
+    }
+
+    np.save(save_path, for_normlime_weights)
+
+
+def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir):
+    kmeans_model = load_kmeans_model(h_pre_models_kmeans)
+
+    for data_index, each_data_ in enumerate(list_data_):
+        if isinstance(each_data_, str):
+            save_path = f"lime_weights_s{num_samples}_{each_data_.split('/')[-1].split('.')[0]}.npy"
+            save_path = os.path.join(save_dir, save_path)
+        else:
+            save_path = f"lime_weights_s{num_samples}_{data_index}.npy"
+            save_path = os.path.join(save_dir, save_path)
+
+        if os.path.exists(save_path):
+            print(f'{save_path} exists, not computing this one.')
+            continue
+
+        print('processing', each_data_ if isinstance(each_data_, str) else data_index,
+              f', {data_index}/{len(list_data_)}')
+
+        image_show = read_image(each_data_)
+        result = predict_fn(image_show)
+        result = result[0]  # only one image here.
+
+        if abs(np.sum(result) - 1.0) > 1e-4:
+            # softmax
+            exp_result = np.exp(result)
+            probability = exp_result / np.sum(exp_result)
+        else:
+            probability = result
+
+        pred_label = np.argsort(probability)[::-1]
+
+        # top_k = argmin(top_n) > threshold
+        threshold = 0.05
+        top_k = 0
+        for l in pred_label:
+            if probability[l] < threshold or top_k == 5:
+                break
+            top_k += 1
+
+        if top_k == 0:
+            top_k = 1
+
+        pred_label = pred_label[:top_k]
+
+        algo = lime_base.LimeImageExplainer()
+        explainer = algo.explain_instance(image_show[0], predict_fn, pred_label, 0,
+                                          num_samples=num_samples, batch_size=batch_size)
+
+        cluster_labels = kmeans_model.predict(
+            get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), explainer.segments)
+        )
+        save_one_lime_predict_and_kmean_labels(
+            explainer.local_exp, pred_label,
+            cluster_labels,
+            save_path
+        )
+
+
+def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples):
+    normlime_weights_all_labels = {}
+    for f in a_list_lime_fnames:
+        try:
+            lime_weights_and_cluster = np.load(f, allow_pickle=True).item()
+            lime_weights = lime_weights_and_cluster['lime_weights']
+            cluster = lime_weights_and_cluster['cluster']
+        except:
+            print('When loading precomputed LIME result, skipping', f)
+            continue
+        print('Loading precomputed LIME result,', f)
+
+        pred_labels = lime_weights.keys()
+        for y in pred_labels:
+            normlime_weights = normlime_weights_all_labels.get(y, {})
+            w_f_y = [abs(w[1]) for w in lime_weights[y]]
+            w_f_y_l1norm = sum(w_f_y)
+
+            for w in lime_weights[y]:
+                seg_label = w[0]
+                weight = w[1] * w[1] / w_f_y_l1norm
+                a = normlime_weights.get(cluster[seg_label], [])
+                a.append(weight)
+                normlime_weights[cluster[seg_label]] = a
+
+            normlime_weights_all_labels[y] = normlime_weights
+
+    # compute normlime
+    for y in normlime_weights_all_labels:
+        normlime_weights = normlime_weights_all_labels.get(y, {})
+        for k in normlime_weights:
+            normlime_weights[k] = sum(normlime_weights[k]) / len(normlime_weights[k])
+
+    # check normlime
+    if len(normlime_weights_all_labels.keys()) < max(normlime_weights_all_labels.keys()) + 1:
+        print(
+            "\n"
+            "Warning: !!! \n"
+            f"There are at least {max(normlime_weights_all_labels.keys()) + 1} classes, "
+            f"but the NormLIME has results of only {len(normlime_weights_all_labels.keys())} classes. \n"
+            "It may have cause unstable results in the later computation"
+            " but can be improved by computing more test samples."
+            "\n"
+        )
+
+    n = 0
+    f_out = f'normlime_weights_s{lime_num_samples}_samples_{len(a_list_lime_fnames)}-{n}.npy'
+    while os.path.exists(
+            os.path.join(save_dir, f_out)
+    ):
+        n += 1
+        f_out = f'normlime_weights_s{lime_num_samples}_samples_{len(a_list_lime_fnames)}-{n}.npy'
+        continue
+
+    np.save(
+        os.path.join(save_dir, f_out),
+        normlime_weights_all_labels
+    )
+    return os.path.join(save_dir, f_out)
+

BIN
paddlex/cv/models/explanation/pre_models/bnv1_1_mean


BIN
paddlex/cv/models/explanation/pre_models/bnv1_1_offset


BIN
paddlex/cv/models/explanation/pre_models/bnv1_1_scale


BIN
paddlex/cv/models/explanation/pre_models/bnv1_1_variance


BIN
paddlex/cv/models/explanation/pre_models/bnv1_2_mean


BIN
paddlex/cv/models/explanation/pre_models/bnv1_2_offset


BIN
paddlex/cv/models/explanation/pre_models/bnv1_2_scale


BIN
paddlex/cv/models/explanation/pre_models/bnv1_2_variance


BIN
paddlex/cv/models/explanation/pre_models/bnv1_3_mean


BIN
paddlex/cv/models/explanation/pre_models/bnv1_3_offset


BIN
paddlex/cv/models/explanation/pre_models/bnv1_3_scale


BIN
paddlex/cv/models/explanation/pre_models/bnv1_3_variance


BIN
paddlex/cv/models/explanation/pre_models/conv1_1_weights


BIN
paddlex/cv/models/explanation/pre_models/conv1_2_weights


BIN
paddlex/cv/models/explanation/pre_models/conv1_3_weights


BIN
paddlex/cv/models/explanation/pre_models/kmeans_model.pkl


BIN
paddlex/cv/models/explanation/pre_models/normlime_weights_imagenet_resnet50vc.npy


+ 57 - 0
paddlex/cv/models/explanation/visualize.py

@@ -18,10 +18,12 @@ import copy
 import os.path as osp
 import numpy as np
 from .core.explanation import Explanation
+from .core.normlime_base import precompute_normlime_weights
 
 
 def visualize(img_file, 
               model, 
+              normlime_dataset=None,
               explanation_type='lime',
               num_samples=3000, 
               batch_size=50,
@@ -36,6 +38,12 @@ def visualize(img_file,
     explaier = None
     if explanation_type == 'lime':
         explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size)
+    elif explanation_type == 'normlime':
+        if normlime_dataset is None:
+            raise Exception('The normlime_dataset is None. Cannot implement this kind of explanation')
+        explaier = get_normlime_explaier(img, model, normlime_dataset, 
+                                     num_samples=num_samples, batch_size=batch_size,
+                                     save_dir=save_dir)
     else:
         raise Exception('The {} explanantion method is not supported yet!'.format(explanation_type))
     img_name = osp.splitext(osp.split(img_file)[-1])[0]
@@ -45,6 +53,8 @@ def visualize(img_file,
 def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
     def predict_func(image):
         image = image.astype('float32')
+        for i in range(image.shape[0]):
+            image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
         model.test_transforms.transforms = model.test_transforms.transforms[-2:]
         out = model.explanation_predict(image)
         return out[0]
@@ -53,4 +63,51 @@ def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
                             num_samples=num_samples, 
                             batch_size=batch_size)
     return explaier
+
+
+def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'):
+    def precompute_predict_func(image):
+        image = image.astype('float32')
+        model.test_transforms.transforms = model.test_transforms.transforms[-2:]
+        out = model.explanation_predict(image)
+        return out[0]
+    def predict_func(image):
+        image = image.astype('float32')
+        for i in range(image.shape[0]):
+            image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
+        model.test_transforms.transforms = model.test_transforms.transforms[-2:]
+        out = model.explanation_predict(image)
+        return out[0]
+    root_path = os.environ['HOME']
+    root_path = osp.join(root_path, '.paddlex')
+    pre_models_path = osp.join(root_path, "pre_models")
+    if not osp.exists(pre_models_path):
+        os.makedirs(pre_models_path)
+        # TODO
+        # paddlex.utils.download_and_decompress(url, path=pre_models_path)
+    npy_dir = precompute_for_normlime(precompute_predict_func, 
+                                      normlime_dataset, 
+                                      num_samples=num_samples, 
+                                      batch_size=batch_size,
+                                      save_dir=save_dir)
+    explaier = Explanation('normlime', 
+                            predict_func,
+                            num_samples=num_samples, 
+                            batch_size=batch_size,
+                            normlime_weights=npy_dir)
+    return explaier
+
+
+def precompute_for_normlime(predict_func, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'):
+    image_list = []
+    for item in normlime_dataset.file_list:
+        image_list.append(item[0])
+    return precompute_normlime_weights(
+            image_list,  
+            predict_func,
+            num_samples=num_samples, 
+            batch_size=batch_size,
+            save_dir=save_dir)
+
+
     

BIN
paddlex/cv/nets/__pycache__/__init__.cpython-37.pyc


BIN
paddlex/cv/nets/__pycache__/backbone_utils.cpython-37.pyc


BIN
paddlex/cv/nets/__pycache__/darknet.cpython-37.pyc


BIN
paddlex/cv/nets/__pycache__/densenet.cpython-37.pyc


BIN
paddlex/cv/nets/__pycache__/mobilenet_v1.cpython-37.pyc


BIN
paddlex/cv/nets/__pycache__/mobilenet_v2.cpython-37.pyc


BIN
paddlex/cv/nets/__pycache__/mobilenet_v3.cpython-37.pyc


BIN
paddlex/cv/nets/__pycache__/resnet.cpython-37.pyc


BIN
paddlex/cv/nets/__pycache__/shufflenet_v2.cpython-37.pyc


BIN
paddlex/cv/nets/__pycache__/xception.cpython-37.pyc


+ 2 - 5
paddlex/cv/nets/darknet.py

@@ -18,7 +18,6 @@ from __future__ import print_function
 
 import six
 import math
-from collections import OrderedDict
 
 from paddle import fluid
 from paddle.fluid.param_attr import ParamAttr
@@ -137,10 +136,8 @@ class DarkNet(object):
     def __call__(self, input):
         """
         Get the backbone of DarkNet, that is output for the 5 stages.
-
         Args:
             input (Variable): input variable.
-
         Returns:
             The last variables of each stage.
         """
@@ -183,6 +180,6 @@ class DarkNet(object):
                     initializer=fluid.initializer.Uniform(-stdv, stdv),
                     name='fc_weights'),
                 bias_attr=ParamAttr(name='fc_offset'))
-            return OrderedDict([('logits', out)])
+            return out
 
-        return blocks
+        return blocks

+ 2 - 3
paddlex/cv/nets/densenet.py

@@ -15,7 +15,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from collections import OrderedDict
 import paddle
 import paddle.fluid as fluid
 import math
@@ -97,7 +96,7 @@ class DenseNet(object):
                     initializer=fluid.initializer.Uniform(-stdv, stdv),
                     name="fc_weights"),
                 bias_attr=ParamAttr(name='fc_offset'))
-            return OrderedDict([('logits', out)])
+            return out
 
     def make_transition(self, input, num_output_features, name=None):
         bn_ac = fluid.layers.batch_norm(
@@ -174,4 +173,4 @@ class DenseNet(object):
             bn_ac_conv = fluid.layers.dropout(
                 x=bn_ac_conv, dropout_prob=dropout)
         bn_ac_conv = fluid.layers.concat([input, bn_ac_conv], axis=1)
-        return bn_ac_conv
+        return bn_ac_conv

BIN
paddlex/cv/nets/detection/__pycache__/__init__.cpython-37.pyc


BIN
paddlex/cv/nets/detection/__pycache__/bbox_head.cpython-37.pyc


BIN
paddlex/cv/nets/detection/__pycache__/faster_rcnn.cpython-37.pyc


BIN
paddlex/cv/nets/detection/__pycache__/fpn.cpython-37.pyc


BIN
paddlex/cv/nets/detection/__pycache__/mask_head.cpython-37.pyc


BIN
paddlex/cv/nets/detection/__pycache__/mask_rcnn.cpython-37.pyc


BIN
paddlex/cv/nets/detection/__pycache__/roi_extractor.cpython-37.pyc


BIN
paddlex/cv/nets/detection/__pycache__/rpn_head.cpython-37.pyc


BIN
paddlex/cv/nets/detection/__pycache__/yolo_v3.cpython-37.pyc


+ 2 - 4
paddlex/cv/nets/mobilenet_v1.py

@@ -16,7 +16,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from collections import OrderedDict
 from paddle import fluid
 from paddle.fluid.param_attr import ParamAttr
 from paddle.fluid.regularizer import L2Decay
@@ -25,7 +24,6 @@ from paddle.fluid.regularizer import L2Decay
 class MobileNetV1(object):
     """
     MobileNet v1, see https://arxiv.org/abs/1704.04861
-
     Args:
         norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
         norm_decay (float): weight decay for normalization layer weights
@@ -197,7 +195,7 @@ class MobileNetV1(object):
                 param_attr=ParamAttr(
                     initializer=fluid.initializer.MSRA(), name="fc7_weights"),
                 bias_attr=ParamAttr(name="fc7_offset"))
-            return OrderedDict([('logits', out)])
+            return output
 
         if not self.with_extra_blocks:
             return blocks
@@ -215,4 +213,4 @@ class MobileNetV1(object):
         module17 = self._extra_block(module16, num_filters[3][0],
                                      num_filters[3][1], 1, 2,
                                      self.prefix_name + "conv7_4")
-        return module11, module13, module14, module15, module16, module17
+        return module11, module13, module14, module15, module16, module17

+ 1 - 3
paddlex/cv/nets/mobilenet_v2.py

@@ -14,7 +14,6 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
-from collections import OrderedDict
 import paddle.fluid as fluid
 from paddle.fluid.param_attr import ParamAttr
 
@@ -110,7 +109,6 @@ class MobileNetV2:
                 size=self.num_classes,
                 param_attr=ParamAttr(name='fc10_weights'),
                 bias_attr=ParamAttr(name='fc10_offset'))
-            return OrderedDict([('logits', output)])
         return output
 
     def modify_bottle_params(self, output_stride=None):
@@ -241,4 +239,4 @@ class MobileNetV2:
                 padding=1,
                 expansion_factor=t,
                 name=name + '_' + str(i + 1))
-        return last_residual_block, depthwise_output
+        return last_residual_block, depthwise_output

+ 1 - 2
paddlex/cv/nets/mobilenet_v3.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import OrderedDict
 import paddle.fluid as fluid
 from paddle.fluid.param_attr import ParamAttr
 from paddle.fluid.regularizer import L2Decay
@@ -328,7 +327,7 @@ class MobileNetV3():
                                   size=self.num_classes,
                                   param_attr=ParamAttr(name='fc_weights'),
                                   bias_attr=ParamAttr(name='fc_offset'))            
-            return OrderedDict([('logits', out)])
+            return out
 
         if not self.with_extra_blocks:
             return blocks

+ 1 - 1
paddlex/cv/nets/resnet.py

@@ -474,7 +474,7 @@ class ResNet(object):
                 size=self.num_classes,
                 param_attr=fluid.param_attr.ParamAttr(
                     initializer=fluid.initializer.Uniform(-stdv, stdv)))
-            return OrderedDict([('logits', out)])
+            return out
 
         return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
                             for idx, feat in enumerate(res_endpoints)])

BIN
paddlex/cv/nets/segmentation/__pycache__/__init__.cpython-37.pyc


BIN
paddlex/cv/nets/segmentation/__pycache__/deeplabv3p.cpython-37.pyc


BIN
paddlex/cv/nets/segmentation/__pycache__/unet.cpython-37.pyc


BIN
paddlex/cv/nets/segmentation/model_utils/__pycache__/__init__.cpython-37.pyc


BIN
paddlex/cv/nets/segmentation/model_utils/__pycache__/libs.cpython-37.pyc


BIN
paddlex/cv/nets/segmentation/model_utils/__pycache__/loss.cpython-37.pyc


+ 1 - 3
paddlex/cv/nets/shufflenet_v2.py

@@ -15,7 +15,6 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
-from collections import OrderedDict
 import paddle.fluid as fluid
 from paddle.fluid.initializer import MSRA
 from paddle.fluid.param_attr import ParamAttr
@@ -102,7 +101,6 @@ class ShuffleNetV2():
                 size=self.num_classes,
                 param_attr=ParamAttr(initializer=MSRA(), name='fc6_weights'),
                 bias_attr=ParamAttr(name='fc6_offset'))
-            return OrderedDict([('logits', output)])
         return output
 
     def conv_bn_layer(self,
@@ -271,4 +269,4 @@ class ShuffleNetV2():
                 name='stage_' + name + '_conv3')
             out = fluid.layers.concat([conv_linear_1, conv_linear_2], axis=1)
 
-        return self.channel_shuffle(out, 2)
+        return self.channel_shuffle(out, 2)

+ 2 - 3
paddlex/cv/nets/xception.py

@@ -19,7 +19,6 @@ from __future__ import print_function
 import contextlib
 import paddle
 import math
-from collections import OrderedDict
 import paddle.fluid as fluid
 from .segmentation.model_utils.libs import scope, name_scope
 from .segmentation.model_utils.libs import bn, bn_relu, relu
@@ -105,7 +104,7 @@ class Xception():
                         initializer=fluid.initializer.Uniform(-stdv, stdv)),
                     bias_attr=fluid.param_attr.ParamAttr(name='bias'))
 
-            return OrderedDict([('logits', out)])
+            return out
         else:
             return data
 
@@ -330,4 +329,4 @@ def xception_41(num_classes=None):
 
 def xception_71(num_classes=None):
     model = Xception(num_classes, 71)
-    return model
+    return model