Răsfoiți Sursa

add prune configs and prompt

sunyanfang01 5 ani în urmă
părinte
comite
5a4783ae37

+ 1 - 0
paddlex/cv/models/load_model.py

@@ -108,6 +108,7 @@ def load_model(model_dir, fixed_input_shape=None):
 
     logging.info("Model[{}] loaded.".format(info['Model']))
     model.trainable = False
+    model.status = status
     return model
 
 

+ 3 - 0
paddlex/cv/models/slim/prune.py

@@ -158,6 +158,7 @@ def prune_program(model, prune_params_ratios=None):
         prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
             使用默认裁剪参数名和裁剪率。默认为None。
     """
+    assert model.status == 'Normal', 'Only the model after training can be pruned!'
     place = model.places[0]
     train_prog = model.train_prog
     eval_prog = model.test_prog
@@ -235,6 +236,8 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
 
             其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
     """
+    print('-----------', model.status)
+    assert model.status == 'Normal', 'Only the model after training can calculate sensitivities data!'
     if os.path.exists(save_file):
         os.remove(save_file)
 

+ 30 - 1
paddlex/cv/models/slim/prune_config.py

@@ -19,6 +19,8 @@ import paddle.fluid as fluid
 import paddlex
 
 sensitivities_data = {
+    'AlexNet':
+    'https://bj.bcebos.com/paddlex/slim_prune/alexnet_sensitivities.data',
     'ResNet18':
     'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities',
     'ResNet34':
@@ -41,6 +43,10 @@ sensitivities_data = {
     'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities',
     'MobileNetV3_small':
     'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities',
+    'MobileNetV3_large_ssld':
+    'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large_ssld_sensitivities.data',
+    'MobileNetV3_small_ssld':
+    'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small_ssld_sensitivities.data',
     'DenseNet121':
     'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities',
     'DenseNet161':
@@ -143,7 +149,8 @@ def get_prune_params(model):
     if model_type.startswith('ResNet') or \
             model_type.startswith('DenseNet') or \
             model_type.startswith('DarkNet') or \
-            model_type.startswith('AlexNet'):
+            model_type.startswith('AlexNet') or \
+            model_type.startswith('ShuffleNetV2'):
         for block in program.blocks:
             for param in block.all_parameters():
                 pd_var = fluid.global_scope().find_var(param.name)
@@ -152,6 +159,28 @@ def get_prune_params(model):
                     prune_names.append(param.name)
         if model_type == 'AlexNet':
             prune_names.remove('conv5_weights')
+        if model_type == 'ShuffleNetV2':
+            not_prune_names = ['stage_2_1_conv5_weights',
+                        'stage_2_1_conv3_weights',
+                        'stage_2_2_conv3_weights',
+                        'stage_2_3_conv3_weights',
+                        'stage_2_4_conv3_weights',
+                        'stage_3_1_conv5_weights',
+                        'stage_3_1_conv3_weights',
+                        'stage_3_2_conv3_weights',
+                        'stage_3_3_conv3_weights',
+                        'stage_3_4_conv3_weights',
+                        'stage_3_5_conv3_weights',
+                        'stage_3_6_conv3_weights',
+                        'stage_3_7_conv3_weights',
+                        'stage_3_8_conv3_weights',
+                        'stage_4_1_conv5_weights',
+                        'stage_4_1_conv3_weights',
+                        'stage_4_2_conv3_weights',
+                        'stage_4_3_conv3_weights',
+                        'stage_4_4_conv3_weights',]
+            for name in not_prune_names:
+                 prune_names.remove(name)
     elif model_type == "MobileNetV1":
         prune_names.append("conv1_weights")
         for param in program.global_block().all_parameters():