Pārlūkot izejas kodu

for sklrearn 0.23

sunyanfang01 5 gadi atpakaļ
vecāks
revīzija
3aff3f2fcf
1 mainītis faili ar 6 papildinājumiem un 3 dzēšanām
  1. 6 3
      paddlex/interpret/core/normlime_base.py

+ 6 - 3
paddlex/interpret/core/normlime_base.py

@@ -150,9 +150,12 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
         interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0,
                                           num_samples=num_samples, batch_size=batch_size)
 
-        cluster_labels = kmeans_model.predict(
-            get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments)
-        )
+        X = get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments)
+        try:
+            cluster_labels = kmeans_model.predict(X)
+        except AttributeError:
+            from sklearn.metrics import pairwise_distances_argmin_min
+            cluster_labels, _ = pairwise_distances_argmin_min(X, kmeans_model.cluster_centers_)
         save_one_lime_predict_and_kmean_labels(
             interpreter.local_weights, pred_label,
             cluster_labels,