prune.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddleslim
  16. FILTER_DIM = paddleslim.dygraph.prune.filter_pruner.FILTER_DIM
  17. def _pruner_eval_fn(model, eval_dataset, batch_size):
  18. metric = model.evaluate(eval_dataset, batch_size=batch_size)
  19. return metric[list(metric.keys())[0]]
  20. def _pruner_template_input(sample, model_type):
  21. if model_type == 'detector':
  22. template_input = [{
  23. "image": paddle.ones(
  24. shape=[1, 3] + list(sample["image"].shape[:2]),
  25. dtype='float32'),
  26. "im_shape": paddle.full(
  27. [1, 2], 640, dtype='float32'),
  28. "scale_factor": paddle.ones(
  29. shape=[1, 2], dtype='float32')
  30. }]
  31. else:
  32. template_input = [1] + list(sample[0].shape)
  33. return template_input
  34. def sensitive_prune(pruner, pruned_flops, skip_vars=[], align=None):
  35. # skip depthwise convolutions
  36. for layer in pruner.model.sublayers():
  37. if isinstance(layer,
  38. paddle.nn.layer.conv.Conv2D) and layer._groups > 1:
  39. for param in layer.parameters(include_sublayers=False):
  40. skip_vars.append(param.name)
  41. pruner.restore()
  42. ratios, pruned_flops = pruner.get_ratios_by_sensitivity(
  43. pruned_flops, align=align, dims=FILTER_DIM, skip_vars=skip_vars)
  44. pruner.plan = pruner.prune_vars(ratios, FILTER_DIM)
  45. pruner.plan._pruned_flops = pruned_flops
  46. return pruner.plan, ratios