|
|
@@ -243,6 +243,32 @@ def get_prune_params(model):
|
|
|
for i in params_not_prune:
|
|
|
if i in prune_names:
|
|
|
prune_names.remove(i)
|
|
|
+
|
|
|
+ elif model_type.startswith('HRNet'):
|
|
|
+ for param in program.global_block().all_parameters():
|
|
|
+ if 'weight' not in param.name:
|
|
|
+ continue
|
|
|
+ prune_names.append(param.name)
|
|
|
+ params_not_prune = [
|
|
|
+ 'conv-1_weights'
|
|
|
+ ]
|
|
|
+ for i in params_not_prune:
|
|
|
+ if i in prune_names:
|
|
|
+ prune_names.remove(i)
|
|
|
+
|
|
|
+ elif model_type.startswith('FastSCNN'):
|
|
|
+ for param in program.global_block().all_parameters():
|
|
|
+ if 'weight' not in param.name:
|
|
|
+ continue
|
|
|
+ if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
|
|
|
+ continue
|
|
|
+ prune_names.append(param.name)
|
|
|
+ params_not_prune = [
|
|
|
+ 'classifier/weights'
|
|
|
+ ]
|
|
|
+ for i in params_not_prune:
|
|
|
+ if i in prune_names:
|
|
|
+ prune_names.remove(i)
|
|
|
|
|
|
elif model_type.startswith('DeepLabv3p'):
|
|
|
for param in program.global_block().all_parameters():
|