Browse Source

Merge pull request #825 from will-jl944/develop_jf

Refine pruning
FlyingQianMM 4 years ago
parent
commit
2d155f6a80
2 changed files with 14 additions and 7 deletions
  1. 9 7
      dygraph/paddlex/cv/models/base.py
  2. 5 0
      dygraph/paddlex/cv/models/load_model.py

+ 9 - 7
dygraph/paddlex/cv/models/base.py

@@ -122,7 +122,14 @@ class BaseModel:
         info = dict()
         info['pruner'] = self.pruner.__class__.__name__
         info['pruning_ratios'] = self.pruning_ratios
-        info['pruner_inputs'] = self.pruner.inputs
+        pruner_inputs = self.pruner.inputs
+        if self.model_type == 'detector':
+            pruner_inputs = {
+                k: v.tolist()
+                for k, v in pruner_inputs[0].items()
+            }
+        info['pruner_inputs'] = pruner_inputs
+
         return info
 
     def get_quant_info(self):
@@ -427,12 +434,7 @@ class BaseModel:
         pre_pruning_flops = flops(self.net, self.pruner.inputs)
         logging.info("Pre-pruning FLOPs: {}. Pruning starts...".format(
             pre_pruning_flops))
-        skip_vars = []
-        for param in self.net.parameters():
-            if param.shape[0] <= 8:
-                skip_vars.append(param.name)
-        _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops,
-                                                 skip_vars)
+        _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops)
         post_pruning_flops = flops(self.net, self.pruner.inputs)
         logging.info("Pruning is complete. Post-pruning FLOPs: {}".format(
             post_pruning_flops))

+ 5 - 0
dygraph/paddlex/cv/models/load_model.py

@@ -69,6 +69,11 @@ def load_model(model_dir):
             with open(osp.join(model_dir, "prune.yml")) as f:
                 pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
                 inputs = pruning_info['pruner_inputs']
+                if model.model_type == 'detector':
+                    inputs = [{
+                        k: paddle.to_tensor(v)
+                        for k, v in inputs.items()
+                    }]
                 model.pruner = getattr(paddleslim, pruning_info['pruner'])(
                     model.net, inputs=inputs)
                 model.pruning_ratios = pruning_info['pruning_ratios']