|
|
@@ -34,8 +34,12 @@ def sensitivity(program,
|
|
|
param_names,
|
|
|
eval_func,
|
|
|
sensitivities_file=None,
|
|
|
- pruned_ratios=None):
|
|
|
- scope = fluid.global_scope()
|
|
|
+ pruned_ratios=None,
|
|
|
+ scope=None):
|
|
|
+ if scope is None:
|
|
|
+ scope = fluid.global_scope()
|
|
|
+ else:
|
|
|
+ scope = scope
|
|
|
graph = GraphWrapper(program)
|
|
|
sensitivities = load_sensitivities(sensitivities_file)
|
|
|
|
|
|
@@ -256,7 +260,8 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
|
|
|
prune_names,
|
|
|
eval_for_prune,
|
|
|
sensitivities_file=save_file,
|
|
|
- pruned_ratios=list(np.arange(0.1, 1, 0.1)))
|
|
|
+ pruned_ratios=list(np.arange(0.1, 1, 0.1)),
|
|
|
+ scope=model.scope)
|
|
|
return sensitivitives
|
|
|
|
|
|
|