瀏覽代碼

Merge pull request #293 from FlyingQianMM/develop_test

add sensitive files for hrnet and fastscnn
Jason 5 年之前
父節點
當前提交
d4b719678a
共有 1 個文件被更改,包括 24 次插入10 次删除
  1. 24 10
      paddlex/cv/models/slim/prune_config.py

+ 24 - 10
paddlex/cv/models/slim/prune_config.py

@@ -91,7 +91,23 @@ sensitivities_data = {
     'DeepLabv3p_Xception65_aspp_decoder':
     'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities',
     'DeepLabv3p_Xception41_aspp_decoder':
-    'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities'
+    'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities',
+    'HRNet_W18_Seg':
+    'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w18.sensitivities',
+    'HRNet_W30_Seg':
+    'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w30.sensitivities',
+    'HRNet_W32_Seg':
+    'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w32.sensitivities',
+    'HRNet_W40_Seg':
+    'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w40.sensitivities',
+    'HRNet_W44_Seg':
+    'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w44.sensitivities',
+    'HRNet_W48_Seg':
+    'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w48.sensitivities',
+    'HRNet_W64_Seg':
+    'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w64.sensitivities',
+    'FastSCNN':
+    'https://bj.bcebos.com/paddlex/slim_prune/fast_scnn.sensitivities'
 }
 
 
@@ -105,6 +121,8 @@ def get_sensitivities(flag, model, save_dir):
     elif hasattr(model, 'encoder_with_aspp') or hasattr(model,
                                                         'enable_decoder'):
         model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
+    if model_type.startswith('HRNet') and model.model_type == 'segmenter':
+        model_type = '{}_W{}_Seg'.format(model_type, model.width)
     if osp.isfile(flag):
         return flag
     elif flag == 'DEFAULT':
@@ -243,19 +261,17 @@ def get_prune_params(model):
         for i in params_not_prune:
             if i in prune_names:
                 prune_names.remove(i)
-    
-    elif model_type.startswith('HRNet'):
+
+    elif model_type.startswith('HRNet') and model.model_type == 'segmenter':
         for param in program.global_block().all_parameters():
             if 'weight' not in param.name:
                 continue
             prune_names.append(param.name)
-        params_not_prune = [
-            'conv-1_weights'
-        ]
+        params_not_prune = ['conv-1_weights']
         for i in params_not_prune:
             if i in prune_names:
                 prune_names.remove(i)
-    
+
     elif model_type.startswith('FastSCNN'):
         for param in program.global_block().all_parameters():
             if 'weight' not in param.name:
@@ -263,9 +279,7 @@ def get_prune_params(model):
             if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
                 continue
             prune_names.append(param.name)
-        params_not_prune = [
-            'classifier/weights'
-        ]
+        params_not_prune = ['classifier/weights']
         for i in params_not_prune:
             if i in prune_names:
                 prune_names.remove(i)