Explorar o código

convert tensor to list while saving pruner_inputs

will-jl944 %!s(int64=4) %!d(string=hai) anos
pai
achega
ed65b01157
Modificáronse 2 ficheiros con 13 adicións e 1 borrados
  1. 8 1
      dygraph/paddlex/cv/models/base.py
  2. 5 0
      dygraph/paddlex/cv/models/load_model.py

+ 8 - 1
dygraph/paddlex/cv/models/base.py

@@ -122,7 +122,14 @@ class BaseModel:
         info = dict()
         info['pruner'] = self.pruner.__class__.__name__
         info['pruning_ratios'] = self.pruning_ratios
-        info['pruner_inputs'] = self.pruner.inputs
+        pruner_inputs = self.pruner.inputs
+        if self.model_type == 'detector':
+            pruner_inputs = {
+                k: v.tolist()
+                for k, v in pruner_inputs[0].items()
+            }
+        info['pruner_inputs'] = pruner_inputs
+
         return info
 
     def get_quant_info(self):

+ 5 - 0
dygraph/paddlex/cv/models/load_model.py

@@ -69,6 +69,11 @@ def load_model(model_dir):
             with open(osp.join(model_dir, "prune.yml")) as f:
                 pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
                 inputs = pruning_info['pruner_inputs']
+                if model.model_type == 'detector':
+                    inputs = [{
+                        k: paddle.to_tensor(v)
+                        for k, v in inputs.items()
+                    }]
                 model.pruner = getattr(paddleslim, pruning_info['pruner'])(
                     model.net, inputs=inputs)
                 model.pruning_ratios = pruning_info['pruning_ratios']