|
|
@@ -427,12 +427,7 @@ class BaseModel:
|
|
|
pre_pruning_flops = flops(self.net, self.pruner.inputs)
|
|
|
logging.info("Pre-pruning FLOPs: {}. Pruning starts...".format(
|
|
|
pre_pruning_flops))
|
|
|
- skip_vars = []
|
|
|
- for param in self.net.parameters():
|
|
|
- if param.shape[0] <= 8:
|
|
|
- skip_vars.append(param.name)
|
|
|
- _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops,
|
|
|
- skip_vars)
|
|
|
+ _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops)
|
|
|
post_pruning_flops = flops(self.net, self.pruner.inputs)
|
|
|
logging.info("Pruning is complete. Post-pruning FLOPs: {}".format(
|
|
|
post_pruning_flops))
|