|
@@ -285,11 +285,35 @@ def get_prune_params(model):
|
|
|
prune_names.remove(i)
|
|
prune_names.remove(i)
|
|
|
|
|
|
|
|
elif model_type.startswith('DeepLabv3p'):
|
|
elif model_type.startswith('DeepLabv3p'):
|
|
|
|
|
+ if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
|
|
|
|
|
+ params_not_prune = [
|
|
|
|
|
+ 'last_1x1_conv_weights', 'conv14_se_2_weights',
|
|
|
|
|
+ 'conv16_depthwise_weights', 'conv13_depthwise_weights',
|
|
|
|
|
+ 'conv15_se_2_weights', 'conv2_depthwise_weights',
|
|
|
|
|
+ 'conv6_depthwise_weights', 'conv8_depthwise_weights',
|
|
|
|
|
+ 'fc_weights', 'conv3_depthwise_weights', 'conv7_se_2_weights',
|
|
|
|
|
+ 'conv16_expand_weights', 'conv16_se_2_weights',
|
|
|
|
|
+ 'conv10_depthwise_weights', 'conv11_depthwise_weights',
|
|
|
|
|
+ 'conv15_expand_weights', 'conv5_expand_weights',
|
|
|
|
|
+ 'conv15_depthwise_weights', 'conv14_depthwise_weights',
|
|
|
|
|
+ 'conv12_se_2_weights', 'conv1_weights',
|
|
|
|
|
+ 'conv13_expand_weights', 'conv_last_weights',
|
|
|
|
|
+ 'conv12_depthwise_weights', 'conv13_se_2_weights',
|
|
|
|
|
+ 'conv12_expand_weights', 'conv5_depthwise_weights',
|
|
|
|
|
+ 'conv6_se_2_weights', 'conv10_expand_weights',
|
|
|
|
|
+ 'conv9_depthwise_weights', 'conv6_expand_weights',
|
|
|
|
|
+ 'conv5_se_2_weights', 'conv14_expand_weights',
|
|
|
|
|
+ 'conv4_depthwise_weights', 'conv7_expand_weights',
|
|
|
|
|
+ 'conv7_depthwise_weights'
|
|
|
|
|
+ ]
|
|
|
for param in program.global_block().all_parameters():
|
|
for param in program.global_block().all_parameters():
|
|
|
if 'weight' not in param.name:
|
|
if 'weight' not in param.name:
|
|
|
continue
|
|
continue
|
|
|
if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
|
|
if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
|
|
|
continue
|
|
continue
|
|
|
|
|
+ if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
|
|
|
|
|
+ if param.name in params_not_prune:
|
|
|
|
|
+ continue
|
|
|
prune_names.append(param.name)
|
|
prune_names.append(param.name)
|
|
|
params_not_prune = [
|
|
params_not_prune = [
|
|
|
'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
|
|
'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
|