|
|
@@ -91,7 +91,23 @@ sensitivities_data = {
|
|
|
'DeepLabv3p_Xception65_aspp_decoder':
|
|
|
'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities',
|
|
|
'DeepLabv3p_Xception41_aspp_decoder':
|
|
|
- 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities'
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities',
|
|
|
+ 'HRNet_W18_Seg':
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w18.sensitivities',
|
|
|
+ 'HRNet_W30_Seg':
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w30.sensitivities',
|
|
|
+ 'HRNet_W32_Seg':
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w32.sensitivities',
|
|
|
+ 'HRNet_W40_Seg':
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w40.sensitivities',
|
|
|
+ 'HRNet_W44_Seg':
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w44.sensitivities',
|
|
|
+ 'HRNet_W48_Seg':
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w48.sensitivities',
|
|
|
+ 'HRNet_W64_Seg':
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w64.sensitivities',
|
|
|
+ 'FastSCNN':
|
|
|
+ 'https://bj.bcebos.com/paddlex/slim_prune/fast_scnn.sensitivities'
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -105,6 +121,8 @@ def get_sensitivities(flag, model, save_dir):
|
|
|
elif hasattr(model, 'encoder_with_aspp') or hasattr(model,
|
|
|
'enable_decoder'):
|
|
|
model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
|
|
|
+ if model_type.startswith('HRNet') and model.model_type == 'segmenter':
|
|
|
+ model_type = '{}_W{}_Seg'.format(model_type, model.width)
|
|
|
if osp.isfile(flag):
|
|
|
return flag
|
|
|
elif flag == 'DEFAULT':
|
|
|
@@ -243,19 +261,17 @@ def get_prune_params(model):
|
|
|
for i in params_not_prune:
|
|
|
if i in prune_names:
|
|
|
prune_names.remove(i)
|
|
|
-
|
|
|
- elif model_type.startswith('HRNet'):
|
|
|
+
|
|
|
+ elif model_type.startswith('HRNet') and model.model_type == 'segmenter':
|
|
|
for param in program.global_block().all_parameters():
|
|
|
if 'weight' not in param.name:
|
|
|
continue
|
|
|
prune_names.append(param.name)
|
|
|
- params_not_prune = [
|
|
|
- 'conv-1_weights'
|
|
|
- ]
|
|
|
+ params_not_prune = ['conv-1_weights']
|
|
|
for i in params_not_prune:
|
|
|
if i in prune_names:
|
|
|
prune_names.remove(i)
|
|
|
-
|
|
|
+
|
|
|
elif model_type.startswith('FastSCNN'):
|
|
|
for param in program.global_block().all_parameters():
|
|
|
if 'weight' not in param.name:
|
|
|
@@ -263,9 +279,7 @@ def get_prune_params(model):
|
|
|
if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
|
|
|
continue
|
|
|
prune_names.append(param.name)
|
|
|
- params_not_prune = [
|
|
|
- 'classifier/weights'
|
|
|
- ]
|
|
|
+ params_not_prune = ['classifier/weights']
|
|
|
for i in params_not_prune:
|
|
|
if i in prune_names:
|
|
|
prune_names.remove(i)
|