Przeglądaj źródła

add params_not_prune for deeplabv3p_mobilenetv3_large

FlyingQianMM 5 lat temu
rodzic
commit
3403eb2f89
1 zmienionych plików z 24 dodań i 0 usunięć
  1. 24 0
      paddlex/cv/models/slim/prune_config.py

+ 24 - 0
paddlex/cv/models/slim/prune_config.py

@@ -285,11 +285,35 @@ def get_prune_params(model):
                 prune_names.remove(i)
 
     elif model_type.startswith('DeepLabv3p'):
+        if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
+            params_not_prune = [
+                'last_1x1_conv_weights', 'conv14_se_2_weights',
+                'conv16_depthwise_weights', 'conv13_depthwise_weights',
+                'conv15_se_2_weights', 'conv2_depthwise_weights',
+                'conv6_depthwise_weights', 'conv8_depthwise_weights',
+                'fc_weights', 'conv3_depthwise_weights', 'conv7_se_2_weights',
+                'conv16_expand_weights', 'conv16_se_2_weights',
+                'conv10_depthwise_weights', 'conv11_depthwise_weights',
+                'conv15_expand_weights', 'conv5_expand_weights',
+                'conv15_depthwise_weights', 'conv14_depthwise_weights',
+                'conv12_se_2_weights', 'conv1_weights',
+                'conv13_expand_weights', 'conv_last_weights',
+                'conv12_depthwise_weights', 'conv13_se_2_weights',
+                'conv12_expand_weights', 'conv5_depthwise_weights',
+                'conv6_se_2_weights', 'conv10_expand_weights',
+                'conv9_depthwise_weights', 'conv6_expand_weights',
+                'conv5_se_2_weights', 'conv14_expand_weights',
+                'conv4_depthwise_weights', 'conv7_expand_weights',
+                'conv7_depthwise_weights'
+            ]
         for param in program.global_block().all_parameters():
             if 'weight' not in param.name:
                 continue
             if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
                 continue
+            if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
+                if param.name in params_not_prune:
+                    continue
             prune_names.append(param.name)
         params_not_prune = [
             'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.