optimizer.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.optimizer as optimizer
  21. import paddle.regularizer as regularizer
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. __all__ = ['LearningRate', 'OptimizerBuilder']
  24. from paddlex.ppdet.utils.logger import setup_logger
  25. logger = setup_logger(__name__)
  26. @serializable
  27. class CosineDecay(object):
  28. """
  29. Cosine learning rate decay
  30. Args:
  31. max_epochs (int): max epochs for the training process.
  32. if you commbine cosine decay with warmup, it is recommended that
  33. the max_iters is much larger than the warmup iter
  34. """
  35. def __init__(self, max_epochs=1000, use_warmup=True):
  36. self.max_epochs = max_epochs
  37. self.use_warmup = use_warmup
  38. def __call__(self,
  39. base_lr=None,
  40. boundary=None,
  41. value=None,
  42. step_per_epoch=None):
  43. assert base_lr is not None, "either base LR or values should be provided"
  44. max_iters = self.max_epochs * int(step_per_epoch)
  45. if boundary is not None and value is not None and self.use_warmup:
  46. for i in range(int(boundary[-1]), max_iters):
  47. boundary.append(i)
  48. decayed_lr = base_lr * 0.5 * (
  49. math.cos(i * math.pi / max_iters) + 1)
  50. value.append(decayed_lr)
  51. return optimizer.lr.PiecewiseDecay(boundary, value)
  52. return optimizer.lr.CosineAnnealingDecay(base_lr, T_max=max_iters)
  53. @serializable
  54. class PiecewiseDecay(object):
  55. """
  56. Multi step learning rate decay
  57. Args:
  58. gamma (float | list): decay factor
  59. milestones (list): steps at which to decay learning rate
  60. """
  61. def __init__(self,
  62. gamma=[0.1, 0.01],
  63. milestones=[8, 11],
  64. values=None,
  65. use_warmup=True):
  66. super(PiecewiseDecay, self).__init__()
  67. if type(gamma) is not list:
  68. self.gamma = []
  69. for i in range(len(milestones)):
  70. self.gamma.append(gamma / 10**i)
  71. else:
  72. self.gamma = gamma
  73. self.milestones = milestones
  74. self.values = values
  75. self.use_warmup = use_warmup
  76. def __call__(self,
  77. base_lr=None,
  78. boundary=None,
  79. value=None,
  80. step_per_epoch=None):
  81. if boundary is not None and self.use_warmup:
  82. boundary.extend([int(step_per_epoch) * i for i in self.milestones])
  83. else:
  84. # do not use LinearWarmup
  85. boundary = [int(step_per_epoch) * i for i in self.milestones]
  86. value = [base_lr] # during step[0, boundary[0]] is base_lr
  87. # self.values is setted directly in config
  88. if self.values is not None:
  89. assert len(self.milestones) + 1 == len(self.values)
  90. return optimizer.lr.PiecewiseDecay(boundary, self.values)
  91. # value is computed by self.gamma
  92. value = value if value is not None else [base_lr]
  93. for i in self.gamma:
  94. value.append(base_lr * i)
  95. return optimizer.lr.PiecewiseDecay(boundary, value)
  96. @serializable
  97. class LinearWarmup(object):
  98. """
  99. Warm up learning rate linearly
  100. Args:
  101. steps (int): warm up steps
  102. start_factor (float): initial learning rate factor
  103. """
  104. def __init__(self, steps=500, start_factor=1. / 3):
  105. super(LinearWarmup, self).__init__()
  106. self.steps = steps
  107. self.start_factor = start_factor
  108. def __call__(self, base_lr, step_per_epoch):
  109. boundary = []
  110. value = []
  111. for i in range(self.steps + 1):
  112. if self.steps > 0:
  113. alpha = i / self.steps
  114. factor = self.start_factor * (1 - alpha) + alpha
  115. lr = base_lr * factor
  116. value.append(lr)
  117. if i > 0:
  118. boundary.append(i)
  119. return boundary, value
  120. @serializable
  121. class BurninWarmup(object):
  122. """
  123. Warm up learning rate in burnin mode
  124. Args:
  125. steps (int): warm up steps
  126. """
  127. def __init__(self, steps=1000):
  128. super(BurninWarmup, self).__init__()
  129. self.steps = steps
  130. def __call__(self, base_lr, step_per_epoch):
  131. boundary = []
  132. value = []
  133. burnin = min(self.steps, step_per_epoch)
  134. for i in range(burnin + 1):
  135. factor = (i * 1.0 / burnin)**4
  136. lr = base_lr * factor
  137. value.append(lr)
  138. if i > 0:
  139. boundary.append(i)
  140. return boundary, value
  141. @register
  142. class LearningRate(object):
  143. """
  144. Learning Rate configuration
  145. Args:
  146. base_lr (float): base learning rate
  147. schedulers (list): learning rate schedulers
  148. """
  149. __category__ = 'optim'
  150. def __init__(self,
  151. base_lr=0.01,
  152. schedulers=[PiecewiseDecay(), LinearWarmup()]):
  153. super(LearningRate, self).__init__()
  154. self.base_lr = base_lr
  155. self.schedulers = schedulers
  156. def __call__(self, step_per_epoch):
  157. assert len(self.schedulers) >= 1
  158. if not self.schedulers[0].use_warmup:
  159. return self.schedulers[0](base_lr=self.base_lr,
  160. step_per_epoch=step_per_epoch)
  161. # TODO: split warmup & decay
  162. # warmup
  163. boundary, value = self.schedulers[1](self.base_lr, step_per_epoch)
  164. # decay
  165. decay_lr = self.schedulers[0](self.base_lr, boundary, value,
  166. step_per_epoch)
  167. return decay_lr
  168. @register
  169. class OptimizerBuilder():
  170. """
  171. Build optimizer handles
  172. Args:
  173. regularizer (object): an `Regularizer` instance
  174. optimizer (object): an `Optimizer` instance
  175. """
  176. __category__ = 'optim'
  177. def __init__(self,
  178. clip_grad_by_norm=None,
  179. regularizer={'type': 'L2',
  180. 'factor': .0001},
  181. optimizer={'type': 'Momentum',
  182. 'momentum': .9}):
  183. self.clip_grad_by_norm = clip_grad_by_norm
  184. self.regularizer = regularizer
  185. self.optimizer = optimizer
  186. def __call__(self, learning_rate, params=None):
  187. if self.clip_grad_by_norm is not None:
  188. grad_clip = nn.ClipGradByGlobalNorm(
  189. clip_norm=self.clip_grad_by_norm)
  190. else:
  191. grad_clip = None
  192. if self.regularizer and self.regularizer != 'None':
  193. reg_type = self.regularizer['type'] + 'Decay'
  194. reg_factor = self.regularizer['factor']
  195. regularization = getattr(regularizer, reg_type)(reg_factor)
  196. else:
  197. regularization = None
  198. optim_args = self.optimizer.copy()
  199. optim_type = optim_args['type']
  200. del optim_args['type']
  201. if optim_type != 'AdamW':
  202. optim_args['weight_decay'] = regularization
  203. op = getattr(optimizer, optim_type)
  204. return op(learning_rate=learning_rate,
  205. parameters=params,
  206. grad_clip=grad_clip,
  207. **optim_args)
  208. class ModelEMA(object):
  209. def __init__(self, decay, model, use_thres_step=False):
  210. self.step = 0
  211. self.decay = decay
  212. self.state_dict = dict()
  213. for k, v in model.state_dict().items():
  214. self.state_dict[k] = paddle.zeros_like(v)
  215. self.use_thres_step = use_thres_step
  216. def update(self, model):
  217. if self.use_thres_step:
  218. decay = min(self.decay, (1 + self.step) / (10 + self.step))
  219. else:
  220. decay = self.decay
  221. self._decay = decay
  222. model_dict = model.state_dict()
  223. for k, v in self.state_dict.items():
  224. v = decay * v + (1 - decay) * model_dict[k]
  225. v.stop_gradient = True
  226. self.state_dict[k] = v
  227. self.step += 1
  228. def apply(self):
  229. if self.step == 0:
  230. return self.state_dict
  231. state_dict = dict()
  232. for k, v in self.state_dict.items():
  233. v = v / (1 - self._decay**self.step)
  234. v.stop_gradient = True
  235. state_dict[k] = v
  236. return state_dict