optimizer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.optimizer as optimizer
  21. import paddle.regularizer as regularizer
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. __all__ = ['LearningRate', 'OptimizerBuilder']
  24. from paddlex.ppdet.utils.logger import setup_logger
  25. logger = setup_logger(__name__)
  26. @serializable
  27. class CosineDecay(object):
  28. """
  29. Cosine learning rate decay
  30. Args:
  31. max_epochs (int): max epochs for the training process.
  32. if you commbine cosine decay with warmup, it is recommended that
  33. the max_iters is much larger than the warmup iter
  34. """
  35. def __init__(self, max_epochs=1000, use_warmup=True):
  36. self.max_epochs = max_epochs
  37. self.use_warmup = use_warmup
  38. def __call__(self,
  39. base_lr=None,
  40. boundary=None,
  41. value=None,
  42. step_per_epoch=None):
  43. assert base_lr is not None, "either base LR or values should be provided"
  44. max_iters = self.max_epochs * int(step_per_epoch)
  45. if boundary is not None and value is not None and self.use_warmup:
  46. warmup_iters = len(boundary)
  47. for i in range(int(boundary[-1]), max_iters):
  48. boundary.append(i)
  49. decayed_lr = base_lr * 0.5 * (math.cos(
  50. (i - warmup_iters) * math.pi /
  51. (max_iters - warmup_iters)) + 1)
  52. value.append(decayed_lr)
  53. return optimizer.lr.PiecewiseDecay(boundary, value)
  54. return optimizer.lr.CosineAnnealingDecay(base_lr, T_max=max_iters)
  55. @serializable
  56. class PiecewiseDecay(object):
  57. """
  58. Multi step learning rate decay
  59. Args:
  60. gamma (float | list): decay factor
  61. milestones (list): steps at which to decay learning rate
  62. """
  63. def __init__(self,
  64. gamma=[0.1, 0.01],
  65. milestones=[8, 11],
  66. values=None,
  67. use_warmup=True):
  68. super(PiecewiseDecay, self).__init__()
  69. if type(gamma) is not list:
  70. self.gamma = []
  71. for i in range(len(milestones)):
  72. self.gamma.append(gamma / 10**i)
  73. else:
  74. self.gamma = gamma
  75. self.milestones = milestones
  76. self.values = values
  77. self.use_warmup = use_warmup
  78. def __call__(self,
  79. base_lr=None,
  80. boundary=None,
  81. value=None,
  82. step_per_epoch=None):
  83. if boundary is not None and self.use_warmup:
  84. boundary.extend([int(step_per_epoch) * i for i in self.milestones])
  85. else:
  86. # do not use LinearWarmup
  87. boundary = [int(step_per_epoch) * i for i in self.milestones]
  88. value = [base_lr] # during step[0, boundary[0]] is base_lr
  89. # self.values is setted directly in config
  90. if self.values is not None:
  91. assert len(self.milestones) + 1 == len(self.values)
  92. return optimizer.lr.PiecewiseDecay(boundary, self.values)
  93. # value is computed by self.gamma
  94. value = value if value is not None else [base_lr]
  95. for i in self.gamma:
  96. value.append(base_lr * i)
  97. return optimizer.lr.PiecewiseDecay(boundary, value)
  98. @serializable
  99. class LinearWarmup(object):
  100. """
  101. Warm up learning rate linearly
  102. Args:
  103. steps (int): warm up steps
  104. start_factor (float): initial learning rate factor
  105. """
  106. def __init__(self, steps=500, start_factor=1. / 3):
  107. super(LinearWarmup, self).__init__()
  108. self.steps = steps
  109. self.start_factor = start_factor
  110. def __call__(self, base_lr, step_per_epoch):
  111. boundary = []
  112. value = []
  113. for i in range(self.steps + 1):
  114. if self.steps > 0:
  115. alpha = i / self.steps
  116. factor = self.start_factor * (1 - alpha) + alpha
  117. lr = base_lr * factor
  118. value.append(lr)
  119. if i > 0:
  120. boundary.append(i)
  121. return boundary, value
  122. @serializable
  123. class BurninWarmup(object):
  124. """
  125. Warm up learning rate in burnin mode
  126. Args:
  127. steps (int): warm up steps
  128. """
  129. def __init__(self, steps=1000):
  130. super(BurninWarmup, self).__init__()
  131. self.steps = steps
  132. def __call__(self, base_lr, step_per_epoch):
  133. boundary = []
  134. value = []
  135. burnin = min(self.steps, step_per_epoch)
  136. for i in range(burnin + 1):
  137. factor = (i * 1.0 / burnin)**4
  138. lr = base_lr * factor
  139. value.append(lr)
  140. if i > 0:
  141. boundary.append(i)
  142. return boundary, value
  143. @register
  144. class LearningRate(object):
  145. """
  146. Learning Rate configuration
  147. Args:
  148. base_lr (float): base learning rate
  149. schedulers (list): learning rate schedulers
  150. """
  151. __category__ = 'optim'
  152. def __init__(self,
  153. base_lr=0.01,
  154. schedulers=[PiecewiseDecay(), LinearWarmup()]):
  155. super(LearningRate, self).__init__()
  156. self.base_lr = base_lr
  157. self.schedulers = schedulers
  158. def __call__(self, step_per_epoch):
  159. assert len(self.schedulers) >= 1
  160. if not self.schedulers[0].use_warmup:
  161. return self.schedulers[0](base_lr=self.base_lr,
  162. step_per_epoch=step_per_epoch)
  163. # TODO: split warmup & decay
  164. # warmup
  165. boundary, value = self.schedulers[1](self.base_lr, step_per_epoch)
  166. # decay
  167. decay_lr = self.schedulers[0](self.base_lr, boundary, value,
  168. step_per_epoch)
  169. return decay_lr
  170. @register
  171. class OptimizerBuilder():
  172. """
  173. Build optimizer handles
  174. Args:
  175. regularizer (object): an `Regularizer` instance
  176. optimizer (object): an `Optimizer` instance
  177. """
  178. __category__ = 'optim'
  179. def __init__(self,
  180. clip_grad_by_norm=None,
  181. regularizer={'type': 'L2',
  182. 'factor': .0001},
  183. optimizer={'type': 'Momentum',
  184. 'momentum': .9}):
  185. self.clip_grad_by_norm = clip_grad_by_norm
  186. self.regularizer = regularizer
  187. self.optimizer = optimizer
  188. def __call__(self, learning_rate, model=None):
  189. if self.clip_grad_by_norm is not None:
  190. grad_clip = nn.ClipGradByGlobalNorm(
  191. clip_norm=self.clip_grad_by_norm)
  192. else:
  193. grad_clip = None
  194. if self.regularizer and self.regularizer != 'None':
  195. reg_type = self.regularizer['type'] + 'Decay'
  196. reg_factor = self.regularizer['factor']
  197. regularization = getattr(regularizer, reg_type)(reg_factor)
  198. else:
  199. regularization = None
  200. optim_args = self.optimizer.copy()
  201. optim_type = optim_args['type']
  202. del optim_args['type']
  203. if optim_type != 'AdamW':
  204. optim_args['weight_decay'] = regularization
  205. op = getattr(optimizer, optim_type)
  206. if 'without_weight_decay_params' in optim_args:
  207. keys = optim_args['without_weight_decay_params']
  208. params = [{
  209. 'params': [
  210. p for n, p in model.named_parameters()
  211. if any([k in n for k in keys])
  212. ],
  213. 'weight_decay': 0.
  214. }, {
  215. 'params': [
  216. p for n, p in model.named_parameters()
  217. if all([k not in n for k in keys])
  218. ]
  219. }]
  220. del optim_args['without_weight_decay_params']
  221. else:
  222. params = model.parameters()
  223. return op(learning_rate=learning_rate,
  224. parameters=params,
  225. grad_clip=grad_clip,
  226. **optim_args)
  227. class ModelEMA(object):
  228. """
  229. Exponential Weighted Average for Deep Neutal Networks
  230. Args:
  231. model (nn.Layer): Detector of model.
  232. decay (int): The decay used for updating ema parameter.
  233. Ema's parameter are updated with the formula:
  234. `ema_param = decay * ema_param + (1 - decay) * cur_param`.
  235. Defaults is 0.9998.
  236. use_thres_step (bool): Whether set decay by thres_step or not
  237. cycle_epoch (int): The epoch of interval to reset ema_param and
  238. step. Defaults is -1, which means not reset. Its function is to
  239. add a regular effect to ema, which is set according to experience
  240. and is effective when the total training epoch is large.
  241. """
  242. def __init__(self,
  243. model,
  244. decay=0.9998,
  245. use_thres_step=False,
  246. cycle_epoch=-1):
  247. self.step = 0
  248. self.epoch = 0
  249. self.decay = decay
  250. self.state_dict = dict()
  251. for k, v in model.state_dict().items():
  252. self.state_dict[k] = paddle.zeros_like(v)
  253. self.use_thres_step = use_thres_step
  254. self.cycle_epoch = cycle_epoch
  255. def reset(self):
  256. self.step = 0
  257. self.epoch = 0
  258. for k, v in self.state_dict.items():
  259. self.state_dict[k] = paddle.zeros_like(v)
  260. def update(self, model):
  261. if self.use_thres_step:
  262. decay = min(self.decay, (1 + self.step) / (10 + self.step))
  263. else:
  264. decay = self.decay
  265. self._decay = decay
  266. model_dict = model.state_dict()
  267. for k, v in self.state_dict.items():
  268. v = decay * v + (1 - decay) * model_dict[k]
  269. v.stop_gradient = True
  270. self.state_dict[k] = v
  271. self.step += 1
  272. def apply(self):
  273. if self.step == 0:
  274. return self.state_dict
  275. state_dict = dict()
  276. for k, v in self.state_dict.items():
  277. v = v / (1 - self._decay**self.step)
  278. v.stop_gradient = True
  279. state_dict[k] = v
  280. self.epoch += 1
  281. if self.cycle_epoch > 0 and self.epoch == self.cycle_epoch:
  282. self.reset()
  283. return state_dict