浏览代码

add hrnet and fast_scnn

chenguowei01 5 年之前
父节点
当前提交
68e0476fae
共有 1 个文件被更改,包括 26 次插入0 次删除
  1. 26 0
      paddlex/cv/models/slim/prune_config.py

+ 26 - 0
paddlex/cv/models/slim/prune_config.py

@@ -243,6 +243,32 @@ def get_prune_params(model):
         for i in params_not_prune:
             if i in prune_names:
                 prune_names.remove(i)
+    
+    elif model_type.startswith('HRNet'):
+        for param in program.global_block().all_parameters():
+            if 'weight' not in param.name:
+                continue
+            prune_names.append(param.name)
+        params_not_prune = [
+            'conv-1_weights'
+        ]
+        for i in params_not_prune:
+            if i in prune_names:
+                prune_names.remove(i)
+    
+    elif model_type.startswith('FastSCNN'):
+        for param in program.global_block().all_parameters():
+            if 'weight' not in param.name:
+                continue
+            if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
+                continue
+            prune_names.append(param.name)
+        params_not_prune = [
+            'classifier/weights'
+        ]
+        for i in params_not_prune:
+            if i in prune_names:
+                prune_names.remove(i)
 
     elif model_type.startswith('DeepLabv3p'):
         for param in program.global_block().all_parameters():