|
@@ -158,7 +158,7 @@ def prune_program(model, prune_params_ratios=None):
|
|
|
prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
|
|
prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
|
|
|
使用默认裁剪参数名和裁剪率。默认为None。
|
|
使用默认裁剪参数名和裁剪率。默认为None。
|
|
|
"""
|
|
"""
|
|
|
- assert model.status == 'Normal', 'Only the model after training can be pruned!'
|
|
|
|
|
|
|
+ assert model.status == 'Normal', 'Only the models saved while training are supported!'
|
|
|
place = model.places[0]
|
|
place = model.places[0]
|
|
|
train_prog = model.train_prog
|
|
train_prog = model.train_prog
|
|
|
eval_prog = model.test_prog
|
|
eval_prog = model.test_prog
|
|
@@ -236,7 +236,7 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
|
|
|
|
|
|
|
|
其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
|
|
其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
|
|
|
"""
|
|
"""
|
|
|
- assert model.status == 'Normal', 'Only the model after training can calculate sensitivities data!'
|
|
|
|
|
|
|
+ assert model.status == 'Normal', 'Only the models saved while training are supported!'
|
|
|
if os.path.exists(save_file):
|
|
if os.path.exists(save_file):
|
|
|
os.remove(save_file)
|
|
os.remove(save_file)
|
|
|
|
|
|