|
|
@@ -21,9 +21,6 @@ import os.path as osp
|
|
|
from functools import reduce
|
|
|
import paddle.fluid as fluid
|
|
|
from multiprocessing import Process, Queue
|
|
|
-import paddleslim
|
|
|
-from paddleslim.prune import Pruner, load_sensitivities
|
|
|
-from paddleslim.core import GraphWrapper
|
|
|
from .prune_config import get_prune_params
|
|
|
import paddlex.utils.logging as logging
|
|
|
from paddlex.utils import seconds_to_hms
|
|
|
@@ -36,6 +33,10 @@ def sensitivity(program,
|
|
|
sensitivities_file=None,
|
|
|
pruned_ratios=None,
|
|
|
scope=None):
|
|
|
+ import paddleslim
|
|
|
+ from paddleslim.prune import Pruner, load_sensitivities
|
|
|
+ from paddleslim.core import GraphWrapper
|
|
|
+
|
|
|
if scope is None:
|
|
|
scope = fluid.global_scope()
|
|
|
else:
|
|
|
@@ -104,7 +105,12 @@ def sensitivity(program,
|
|
|
return sensitivities
|
|
|
|
|
|
|
|
|
-def channel_prune(program, prune_names, prune_ratios, place, only_graph=False, scope=None):
|
|
|
+def channel_prune(program,
|
|
|
+ prune_names,
|
|
|
+ prune_ratios,
|
|
|
+ place,
|
|
|
+ only_graph=False,
|
|
|
+ scope=None):
|
|
|
"""通道裁剪。
|
|
|
|
|
|
Args:
|
|
|
@@ -119,6 +125,10 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False, s
|
|
|
Returns:
|
|
|
paddle.fluid.Program: 裁剪后的Program。
|
|
|
"""
|
|
|
+ import paddleslim
|
|
|
+ from paddleslim.prune import Pruner, load_sensitivities
|
|
|
+ from paddleslim.core import GraphWrapper
|
|
|
+
|
|
|
prog_var_shape_dict = {}
|
|
|
for var in program.list_vars():
|
|
|
try:
|
|
|
@@ -163,6 +173,10 @@ def prune_program(model, prune_params_ratios=None):
|
|
|
prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
|
|
|
使用默认裁剪参数名和裁剪率。默认为None。
|
|
|
"""
|
|
|
+ import paddleslim
|
|
|
+ from paddleslim.prune import Pruner, load_sensitivities
|
|
|
+ from paddleslim.core import GraphWrapper
|
|
|
+
|
|
|
assert model.status == 'Normal', 'Only the models saved while training are supported!'
|
|
|
place = model.places[0]
|
|
|
train_prog = model.train_prog
|
|
|
@@ -175,10 +189,15 @@ def prune_program(model, prune_params_ratios=None):
|
|
|
prune_ratios = [
|
|
|
prune_params_ratios[prune_name] for prune_name in prune_names
|
|
|
]
|
|
|
- model.train_prog = channel_prune(train_prog, prune_names, prune_ratios,
|
|
|
- place, scope=model.scope)
|
|
|
+ model.train_prog = channel_prune(
|
|
|
+ train_prog, prune_names, prune_ratios, place, scope=model.scope)
|
|
|
model.test_prog = channel_prune(
|
|
|
- eval_prog, prune_names, prune_ratios, place, only_graph=True, scope=model.scope)
|
|
|
+ eval_prog,
|
|
|
+ prune_names,
|
|
|
+ prune_ratios,
|
|
|
+ place,
|
|
|
+ only_graph=True,
|
|
|
+ scope=model.scope)
|
|
|
|
|
|
|
|
|
def update_program(program, model_dir, place, scope=None):
|
|
|
@@ -193,6 +212,10 @@ def update_program(program, model_dir, place, scope=None):
|
|
|
Returns:
|
|
|
paddle.fluid.Program: 更新后的Program。
|
|
|
"""
|
|
|
+ import paddleslim
|
|
|
+ from paddleslim.prune import Pruner, load_sensitivities
|
|
|
+ from paddleslim.core import GraphWrapper
|
|
|
+
|
|
|
graph = GraphWrapper(program)
|
|
|
with open(osp.join(model_dir, "prune.yml")) as f:
|
|
|
shapes = yaml.load(f.read(), Loader=yaml.Loader)
|
|
|
@@ -203,11 +226,9 @@ def update_program(program, model_dir, place, scope=None):
|
|
|
for block in program.blocks:
|
|
|
for param in block.all_parameters():
|
|
|
if param.name in shapes:
|
|
|
- param_tensor = scope.find_var(
|
|
|
- param.name).get_tensor()
|
|
|
+ param_tensor = scope.find_var(param.name).get_tensor()
|
|
|
param_tensor.set(
|
|
|
- np.zeros(list(shapes[param.name])).astype('float32'),
|
|
|
- place)
|
|
|
+ np.zeros(list(shapes[param.name])).astype('float32'), place)
|
|
|
graph.update_groups_of_conv()
|
|
|
graph.infer_shape()
|
|
|
return program
|
|
|
@@ -243,6 +264,10 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
|
|
|
|
|
|
其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
|
|
|
"""
|
|
|
+ import paddleslim
|
|
|
+ from paddleslim.prune import Pruner, load_sensitivities
|
|
|
+ from paddleslim.core import GraphWrapper
|
|
|
+
|
|
|
assert model.status == 'Normal', 'Only the models saved while training are supported!'
|
|
|
if os.path.exists(save_file):
|
|
|
os.remove(save_file)
|
|
|
@@ -268,6 +293,11 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
|
|
|
return sensitivitives
|
|
|
|
|
|
|
|
|
+def analysis(model, dataset, batch_size=8, save_file='./model.sensi.data'):
|
|
|
+ return cal_params_sensitivities(
|
|
|
+ model, eval_dataset=dataset, batch_size=batch_size, save_file=save_file)
|
|
|
+
|
|
|
+
|
|
|
def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
|
|
|
"""根据设定的精度损失容忍度metric_loss_thresh和计算保存的模型参数敏感度信息文件sensetive_file,
|
|
|
获取裁剪的参数配置。
|
|
|
@@ -288,6 +318,10 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
|
|
|
|
|
|
其中key是卷积Kernel名;value是裁剪率。
|
|
|
"""
|
|
|
+ import paddleslim
|
|
|
+ from paddleslim.prune import Pruner, load_sensitivities
|
|
|
+ from paddleslim.core import GraphWrapper
|
|
|
+
|
|
|
if not osp.exists(sensitivities_file):
|
|
|
raise Exception('The sensitivities file is not exists!')
|
|
|
sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
|
|
|
@@ -296,7 +330,11 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
|
|
|
return params_ratios
|
|
|
|
|
|
|
|
|
-def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05, scope=None):
|
|
|
+def cal_model_size(program,
|
|
|
+ place,
|
|
|
+ sensitivities_file,
|
|
|
+ eval_metric_loss=0.05,
|
|
|
+ scope=None):
|
|
|
"""在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。
|
|
|
|
|
|
Args:
|
|
|
@@ -309,6 +347,10 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05, sc
|
|
|
Returns:
|
|
|
float: 裁剪后模型大小相对于当前模型大小的比例。
|
|
|
"""
|
|
|
+ import paddleslim
|
|
|
+ from paddleslim.prune import Pruner, load_sensitivities
|
|
|
+ from paddleslim.core import GraphWrapper
|
|
|
+
|
|
|
prune_params_ratios = get_params_ratios(sensitivities_file,
|
|
|
eval_metric_loss)
|
|
|
prog_var_shape_dict = {}
|