Browse Source

add hrnet and fast_scnn

chenguowei01 5 years ago
parent
commit
68e0476fae
1 changed files with 26 additions and 0 deletions
  1. 26 0
      paddlex/cv/models/slim/prune_config.py

+ 26 - 0
paddlex/cv/models/slim/prune_config.py

@@ -243,6 +243,32 @@ def get_prune_params(model):
         for i in params_not_prune:
             if i in prune_names:
                 prune_names.remove(i)
+    
+    elif model_type.startswith('HRNet'):
+        for param in program.global_block().all_parameters():
+            if 'weight' not in param.name:
+                continue
+            prune_names.append(param.name)
+        params_not_prune = [
+            'conv-1_weights'
+        ]
+        for i in params_not_prune:
+            if i in prune_names:
+                prune_names.remove(i)
+    
+    elif model_type.startswith('FastSCNN'):
+        for param in program.global_block().all_parameters():
+            if 'weight' not in param.name:
+                continue
+            if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
+                continue
+            prune_names.append(param.name)
+        params_not_prune = [
+            'classifier/weights'
+        ]
+        for i in params_not_prune:
+            if i in prune_names:
+                prune_names.remove(i)
 
     elif model_type.startswith('DeepLabv3p'):
         for param in program.global_block().all_parameters():