浏览代码

modify cal sensitivities

jiangjiajun 5 年之前
父节点
当前提交
b66b3a1a41
共有 1 个文件被更改,包括 12 次插入10 次删除
  1. 12 10
      paddlex/cv/models/slim/prune.py

+ 12 - 10
paddlex/cv/models/slim/prune.py

@@ -66,16 +66,15 @@ def sensitivity(program,
             progress = "%.2f%%" % (progress * 100)
             logging.info(
                 "Total evaluate iters={}, current={}, progress={}, eta={}".
-                format(
-                    total_evaluate_iters, current_iter, progress,
-                    seconds_to_hms(
-                        int(cost * (total_evaluate_iters - current_iter)))),
+                format(total_evaluate_iters, current_iter, progress,
+                       seconds_to_hms(
+                           int(cost * (total_evaluate_iters - current_iter)))),
                 use_color=True)
             current_iter += 1
 
             pruner = Pruner()
-            logging.info("sensitive - param: {}; ratios: {}".format(
-                name, ratio))
+            logging.info("sensitive - param: {}; ratios: {}".format(name,
+                                                                    ratio))
             pruned_program, param_backup, _ = pruner.prune(
                 program=graph.program,
                 scope=scope,
@@ -87,8 +86,8 @@ def sensitivity(program,
                 param_backup=True)
             pruned_metric = eval_func(pruned_program)
             loss = (baseline - pruned_metric) / baseline
-            logging.info("pruned param: {}; {}; loss={}".format(
-                name, ratio, loss))
+            logging.info("pruned param: {}; {}; loss={}".format(name, ratio,
+                                                                loss))
 
             sensitivities[name][ratio] = loss
 
@@ -221,6 +220,9 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
 
             其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
     """
+    if os.path.exists(save_file):
+        os.remove(save_file)
+
     prune_names = get_prune_params(model)
 
     def eval_for_prune(program):
@@ -264,8 +266,8 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
     if not osp.exists(sensitivities_file):
         raise Exception('The sensitivities file is not exists!')
     sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
-    params_ratios = paddleslim.prune.get_ratios_by_loss(
-        sensitivitives, eval_metric_loss)
+    params_ratios = paddleslim.prune.get_ratios_by_loss(sensitivitives,
+                                                        eval_metric_loss)
     return params_ratios