Browse Source

fix scope bug for prune

jiangjiajun 5 năm trước cách đây
mục cha
commit
7f9101513e
1 tập tin đã thay đổi với 8 bổ sung3 xóa
  1. 8 3
      paddlex/cv/models/slim/prune.py

+ 8 - 3
paddlex/cv/models/slim/prune.py

@@ -34,8 +34,12 @@ def sensitivity(program,
                 param_names,
                 eval_func,
                 sensitivities_file=None,
-                pruned_ratios=None):
-    scope = fluid.global_scope()
+                pruned_ratios=None,
+                scope=None):
+    if scope is None:
+        scope = fluid.global_scope()
+    else:
+        scope = scope
     graph = GraphWrapper(program)
     sensitivities = load_sensitivities(sensitivities_file)
 
@@ -256,7 +260,8 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
         prune_names,
         eval_for_prune,
         sensitivities_file=save_file,
-        pruned_ratios=list(np.arange(0.1, 1, 0.1)))
+        pruned_ratios=list(np.arange(0.1, 1, 0.1)),
+        scope=model.scope)
     return sensitivitives