|
|
@@ -66,16 +66,15 @@ def sensitivity(program,
|
|
|
progress = "%.2f%%" % (progress * 100)
|
|
|
logging.info(
|
|
|
"Total evaluate iters={}, current={}, progress={}, eta={}".
|
|
|
- format(
|
|
|
- total_evaluate_iters, current_iter, progress,
|
|
|
- seconds_to_hms(
|
|
|
- int(cost * (total_evaluate_iters - current_iter)))),
|
|
|
+ format(total_evaluate_iters, current_iter, progress,
|
|
|
+ seconds_to_hms(
|
|
|
+ int(cost * (total_evaluate_iters - current_iter)))),
|
|
|
use_color=True)
|
|
|
current_iter += 1
|
|
|
|
|
|
pruner = Pruner()
|
|
|
- logging.info("sensitive - param: {}; ratios: {}".format(
|
|
|
- name, ratio))
|
|
|
+ logging.info("sensitive - param: {}; ratios: {}".format(name,
|
|
|
+ ratio))
|
|
|
pruned_program, param_backup, _ = pruner.prune(
|
|
|
program=graph.program,
|
|
|
scope=scope,
|
|
|
@@ -87,8 +86,8 @@ def sensitivity(program,
|
|
|
param_backup=True)
|
|
|
pruned_metric = eval_func(pruned_program)
|
|
|
loss = (baseline - pruned_metric) / baseline
|
|
|
- logging.info("pruned param: {}; {}; loss={}".format(
|
|
|
- name, ratio, loss))
|
|
|
+ logging.info("pruned param: {}; {}; loss={}".format(name, ratio,
|
|
|
+ loss))
|
|
|
|
|
|
sensitivities[name][ratio] = loss
|
|
|
|
|
|
@@ -221,6 +220,9 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
|
|
|
|
|
|
其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
|
|
|
"""
|
|
|
+ if os.path.exists(save_file):
|
|
|
+ os.remove(save_file)
|
|
|
+
|
|
|
prune_names = get_prune_params(model)
|
|
|
|
|
|
def eval_for_prune(program):
|
|
|
@@ -264,8 +266,8 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
|
|
|
if not osp.exists(sensitivities_file):
|
|
|
raise Exception('The sensitivities file is not exists!')
|
|
|
sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
|
|
|
- params_ratios = paddleslim.prune.get_ratios_by_loss(
|
|
|
- sensitivitives, eval_metric_loss)
|
|
|
+ params_ratios = paddleslim.prune.get_ratios_by_loss(sensitivitives,
|
|
|
+ eval_metric_loss)
|
|
|
return params_ratios
|
|
|
|
|
|
|