Переглянути джерело

convert tensor to list while saving pruner_inputs

will-jl944 4 роки тому
батько
коміт
ed65b01157

+ 8 - 1
dygraph/paddlex/cv/models/base.py

@@ -122,7 +122,14 @@ class BaseModel:
         info = dict()
         info['pruner'] = self.pruner.__class__.__name__
         info['pruning_ratios'] = self.pruning_ratios
-        info['pruner_inputs'] = self.pruner.inputs
+        pruner_inputs = self.pruner.inputs
+        if self.model_type == 'detector':
+            pruner_inputs = {
+                k: v.tolist()
+                for k, v in pruner_inputs[0].items()
+            }
+        info['pruner_inputs'] = pruner_inputs
+
         return info
 
     def get_quant_info(self):

+ 5 - 0
dygraph/paddlex/cv/models/load_model.py

@@ -69,6 +69,11 @@ def load_model(model_dir):
             with open(osp.join(model_dir, "prune.yml")) as f:
                 pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
                 inputs = pruning_info['pruner_inputs']
+                if model.model_type == 'detector':
+                    inputs = [{
+                        k: paddle.to_tensor(v)
+                        for k, v in inputs.items()
+                    }]
                 model.pruner = getattr(paddleslim, pruning_info['pruner'])(
                     model.net, inputs=inputs)
                 model.pruning_ratios = pruning_info['pruning_ratios']