Browse Source

convert tensor to list while saving pruner_inputs

will-jl944 4 years ago
parent
commit
ed65b01157
2 changed files with 13 additions and 1 deletions
  1. 8 1
      dygraph/paddlex/cv/models/base.py
  2. 5 0
      dygraph/paddlex/cv/models/load_model.py

+ 8 - 1
dygraph/paddlex/cv/models/base.py

@@ -122,7 +122,14 @@ class BaseModel:
         info = dict()
         info = dict()
         info['pruner'] = self.pruner.__class__.__name__
         info['pruner'] = self.pruner.__class__.__name__
         info['pruning_ratios'] = self.pruning_ratios
         info['pruning_ratios'] = self.pruning_ratios
-        info['pruner_inputs'] = self.pruner.inputs
+        pruner_inputs = self.pruner.inputs
+        if self.model_type == 'detector':
+            pruner_inputs = {
+                k: v.tolist()
+                for k, v in pruner_inputs[0].items()
+            }
+        info['pruner_inputs'] = pruner_inputs
+
         return info
         return info
 
 
     def get_quant_info(self):
     def get_quant_info(self):

+ 5 - 0
dygraph/paddlex/cv/models/load_model.py

@@ -69,6 +69,11 @@ def load_model(model_dir):
             with open(osp.join(model_dir, "prune.yml")) as f:
             with open(osp.join(model_dir, "prune.yml")) as f:
                 pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
                 pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
                 inputs = pruning_info['pruner_inputs']
                 inputs = pruning_info['pruner_inputs']
+                if model.model_type == 'detector':
+                    inputs = [{
+                        k: paddle.to_tensor(v)
+                        for k, v in inputs.items()
+                    }]
                 model.pruner = getattr(paddleslim, pruning_info['pruner'])(
                 model.pruner = getattr(paddleslim, pruning_info['pruner'])(
                     model.net, inputs=inputs)
                     model.net, inputs=inputs)
                 model.pruning_ratios = pruning_info['pruning_ratios']
                 model.pruning_ratios = pruning_info['pruning_ratios']