|
|
@@ -116,6 +116,21 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False):
|
|
|
Returns:
|
|
|
paddle.fluid.Program: 裁剪后的Program。
|
|
|
"""
|
|
|
+ prog_var_shape_dict = {}
|
|
|
+ for var in program.list_vars():
|
|
|
+ try:
|
|
|
+ prog_var_shape_dict[var.name] = var.shape
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ index = 0
|
|
|
+ for param, ratio in zip(prune_names, prune_ratios):
|
|
|
+ origin_num = prog_var_shape_dict[param][0]
|
|
|
+ pruned_num = int(round(origin_num * ratio))
|
|
|
+ while origin_num == pruned_num:
|
|
|
+ ratio -= 0.1
|
|
|
+ pruned_num = int(round(origin_num * (ratio)))
|
|
|
+ prune_ratios[index] = ratio
|
|
|
+ index += 1
|
|
|
scope = fluid.global_scope()
|
|
|
pruner = Pruner()
|
|
|
program, _, _ = pruner.prune(
|
|
|
@@ -266,6 +281,7 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
|
|
|
sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
|
|
|
params_ratios = paddleslim.prune.get_ratios_by_loss(
|
|
|
sensitivitives, eval_metric_loss)
|
|
|
+
|
|
|
return params_ratios
|
|
|
|
|
|
|
|
|
@@ -284,6 +300,19 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05):
|
|
|
"""
|
|
|
prune_params_ratios = get_params_ratios(sensitivities_file,
|
|
|
eval_metric_loss)
|
|
|
+ prog_var_shape_dict = {}
|
|
|
+ for var in program.list_vars():
|
|
|
+ try:
|
|
|
+ prog_var_shape_dict[var.name] = var.shape
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ for param, ratio in prune_params_ratios.items():
|
|
|
+ origin_num = prog_var_shape_dict[param][0]
|
|
|
+ pruned_num = int(round(origin_num * ratio))
|
|
|
+ while origin_num == pruned_num:
|
|
|
+ ratio -= 0.1
|
|
|
+ pruned_num = int(round(origin_num * (ratio)))
|
|
|
+ prune_params_ratios[param] = ratio
|
|
|
prune_program = channel_prune(
|
|
|
program,
|
|
|
list(prune_params_ratios.keys()),
|