فهرست منبع

Merge pull request #137 from SunAhong1993/syf_prune2

add alexnet prune
Jason 5 سال پیش
والد
کامیت
cf529fc1f3
1فایلهای تغییر یافته به همراه4 افزوده شده و 1 حذف شده
  1. 4 1
      paddlex/cv/models/slim/prune_config.py

+ 4 - 1
paddlex/cv/models/slim/prune_config.py

@@ -142,13 +142,16 @@ def get_prune_params(model):
     program = model.test_prog
     if model_type.startswith('ResNet') or \
             model_type.startswith('DenseNet') or \
-            model_type.startswith('DarkNet'):
+            model_type.startswith('DarkNet') or \
+            model_type.startswith('AlexNet'):
         for block in program.blocks:
             for param in block.all_parameters():
                 pd_var = fluid.global_scope().find_var(param.name)
                 pd_param = pd_var.get_tensor()
                 if len(np.array(pd_param).shape) == 4:
                     prune_names.append(param.name)
+        if model_type == 'AlexNet':
+            prune_names.remove('conv5_weights')
     elif model_type == "MobileNetV1":
         prune_names.append("conv1_weights")
         for param in program.global_block().all_parameters():