Explorar o código

fix the prune

sunyanfang01 %!s(int64=5) %!d(string=hai) anos
pai
achega
570fdf6cda
Modificáronse 1 ficheiros con 29 adicións e 0 borrados
  1. 29 0
      paddlex/cv/models/slim/prune.py

+ 29 - 0
paddlex/cv/models/slim/prune.py

@@ -116,6 +116,21 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False):
     Returns:
         paddle.fluid.Program: 裁剪后的Program。
     """
+    prog_var_shape_dict = {}
+    for var in program.list_vars():
+        try:
+            prog_var_shape_dict[var.name] = var.shape
+        except Exception:
+            pass
+    index = 0
+    for param, ratio in zip(prune_names, prune_ratios):
+        origin_num = prog_var_shape_dict[param][0]
+        pruned_num = int(round(origin_num * ratio))
+        while origin_num == pruned_num:
+            ratio -= 0.1
+            pruned_num = int(round(origin_num * (ratio)))
+            prune_ratios[index] = ratio
+        index += 1
     scope = fluid.global_scope()
     pruner = Pruner()
     program, _, _ = pruner.prune(
@@ -266,6 +281,7 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
     sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
     params_ratios = paddleslim.prune.get_ratios_by_loss(
         sensitivitives, eval_metric_loss)
+    
     return params_ratios
 
 
@@ -284,6 +300,19 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05):
     """
     prune_params_ratios = get_params_ratios(sensitivities_file,
                                             eval_metric_loss)
+    prog_var_shape_dict = {}
+    for var in program.list_vars():
+        try:
+            prog_var_shape_dict[var.name] = var.shape
+        except Exception:
+            pass
+    for param, ratio in prune_params_ratios.items():
+        origin_num = prog_var_shape_dict[param][0]
+        pruned_num = int(round(origin_num * ratio))
+        while origin_num == pruned_num:
+            ratio -= 0.1
+            pruned_num = int(round(origin_num * (ratio)))
+            prune_params_ratios[param] = ratio
     prune_program = channel_prune(
         program,
         list(prune_params_ratios.keys()),