|
|
4 жил өмнө | |
|---|---|---|
| .. | ||
| README.md | 4 жил өмнө | |
| mobilenetv2_prune.py | 4 жил өмнө | |
| mobilenetv2_train.py | 4 жил өмнө | |
python mobilenetv2_train.py
在此步骤中,训练的模型会保存在output/mobilenet_v2目录下
python mobilenetv2_prune.py
mobilenetv2_prune.py中主要执行了以下API:
step 1: 分析模型各层参数在不同的剪裁比例下的敏感度
主要由两个API完成:
model = pdx.load_model('output/mobilenet_v2/best_model')
model.analyze_sensitivity(
dataset=eval_dataset, save_dir='output/mobilenet_v2/prune')
参数分析完后,output/mobilenet_v2/prune目录下会得到model.sensi.data文件,此文件保存了不同剪裁比例下各层参数的敏感度信息。
注意: 如果之前运行过该步骤,第二次运行时会自动加载已有的output/mobilenet_v2/prune/model.sensi.data,不再进行敏感度分析。
step 2: 根据选择的FLOPs减小比例对模型进行剪裁
model.prune(pruned_flops=.2, save_dir=None)
注意: 如果想直接保存剪裁完的模型参数,设置save_dir即可。但我们强烈建议对剪裁过的模型重新进行训练,以保证模型精度损失能尽可能少。
step 3: 对剪裁后的模型重新训练
model.train(
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=32,
eval_dataset=eval_dataset,
lr_decay_epochs=[4, 6, 8],
learning_rate=0.025,
pretrain_weights=None,
save_dir='output/mobilenet_v2/prune',
use_vdl=True)
重新训练后的模型保存在output/mobilenet_v2/prune。
注意: 重新训练时需将pretrain_weights设置为None,否则模型会加载pretrain_weights指定的预训练模型参数。