浏览代码

Merge pull request #372 from PaddlePaddle/prune_tutorials

add prune tutorials
Jason 5 年之前
父节点
当前提交
089b06af02

+ 54 - 12
paddlex/cv/models/slim/prune.py

@@ -21,9 +21,6 @@ import os.path as osp
 from functools import reduce
 import paddle.fluid as fluid
 from multiprocessing import Process, Queue
-import paddleslim
-from paddleslim.prune import Pruner, load_sensitivities
-from paddleslim.core import GraphWrapper
 from .prune_config import get_prune_params
 import paddlex.utils.logging as logging
 from paddlex.utils import seconds_to_hms
@@ -36,6 +33,10 @@ def sensitivity(program,
                 sensitivities_file=None,
                 pruned_ratios=None,
                 scope=None):
+    import paddleslim
+    from paddleslim.prune import Pruner, load_sensitivities
+    from paddleslim.core import GraphWrapper
+
     if scope is None:
         scope = fluid.global_scope()
     else:
@@ -104,7 +105,12 @@ def sensitivity(program,
     return sensitivities
 
 
-def channel_prune(program, prune_names, prune_ratios, place, only_graph=False, scope=None):
+def channel_prune(program,
+                  prune_names,
+                  prune_ratios,
+                  place,
+                  only_graph=False,
+                  scope=None):
     """通道裁剪。
 
     Args:
@@ -119,6 +125,10 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False, s
     Returns:
         paddle.fluid.Program: 裁剪后的Program。
     """
+    import paddleslim
+    from paddleslim.prune import Pruner, load_sensitivities
+    from paddleslim.core import GraphWrapper
+
     prog_var_shape_dict = {}
     for var in program.list_vars():
         try:
@@ -163,6 +173,10 @@ def prune_program(model, prune_params_ratios=None):
         prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
             使用默认裁剪参数名和裁剪率。默认为None。
     """
+    import paddleslim
+    from paddleslim.prune import Pruner, load_sensitivities
+    from paddleslim.core import GraphWrapper
+
     assert model.status == 'Normal', 'Only the models saved while training are supported!'
     place = model.places[0]
     train_prog = model.train_prog
@@ -175,10 +189,15 @@ def prune_program(model, prune_params_ratios=None):
     prune_ratios = [
         prune_params_ratios[prune_name] for prune_name in prune_names
     ]
-    model.train_prog = channel_prune(train_prog, prune_names, prune_ratios,
-                                     place, scope=model.scope)
+    model.train_prog = channel_prune(
+        train_prog, prune_names, prune_ratios, place, scope=model.scope)
     model.test_prog = channel_prune(
-        eval_prog, prune_names, prune_ratios, place, only_graph=True, scope=model.scope)
+        eval_prog,
+        prune_names,
+        prune_ratios,
+        place,
+        only_graph=True,
+        scope=model.scope)
 
 
 def update_program(program, model_dir, place, scope=None):
@@ -193,6 +212,10 @@ def update_program(program, model_dir, place, scope=None):
     Returns:
         paddle.fluid.Program: 更新后的Program。
     """
+    import paddleslim
+    from paddleslim.prune import Pruner, load_sensitivities
+    from paddleslim.core import GraphWrapper
+
     graph = GraphWrapper(program)
     with open(osp.join(model_dir, "prune.yml")) as f:
         shapes = yaml.load(f.read(), Loader=yaml.Loader)
@@ -203,11 +226,9 @@ def update_program(program, model_dir, place, scope=None):
     for block in program.blocks:
         for param in block.all_parameters():
             if param.name in shapes:
-                param_tensor = scope.find_var(
-                    param.name).get_tensor()
+                param_tensor = scope.find_var(param.name).get_tensor()
                 param_tensor.set(
-                    np.zeros(list(shapes[param.name])).astype('float32'),
-                    place)
+                    np.zeros(list(shapes[param.name])).astype('float32'), place)
     graph.update_groups_of_conv()
     graph.infer_shape()
     return program
@@ -243,6 +264,10 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
 
             其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
     """
+    import paddleslim
+    from paddleslim.prune import Pruner, load_sensitivities
+    from paddleslim.core import GraphWrapper
+
     assert model.status == 'Normal', 'Only the models saved while training are supported!'
     if os.path.exists(save_file):
         os.remove(save_file)
@@ -268,6 +293,11 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
     return sensitivitives
 
 
+def analysis(model, dataset, batch_size=8, save_file='./model.sensi.data'):
+    return cal_params_sensitivities(
+        model, eval_dataset=dataset, batch_size=batch_size, save_file=save_file)
+
+
 def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
     """根据设定的精度损失容忍度metric_loss_thresh和计算保存的模型参数敏感度信息文件sensetive_file,
         获取裁剪的参数配置。
@@ -288,6 +318,10 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
 
             其中key是卷积Kernel名;value是裁剪率。
     """
+    import paddleslim
+    from paddleslim.prune import Pruner, load_sensitivities
+    from paddleslim.core import GraphWrapper
+
     if not osp.exists(sensitivities_file):
         raise Exception('The sensitivities file is not exists!')
     sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
@@ -296,7 +330,11 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
     return params_ratios
 
 
-def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05, scope=None):
+def cal_model_size(program,
+                   place,
+                   sensitivities_file,
+                   eval_metric_loss=0.05,
+                   scope=None):
     """在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。
 
     Args:
@@ -309,6 +347,10 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05, sc
     Returns:
         float: 裁剪后模型大小相对于当前模型大小的比例。
     """
+    import paddleslim
+    from paddleslim.prune import Pruner, load_sensitivities
+    from paddleslim.core import GraphWrapper
+
     prune_params_ratios = get_params_ratios(sensitivities_file,
                                             eval_metric_loss)
     prog_var_shape_dict = {}

+ 5 - 2
paddlex/cv/models/slim/visualize.py

@@ -16,7 +16,6 @@ import os.path as osp
 import tqdm
 import numpy as np
 from .prune import cal_model_size
-from paddleslim.prune import load_sensitivities
 
 
 def visualize(model, sensitivities_file, save_dir='./'):
@@ -42,7 +41,11 @@ def visualize(model, sensitivities_file, save_dir='./'):
     y = list()
     for loss_thresh in tqdm.tqdm(list(np.arange(0.05, 1, 0.05))):
         prune_ratio = 1 - cal_model_size(
-            program, place, sensitivities_file, eval_metric_loss=loss_thresh, scope=model.scope)
+            program,
+            place,
+            sensitivities_file,
+            eval_metric_loss=loss_thresh,
+            scope=model.scope)
         x.append(prune_ratio)
         y.append(loss_thresh)
     plt.plot(x, y, color='green', linewidth=0.5, marker='o', markersize=3)

+ 1 - 1
tutorials/compress/classification/cal_sensitivities_file.py

@@ -39,7 +39,7 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser(description=__doc__)
     parser.add_argument(
         "--model_dir",
-        default="./output/mobilenet/best_model",
+        default="./output/mobilenetv2/best_model",
         type=str,
         help="The model path.")
     parser.add_argument(

+ 51 - 0
tutorials/slim/prune/image_classification/README.md

@@ -0,0 +1,51 @@
+# 图像分类模型裁剪训练
+
+## 第一步 正常训练图像分类模型
+
+```
+python mobilenetv2_train.py
+```
+
+在此步骤中,训练的模型会保存在`output/mobilenetv2`目录下
+
+## 第二步 分析模型参数信息
+
+```
+python param_analysis.py
+```
+参数分析完后,会得到`mobilenetv2.sensi.data`文件,此文件保存了各参数的敏感度信息。  
+
+> 我们可以继续加载模型和敏感度文件,进行可视化,如下命令所示
+> ```
+> python slim_visualize.py
+> ```
+> 可视化结果出下图
+纵轴为`eval_metric_loss`(接下来第三步需要配置的参数),横轴为模型被裁剪的比例,从图中可以看到,  
+- 当`eval_metric_loss`设0.05时,模型被裁掉68.4%(剩余31.6%)  
+- 当`eval_metric_loss`设0.1时,模型被裁掉78.5%(剩余21.5%)
+
+![](./sensitivities.png)
+
+## 第三步 模型进行裁剪训练
+
+```
+python mobilenetv2_prune_train.py
+```
+此步骤的代码与第一步的代码基本一致,唯一的区别是在最后的train函数中,`mobilenetv2_prune_train.py`修改了里面的`pretrain_weights`、`save_dir`、`sensitivities_file`和`eval_metric_loss`四个参数
+
+- pretrain_weights: 在裁剪训练中,设置为之前训练好的模型
+- save_dir: 模型训练过程中,模型的保存位置
+- sensitivities_file: 在第二步中分析得到的参数敏感度信息文件
+- eval_metric_loss: 第二步中可视化的相关参数,通过此参数可相应的改变最终模型被裁剪的比例
+
+
+## 裁剪效果
+
+在本示例数据上,裁剪效果对比如下,其中预测采用**CPU,关闭MKLDNN**进行预测,预测时间不包含数据的预处理和结果的后处理。  
+可以看到在模型被裁剪掉64%后,模型精度还有上升,单张图片的预测用时减少了37%。
+
+
+| 模型 | 参数文件大小 | 预测速度 | 准确率 |
+| :--- | :----------  | :------- | :--- |
+| MobileNetV2 |    8.7M       |   0.057s  | 0.92 |
+| MobileNetV2(裁掉68%) | 2.8M | 0.036s | 0.99 |

+ 44 - 0
tutorials/slim/prune/image_classification/mobilenetv2_prune_train.py

@@ -0,0 +1,44 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.cls import transforms
+import paddlex as pdx
+
+veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
+pdx.utils.download_and_decompress(veg_dataset, path='./')
+
+train_transforms = transforms.Compose([
+    transforms.RandomCrop(crop_size=224), transforms.RandomHorizontalFlip(),
+    transforms.Normalize()
+])
+eval_transforms = transforms.Compose([
+    transforms.ResizeByShort(short_size=256),
+    transforms.CenterCrop(crop_size=224), transforms.Normalize()
+])
+
+train_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/train_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/val_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=eval_transforms)
+
+model = pdx.cls.MobileNetV2(num_classes=len(train_dataset.labels))
+
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=32,
+    eval_dataset=eval_dataset,
+    lr_decay_epochs=[4, 6, 8],
+    learning_rate=0.025,
+    pretrain_weights='output/mobilenetv2/best_model',
+    save_dir='output/mobilenetv2_prune',
+    sensitivities_file='./mobilenetv2.sensi.data',
+    eval_metric_loss=0.05,
+    use_vdl=True)

+ 41 - 0
tutorials/slim/prune/image_classification/mobilenetv2_train.py

@@ -0,0 +1,41 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.cls import transforms
+import paddlex as pdx
+
+veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
+pdx.utils.download_and_decompress(veg_dataset, path='./')
+
+train_transforms = transforms.Compose([
+    transforms.RandomCrop(crop_size=224), transforms.RandomHorizontalFlip(),
+    transforms.Normalize()
+])
+eval_transforms = transforms.Compose([
+    transforms.ResizeByShort(short_size=256),
+    transforms.CenterCrop(crop_size=224), transforms.Normalize()
+])
+
+train_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/train_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/val_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=eval_transforms)
+
+model = pdx.cls.MobileNetV2(num_classes=len(train_dataset.labels))
+
+model.train(
+    num_epochs=10,
+    train_dataset=train_dataset,
+    train_batch_size=32,
+    eval_dataset=eval_dataset,
+    lr_decay_epochs=[4, 6, 8],
+    learning_rate=0.025,
+    save_dir='output/mobilenetv2',
+    use_vdl=True)

+ 17 - 0
tutorials/slim/prune/image_classification/params_analysis.py

@@ -0,0 +1,17 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+import paddlex as pdx
+
+model = pdx.load_model('output/mobilenetv2/best_model')
+
+eval_dataset = pdx.datasets.ImageNet(
+    data_dir='vegetables_cls',
+    file_list='vegetables_cls/val_list.txt',
+    label_list='vegetables_cls/labels.txt',
+    transforms=model.eval_transforms)
+
+pdx.slim.prune.analysis(
+    model,
+    dataset=eval_dataset,
+    batch_size=16,
+    save_file='mobilenetv2.sensi.data')

二进制
tutorials/slim/prune/image_classification/sensitivities.png


+ 3 - 0
tutorials/slim/prune/image_classification/slim_visualize.py

@@ -0,0 +1,3 @@
+import paddlex as pdx
+model = pdx.load_model('output/mobilenetv2/best_model')
+pdx.slim.visualize(model, 'mobilenetv2.sensi.data', save_dir='./')

+ 50 - 0
tutorials/slim/prune/object_detection/README.md

@@ -0,0 +1,50 @@
+# 目标检测模型裁剪训练
+
+## 第一步 正常训练目标检测模型
+
+```
+python yolov3_train.py
+```
+
+在此步骤中,训练的模型会保存在`output/yolov3_mobilenetv1`目录下
+
+## 第二步 分析模型参数信息
+
+```
+python param_analysis.py
+```
+参数分析完后,会得到`yolov3.sensi.data`文件,此文件保存了各参数的敏感度信息。  
+
+> 我们可以继续加载模型和敏感度文件,进行可视化,如下命令所示
+> ```
+> python slim_visualize.py
+> ```
+> 可视化结果出下图
+纵轴为`eval_metric_loss`(接下来第三步需要配置的参数),横轴为模型被裁剪的比例,从图中可以看到,  
+- 当`eval_metric_loss`设0.05时,模型被裁掉63.1%(剩余36.9%)  
+- 当`eval_metric_loss`设0.1时,模型被裁掉68.6%(剩余31.4%)
+
+![](./sensitivities.png)
+
+## 第三步 模型进行裁剪训练
+
+```
+python yolov3_prune_train.py
+```
+此步骤的代码与第一步的代码基本一致,唯一的区别是在最后的train函数中,`yolov3_prune_train.py`修改了里面的`pretrain_weights`、`save_dir`、`sensitivities_file`和`eval_metric_loss`四个参数
+
+- pretrain_weights: 在裁剪训练中,设置为之前训练好的模型
+- save_dir: 模型训练过程中,模型的保存位置
+- sensitivities_file: 在第二步中分析得到的参数敏感度信息文件
+- eval_metric_loss: 第二步中可视化的相关参数,通过此参数可相应的改变最终模型被裁剪的比例
+
+## 裁剪效果
+
+在本示例数据上,裁剪效果对比如下,其中预测采用**CPU,关闭MKLDNN**进行预测,预测时间不包含数据的预处理和结果的后处理。  
+可以看到在模型被裁剪掉63%后,模型精度还有上升,单张图片的预测用时减少了30%。
+
+
+| 模型 | 参数文件大小 | 预测速度 | MAP |
+| :--- | :----------  | :------- | :--- |
+| YOLOv3-MobileNetV1 |    93M       |   1.045s  | 0.635 |
+| YOLOv3-MobileNetV1(裁掉63%) | 35M | 0.735s | 0.735 |

+ 14 - 0
tutorials/slim/prune/object_detection/params_analysis.py

@@ -0,0 +1,14 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+import paddlex as pdx
+
+model = pdx.load_model('output/yolov3_mobilenetv1/best_model')
+
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=model.eval_transforms)
+
+pdx.slim.prune.analysis(
+    model, dataset=eval_dataset, batch_size=8, save_file='yolov3.sensi.data')

二进制
tutorials/slim/prune/object_detection/sensitivities.png


+ 3 - 0
tutorials/slim/prune/object_detection/slim_visualize.py

@@ -0,0 +1,3 @@
+import paddlex as pdx
+model = pdx.load_model('output/yolov3_mobilenetv1/best_model')
+pdx.slim.visualize(model, 'yolov3.sensi.data', save_dir='./')

+ 54 - 0
tutorials/slim/prune/object_detection/yolov3_prune_train.py

@@ -0,0 +1,54 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.det import transforms
+import paddlex as pdx
+
+insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(insect_dataset, path='./')
+
+train_transforms = transforms.Compose([
+    transforms.MixupImage(mixup_epoch=250),
+    transforms.RandomDistort(),
+    transforms.RandomExpand(),
+    transforms.RandomCrop(),
+    transforms.Resize(
+        target_size=608, interp='RANDOM'),
+    transforms.RandomHorizontalFlip(),
+    transforms.Normalize(),
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Resize(
+        target_size=608, interp='CUBIC'),
+    transforms.Normalize(),
+])
+
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/train_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=eval_transforms)
+
+num_classes = len(train_dataset.labels)
+
+model = pdx.det.YOLOv3(num_classes=num_classes, backbone='MobileNetV1')
+
+model.train(
+    num_epochs=270,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    learning_rate=0.000125,
+    lr_decay_epochs=[210, 240],
+    pretrain_weights='output/yolov3_mobilenetv1/best_model',
+    save_dir='output/yolov3_mobilenetv1_prune',
+    sensitivities_file='./yolov3.sensi.data',
+    eval_metric_loss=0.05,
+    use_vdl=True)

+ 51 - 0
tutorials/slim/prune/object_detection/yolov3_train.py

@@ -0,0 +1,51 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.det import transforms
+import paddlex as pdx
+
+insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
+pdx.utils.download_and_decompress(insect_dataset, path='./')
+
+train_transforms = transforms.Compose([
+    transforms.MixupImage(mixup_epoch=250),
+    transforms.RandomDistort(),
+    transforms.RandomExpand(),
+    transforms.RandomCrop(),
+    transforms.Resize(
+        target_size=608, interp='RANDOM'),
+    transforms.RandomHorizontalFlip(),
+    transforms.Normalize(),
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Resize(
+        target_size=608, interp='CUBIC'),
+    transforms.Normalize(),
+])
+
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/train_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=eval_transforms)
+
+num_classes = len(train_dataset.labels)
+
+model = pdx.det.YOLOv3(num_classes=num_classes, backbone='MobileNetV1')
+
+model.train(
+    num_epochs=270,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    learning_rate=0.000125,
+    lr_decay_epochs=[210, 240],
+    save_dir='output/yolov3_mobilenetv1',
+    use_vdl=True)

+ 51 - 0
tutorials/slim/prune/semantic_segmentation/README.md

@@ -0,0 +1,51 @@
+# 语义分割模型裁剪训练
+
+## 第一步 正常训练语义分割模型
+
+```
+python unet_train.py
+```
+
+在此步骤中,训练的模型会保存在`output/unet`目录下
+
+## 第二步 分析模型参数信息
+
+```
+python param_analysis.py
+```
+参数分析完后,会得到`unet.sensi.data`文件,此文件保存了各参数的敏感度信息。  
+
+> 我们可以继续加载模型和敏感度文件,进行可视化,如下命令所示
+> ```
+> python slim_visualize.py
+> ```
+> 可视化结果出下图
+纵轴为`eval_metric_loss`(接下来第三步需要配置的参数),横轴为模型被裁剪的比例,从图中可以看到,  
+- 当`eval_metric_loss`设0.05时,模型被裁掉64.1%(剩余35.9%)  
+- 当`eval_metric_loss`设0.1时,模型被裁掉70.9%(剩余29.1%)
+
+![](./sensitivities.png)
+
+## 第三步 模型进行裁剪训练
+
+```
+python unet_prune_train.py
+```
+此步骤的代码与第一步的代码基本一致,唯一的区别是在最后的train函数中,`unet_prune_train.py`修改了里面的`pretrain_weights`、`save_dir`、`sensitivities_file`和`eval_metric_loss`四个参数
+
+- pretrain_weights: 在裁剪训练中,设置为之前训练好的模型
+- save_dir: 模型训练过程中,模型的保存位置
+- sensitivities_file: 在第二步中分析得到的参数敏感度信息文件
+- eval_metric_loss: 第二步中可视化的相关参数,通过此参数可相应的改变最终模型被裁剪的比例
+
+## 裁剪效果
+
+在本示例数据上,裁剪效果对比如下,其中预测采用**CPU,关闭MKLDNN**进行预测,预测时间不包含数据的预处理和结果的后处理。  
+可以看到在模型被裁剪掉64%后,模型精度基本保持不变,单张图片的预测用时降低了近50%。
+
+> 此处仅做对比,使用了UNet模型,实际上在低性能设备上,更建议使用deeplab-mobilenet或fastscnn等轻量级分割模型。
+
+| 模型 | 参数文件大小 | 预测速度 | mIOU |
+| :--- | :----------  | :------- | :--- |
+| UNet |    52M       |   9.85s  | 0.915 |
+| UNet(裁掉64%) | 19M | 4.80s | 0.911 |

+ 14 - 0
tutorials/slim/prune/semantic_segmentation/params_analysis.py

@@ -0,0 +1,14 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+import paddlex as pdx
+
+model = pdx.load_model('output/unet/best_model')
+
+eval_dataset = pdx.datasets.SegDataset(
+    data_dir='optic_disc_seg',
+    file_list='optic_disc_seg/val_list.txt',
+    label_list='optic_disc_seg/labels.txt',
+    transforms=model.eval_transforms)
+
+pdx.slim.prune.analysis(
+    model, dataset=eval_dataset, batch_size=4, save_file='unet.sensi.data')

二进制
tutorials/slim/prune/semantic_segmentation/sensitivities.png


+ 3 - 0
tutorials/slim/prune/semantic_segmentation/slim_visualize.py

@@ -0,0 +1,3 @@
+import paddlex as pdx
+model = pdx.load_model('output/unet/best_model')
+pdx.slim.visualize(model, 'unet.sensi.data', save_dir='./')

+ 46 - 0
tutorials/slim/prune/semantic_segmentation/unet_prune_train.py

@@ -0,0 +1,46 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import paddlex as pdx
+from paddlex.seg import transforms
+
+optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
+pdx.utils.download_and_decompress(optic_dataset, path='./')
+
+train_transforms = transforms.Compose([
+    transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
+    transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
+])
+
+eval_transforms = transforms.Compose([
+    transforms.ResizeByLong(long_size=512), transforms.Padding(target_size=512),
+    transforms.Normalize()
+])
+
+train_dataset = pdx.datasets.SegDataset(
+    data_dir='optic_disc_seg',
+    file_list='optic_disc_seg/train_list.txt',
+    label_list='optic_disc_seg/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.SegDataset(
+    data_dir='optic_disc_seg',
+    file_list='optic_disc_seg/val_list.txt',
+    label_list='optic_disc_seg/labels.txt',
+    transforms=eval_transforms)
+
+num_classes = len(train_dataset.labels)
+
+model = pdx.seg.UNet(num_classes=num_classes)
+
+model.train(
+    num_epochs=20,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    learning_rate=0.01,
+    pretrain_weights='output/unet/best_model',
+    save_dir='output/unet_prune',
+    sensitivities_file='./unet.sensi.data',
+    eval_metric_loss=0.05,
+    use_vdl=True)

+ 43 - 0
tutorials/slim/prune/semantic_segmentation/unet_train.py

@@ -0,0 +1,43 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import paddlex as pdx
+from paddlex.seg import transforms
+
+optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
+pdx.utils.download_and_decompress(optic_dataset, path='./')
+
+train_transforms = transforms.Compose([
+    transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
+    transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
+])
+
+eval_transforms = transforms.Compose([
+    transforms.ResizeByLong(long_size=512), transforms.Padding(target_size=512),
+    transforms.Normalize()
+])
+
+train_dataset = pdx.datasets.SegDataset(
+    data_dir='optic_disc_seg',
+    file_list='optic_disc_seg/train_list.txt',
+    label_list='optic_disc_seg/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.SegDataset(
+    data_dir='optic_disc_seg',
+    file_list='optic_disc_seg/val_list.txt',
+    label_list='optic_disc_seg/labels.txt',
+    transforms=eval_transforms)
+
+num_classes = len(train_dataset.labels)
+
+model = pdx.seg.UNet(num_classes=num_classes)
+
+model.train(
+    num_epochs=20,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    learning_rate=0.01,
+    save_dir='output/unet',
+    use_vdl=True)