|
@@ -19,6 +19,8 @@ import paddle.fluid as fluid
|
|
|
import paddlex
|
|
import paddlex
|
|
|
|
|
|
|
|
sensitivities_data = {
|
|
sensitivities_data = {
|
|
|
|
|
+ 'AlexNet':
|
|
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/alexnet_sensitivities.data',
|
|
|
'ResNet18':
|
|
'ResNet18':
|
|
|
'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities',
|
|
'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities',
|
|
|
'ResNet34':
|
|
'ResNet34':
|
|
@@ -41,6 +43,10 @@ sensitivities_data = {
|
|
|
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities',
|
|
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities',
|
|
|
'MobileNetV3_small':
|
|
'MobileNetV3_small':
|
|
|
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities',
|
|
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities',
|
|
|
|
|
+ 'MobileNetV3_large_ssld':
|
|
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large_ssld_sensitivities.data',
|
|
|
|
|
+ 'MobileNetV3_small_ssld':
|
|
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small_ssld_sensitivities.data',
|
|
|
'DenseNet121':
|
|
'DenseNet121':
|
|
|
'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities',
|
|
'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities',
|
|
|
'DenseNet161':
|
|
'DenseNet161':
|
|
@@ -143,7 +149,8 @@ def get_prune_params(model):
|
|
|
if model_type.startswith('ResNet') or \
|
|
if model_type.startswith('ResNet') or \
|
|
|
model_type.startswith('DenseNet') or \
|
|
model_type.startswith('DenseNet') or \
|
|
|
model_type.startswith('DarkNet') or \
|
|
model_type.startswith('DarkNet') or \
|
|
|
- model_type.startswith('AlexNet'):
|
|
|
|
|
|
|
+ model_type.startswith('AlexNet') or \
|
|
|
|
|
+ model_type.startswith('ShuffleNetV2'):
|
|
|
for block in program.blocks:
|
|
for block in program.blocks:
|
|
|
for param in block.all_parameters():
|
|
for param in block.all_parameters():
|
|
|
pd_var = fluid.global_scope().find_var(param.name)
|
|
pd_var = fluid.global_scope().find_var(param.name)
|
|
@@ -152,6 +159,28 @@ def get_prune_params(model):
|
|
|
prune_names.append(param.name)
|
|
prune_names.append(param.name)
|
|
|
if model_type == 'AlexNet':
|
|
if model_type == 'AlexNet':
|
|
|
prune_names.remove('conv5_weights')
|
|
prune_names.remove('conv5_weights')
|
|
|
|
|
+ if model_type == 'ShuffleNetV2':
|
|
|
|
|
+ not_prune_names = ['stage_2_1_conv5_weights',
|
|
|
|
|
+ 'stage_2_1_conv3_weights',
|
|
|
|
|
+ 'stage_2_2_conv3_weights',
|
|
|
|
|
+ 'stage_2_3_conv3_weights',
|
|
|
|
|
+ 'stage_2_4_conv3_weights',
|
|
|
|
|
+ 'stage_3_1_conv5_weights',
|
|
|
|
|
+ 'stage_3_1_conv3_weights',
|
|
|
|
|
+ 'stage_3_2_conv3_weights',
|
|
|
|
|
+ 'stage_3_3_conv3_weights',
|
|
|
|
|
+ 'stage_3_4_conv3_weights',
|
|
|
|
|
+ 'stage_3_5_conv3_weights',
|
|
|
|
|
+ 'stage_3_6_conv3_weights',
|
|
|
|
|
+ 'stage_3_7_conv3_weights',
|
|
|
|
|
+ 'stage_3_8_conv3_weights',
|
|
|
|
|
+ 'stage_4_1_conv5_weights',
|
|
|
|
|
+ 'stage_4_1_conv3_weights',
|
|
|
|
|
+ 'stage_4_2_conv3_weights',
|
|
|
|
|
+ 'stage_4_3_conv3_weights',
|
|
|
|
|
+ 'stage_4_4_conv3_weights',]
|
|
|
|
|
+ for name in not_prune_names:
|
|
|
|
|
+ prune_names.remove(name)
|
|
|
elif model_type == "MobileNetV1":
|
|
elif model_type == "MobileNetV1":
|
|
|
prune_names.append("conv1_weights")
|
|
prune_names.append("conv1_weights")
|
|
|
for param in program.global_block().all_parameters():
|
|
for param in program.global_block().all_parameters():
|