|
|
@@ -69,6 +69,11 @@ def load_model(model_dir):
|
|
|
with open(osp.join(model_dir, "prune.yml")) as f:
|
|
|
pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
|
|
|
inputs = pruning_info['pruner_inputs']
|
|
|
+ if model.model_type == 'detector':
|
|
|
+ inputs = [{
|
|
|
+ k: paddle.to_tensor(v)
|
|
|
+ for k, v in inputs.items()
|
|
|
+ }]
|
|
|
model.pruner = getattr(paddleslim, pruning_info['pruner'])(
|
|
|
model.net, inputs=inputs)
|
|
|
model.pruning_ratios = pruning_info['pruning_ratios']
|