ema.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. class ExponentialMovingAverage(object):
  16. def __init__(self, decay, model, use_thres_step=False):
  17. self.step = 0
  18. self.decay = decay
  19. self.shadow = dict()
  20. for k, v in model.state_dict().items():
  21. self.shadow[k] = paddle.zeros_like(v)
  22. self.use_thres_step = use_thres_step
  23. def update(self, model):
  24. if self.use_thres_step:
  25. decay = min(self.decay, (1 + self.step) / (10 + self.step))
  26. else:
  27. decay = self.decay
  28. self._decay = decay
  29. model_dict = model.state_dict()
  30. for k, v in self.shadow.items():
  31. v = decay * v + (1 - decay) * model_dict[k]
  32. v.stop_gradient = True
  33. self.shadow[k] = v
  34. self.step += 1
  35. def apply(self):
  36. if self.step == 0:
  37. return self.shadow
  38. state_dict = dict()
  39. for k, v in self.shadow.items():
  40. v = v / (1 - self._decay**self.step)
  41. v.stop_gradient = True
  42. state_dict[k] = v
  43. return state_dict