|
|
@@ -122,7 +122,14 @@ class BaseModel:
|
|
|
info = dict()
|
|
|
info['pruner'] = self.pruner.__class__.__name__
|
|
|
info['pruning_ratios'] = self.pruning_ratios
|
|
|
- info['pruner_inputs'] = self.pruner.inputs
|
|
|
+ pruner_inputs = self.pruner.inputs
|
|
|
+ if self.model_type == 'detector':
|
|
|
+ pruner_inputs = {
|
|
|
+ k: v.tolist()
|
|
|
+ for k, v in pruner_inputs[0].items()
|
|
|
+ }
|
|
|
+ info['pruner_inputs'] = pruner_inputs
|
|
|
+
|
|
|
return info
|
|
|
|
|
|
def get_quant_info(self):
|
|
|
@@ -427,12 +434,7 @@ class BaseModel:
|
|
|
pre_pruning_flops = flops(self.net, self.pruner.inputs)
|
|
|
logging.info("Pre-pruning FLOPs: {}. Pruning starts...".format(
|
|
|
pre_pruning_flops))
|
|
|
- skip_vars = []
|
|
|
- for param in self.net.parameters():
|
|
|
- if param.shape[0] <= 8:
|
|
|
- skip_vars.append(param.name)
|
|
|
- _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops,
|
|
|
- skip_vars)
|
|
|
+ _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops)
|
|
|
post_pruning_flops = flops(self.net, self.pruner.inputs)
|
|
|
logging.info("Pruning is complete. Post-pruning FLOPs: {}".format(
|
|
|
post_pruning_flops))
|