jiangjiajun 5 жил өмнө
parent
commit
305e50943a

+ 51 - 0
tutorials/slim/prune/image_classification/README.md

@@ -0,0 +1,51 @@
+# 图像分类模型裁剪训练
+
+## 第一步 正常训练图像分类模型
+
+```
+python mobilenetv2_train.py
+```
+
+在此步骤中,训练的模型会保存在`output/mobilenetv2`目录下
+
+## 第二步 分析模型参数信息
+
+```
+python param_analysis.py
+```
+参数分析完后,会得到`mobilenetv2.sensi.data`文件,此文件保存了各参数的敏感度信息。  
+
+> 我们可以继续加载模型和敏感度文件,进行可视化,如下命令所示
+> ```
+> python slim_visualize.py
+> ```
+> 可视化结果出下图
+纵轴为`eval_metric_loss`(接下来第三步需要配置的参数),横轴为模型被裁剪的比例,从图中可以看到,  
+- 当`eval_metric_loss`设0.05时,模型被裁掉68.4%(剩余31.6%)  
+- 当`eval_metric_loss`设0.1时,模型被裁掉78.5%(剩余21.5%)
+
+![](./sensitivities.png)
+
+## 第三步 模型进行裁剪训练
+
+```
+python mobilenetv2_prune_train.py
+```
+此步骤的代码与第一步的代码基本一致,唯一的区别是在最后的train函数中,`mobilenetv2_prune_train.py`修改了里面的`pretrain_weights`、`save_dir`、`sensitivities_file`和`eval_metric_loss`四个参数
+
+- pretrain_weights: 在裁剪训练中,设置为之前训练好的模型
+- save_dir: 模型训练过程中,模型的保存位置
+- sensitivities_file: 在第二步中分析得到的参数敏感度信息文件
+- eval_metric_loss: 第二步中可视化的相关参数,通过此参数可相应的改变最终模型被裁剪的比例
+
+
+## 裁剪效果
+
+在本示例数据上,裁剪效果对比如下,其中预测采用**CPU,关闭MKLDNN**进行预测,预测时间不包含数据的预处理和结果的后处理。  
+可以看到在模型被裁剪掉64%后,模型精度还有上升,单张图片的预测用时减少了37%。
+
+
+| 模型 | 参数文件大小 | 预测速度 | 准确率 |
+| :--- | :----------  | :------- | :--- |
+| MobileNetV2 |    8.7M       |   0.057s  | 0.92 |
+| MobileNetV2(裁掉68%) | 2.8M | 0.036s | 0.99 |

+ 44 - 0
tutorials/slim/prune/image_classification/mobilenetv2_prune_train.py

@@ -0,0 +1,44 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.cls import transforms
+import paddlex as pdx
+
+veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
+pdx.utils.download_and_decompress(veg_dataset, path='./')
+
+train_transforms = transforms.Compose([
+    transforms.RandomCrop(crop_size=224), transforms.RandomHorizontalFlip(),
+    transforms.Normalize()
+])
+eval_transforms = transforms.Compose([
+    transforms.ResizeByShort(short_size=256),
+    transforms.CenterCrop(crop_size=224), transforms.Normalize()
+])
+
+train_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/train_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/val_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=eval_transforms)
+
+model = pdx.cls.MobileNetV2(num_classes=len(train_dataset.labels))
+
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=32,
+    eval_dataset=eval_dataset,
+    lr_decay_epochs=[4, 6, 8],
+    learning_rate=0.025,
+    pretrain_weights='output/mobilenetv2/best_model',
+    save_dir='output/mobilenetv2_prune',
+    sensitivities_file='./mobilenetv2.sensi.data',
+    eval_metric_loss=0.05,
+    use_vdl=True)

+ 41 - 0
tutorials/slim/prune/image_classification/mobilenetv2_train.py

@@ -0,0 +1,41 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.cls import transforms
+import paddlex as pdx
+
+veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
+pdx.utils.download_and_decompress(veg_dataset, path='./')
+
+train_transforms = transforms.Compose([
+    transforms.RandomCrop(crop_size=224), transforms.RandomHorizontalFlip(),
+    transforms.Normalize()
+])
+eval_transforms = transforms.Compose([
+    transforms.ResizeByShort(short_size=256),
+    transforms.CenterCrop(crop_size=224), transforms.Normalize()
+])
+
+train_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/train_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/val_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=eval_transforms)
+
+model = pdx.cls.MobileNetV2(num_classes=len(train_dataset.labels))
+
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=32,
+    eval_dataset=eval_dataset,
+    lr_decay_epochs=[4, 6, 8],
+    learning_rate=0.025,
+    save_dir='output/mobilenetv2',
+    use_vdl=True)

+ 17 - 0
tutorials/slim/prune/image_classification/params_analysis.py

@@ -0,0 +1,17 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+import paddlex as pdx
+
+model = pdx.load_model('output/mobilenetv2/best_model')
+
+eval_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/val_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=model.eval_transforms)
+
+pdx.slim.prune.analysis(
+    model,
+    dataset=eval_dataset,
+    batch_size=16,
+    save_file='mobilenetv2.sensi.data')

BIN
tutorials/slim/prune/image_classification/sensitivities.png


+ 3 - 0
tutorials/slim/prune/image_classification/slim_visualize.py

@@ -0,0 +1,3 @@
+import paddlex as pdx
+model = pdx.load_model('output/mobilenetv2/best_model')
+pdx.slim.visualize(model, 'mobilenetv2.sensi.data', save_dir='./')