ema.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import numpy as np
  16. class ExponentialMovingAverage():
  17. """
  18. Exponential Moving Average
  19. Code was heavily based on https://github.com/Wanger-SJTU/SegToolbox.Pytorch/blob/master/lib/utils/ema.py
  20. """
  21. def __init__(self, model, decay, thres_steps=True):
  22. self._model = model
  23. self._decay = decay
  24. self._thres_steps = thres_steps
  25. self._shadow = {}
  26. self._backup = {}
  27. def register(self):
  28. self._update_step = 0
  29. for name, param in self._model.named_parameters():
  30. if param.stop_gradient is False:
  31. self._shadow[name] = param.numpy().copy()
  32. def update(self):
  33. decay = min(self._decay, (1 + self._update_step) / (
  34. 10 + self._update_step)) if self._thres_steps else self._decay
  35. for name, param in self._model.named_parameters():
  36. if param.stop_gradient is False:
  37. assert name in self._shadow
  38. new_val = np.array(param.numpy().copy())
  39. old_val = np.array(self._shadow[name])
  40. new_average = decay * old_val + (1 - decay) * new_val
  41. self._shadow[name] = new_average
  42. self._update_step += 1
  43. return decay
  44. def apply(self):
  45. for name, param in self._model.named_parameters():
  46. if param.stop_gradient is False:
  47. assert name in self._shadow
  48. self._backup[name] = np.array(param.numpy().copy())
  49. param.set_value(np.array(self._shadow[name]))
  50. def restore(self):
  51. for name, param in self._model.named_parameters():
  52. if param.stop_gradient is False:
  53. assert name in self._backup
  54. param.set_value(self._backup[name])
  55. self._backup = {}