| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import paddle
- import numpy as np
- class ExponentialMovingAverage():
- """
- Exponential Moving Average
- Code was heavily based on https://github.com/Wanger-SJTU/SegToolbox.Pytorch/blob/master/lib/utils/ema.py
- """
- def __init__(self, model, decay, thres_steps=True):
- self._model = model
- self._decay = decay
- self._thres_steps = thres_steps
- self._shadow = {}
- self._backup = {}
- def register(self):
- self._update_step = 0
- for name, param in self._model.named_parameters():
- if param.stop_gradient is False:
- self._shadow[name] = param.numpy().copy()
- def update(self):
- decay = min(self._decay, (1 + self._update_step) / (
- 10 + self._update_step)) if self._thres_steps else self._decay
- for name, param in self._model.named_parameters():
- if param.stop_gradient is False:
- assert name in self._shadow
- new_val = np.array(param.numpy().copy())
- old_val = np.array(self._shadow[name])
- new_average = decay * old_val + (1 - decay) * new_val
- self._shadow[name] = new_average
- self._update_step += 1
- return decay
- def apply(self):
- for name, param in self._model.named_parameters():
- if param.stop_gradient is False:
- assert name in self._shadow
- self._backup[name] = np.array(param.numpy().copy())
- param.set_value(np.array(self._shadow[name]))
- def restore(self):
- for name, param in self._model.named_parameters():
- if param.stop_gradient is False:
- assert name in self._backup
- param.set_value(self._backup[name])
- self._backup = {}
|