|
|
@@ -104,7 +104,7 @@ def sensitivity(program,
|
|
|
return sensitivities
|
|
|
|
|
|
|
|
|
-def channel_prune(program, prune_names, prune_ratios, place, only_graph=False):
|
|
|
+def channel_prune(program, prune_names, prune_ratios, place, only_graph=False, scope=None):
|
|
|
"""通道裁剪。
|
|
|
|
|
|
Args:
|
|
|
@@ -134,7 +134,8 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False):
|
|
|
pruned_num = int(round(origin_num * (ratio)))
|
|
|
prune_ratios[index] = ratio
|
|
|
index += 1
|
|
|
- scope = fluid.global_scope()
|
|
|
+ if scope is None:
|
|
|
+ scope = fluid.global_scope()
|
|
|
pruner = Pruner()
|
|
|
program, _, _ = pruner.prune(
|
|
|
program,
|
|
|
@@ -175,12 +176,12 @@ def prune_program(model, prune_params_ratios=None):
|
|
|
prune_params_ratios[prune_name] for prune_name in prune_names
|
|
|
]
|
|
|
model.train_prog = channel_prune(train_prog, prune_names, prune_ratios,
|
|
|
- place)
|
|
|
+ place, scope=model.scope)
|
|
|
model.test_prog = channel_prune(
|
|
|
- eval_prog, prune_names, prune_ratios, place, only_graph=True)
|
|
|
+ eval_prog, prune_names, prune_ratios, place, only_graph=True, scope=model.scope)
|
|
|
|
|
|
|
|
|
-def update_program(program, model_dir, place):
|
|
|
+def update_program(program, model_dir, place, scope=None):
|
|
|
"""根据裁剪信息更新Program和参数。
|
|
|
|
|
|
Args:
|
|
|
@@ -197,10 +198,12 @@ def update_program(program, model_dir, place):
|
|
|
shapes = yaml.load(f.read(), Loader=yaml.Loader)
|
|
|
for param, shape in shapes.items():
|
|
|
graph.var(param).set_shape(shape)
|
|
|
+ if scope is None:
|
|
|
+ scope = fluid.global_scope()
|
|
|
for block in program.blocks:
|
|
|
for param in block.all_parameters():
|
|
|
if param.name in shapes:
|
|
|
- param_tensor = fluid.global_scope().find_var(
|
|
|
+ param_tensor = scope.find_var(
|
|
|
param.name).get_tensor()
|
|
|
param_tensor.set(
|
|
|
np.zeros(list(shapes[param.name])).astype('float32'),
|
|
|
@@ -293,7 +296,7 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
|
|
|
return params_ratios
|
|
|
|
|
|
|
|
|
-def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05):
|
|
|
+def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05, scope=None):
|
|
|
"""在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。
|
|
|
|
|
|
Args:
|
|
|
@@ -326,7 +329,8 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05):
|
|
|
list(prune_params_ratios.keys()),
|
|
|
list(prune_params_ratios.values()),
|
|
|
place,
|
|
|
- only_graph=True)
|
|
|
+ only_graph=True,
|
|
|
+ scope=scope)
|
|
|
origin_size = 0
|
|
|
new_size = 0
|
|
|
for var in program.list_vars():
|