|
|
@@ -142,13 +142,16 @@ def get_prune_params(model):
|
|
|
program = model.test_prog
|
|
|
if model_type.startswith('ResNet') or \
|
|
|
model_type.startswith('DenseNet') or \
|
|
|
- model_type.startswith('DarkNet'):
|
|
|
+ model_type.startswith('DarkNet') or \
|
|
|
+ model_type.startswith('AlexNet'):
|
|
|
for block in program.blocks:
|
|
|
for param in block.all_parameters():
|
|
|
pd_var = fluid.global_scope().find_var(param.name)
|
|
|
pd_param = pd_var.get_tensor()
|
|
|
if len(np.array(pd_param).shape) == 4:
|
|
|
prune_names.append(param.name)
|
|
|
+ if model_type == 'AlexNet':
|
|
|
+ prune_names.remove('conv5_weights')
|
|
|
elif model_type == "MobileNetV1":
|
|
|
prune_names.append("conv1_weights")
|
|
|
for param in program.global_block().all_parameters():
|