optimizer.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import sys
  18. import paddle
  19. import paddle.regularizer as regularizer
  20. __all__ = ['OptimizerBuilder']
  21. class L1Decay(object):
  22. """
  23. L1 Weight Decay Regularization, which encourages the weights to be sparse.
  24. Args:
  25. factor(float): regularization coeff. Default:0.0.
  26. """
  27. def __init__(self, factor=0.0):
  28. super(L1Decay, self).__init__()
  29. self.factor = factor
  30. def __call__(self):
  31. reg = regularizer.L1Decay(self.factor)
  32. return reg
  33. class L2Decay(object):
  34. """
  35. L2 Weight Decay Regularization, which encourages the weights to be sparse.
  36. Args:
  37. factor(float): regularization coeff. Default:0.0.
  38. """
  39. def __init__(self, factor=0.0):
  40. super(L2Decay, self).__init__()
  41. self.factor = factor
  42. def __call__(self):
  43. reg = regularizer.L2Decay(self.factor)
  44. return reg
  45. class Momentum(object):
  46. """
  47. Simple Momentum optimizer with velocity state.
  48. Args:
  49. learning_rate (float|Variable) - The learning rate used to update parameters.
  50. Can be a float value or a Variable with one float value as data element.
  51. momentum (float) - Momentum factor.
  52. regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
  53. """
  54. def __init__(self,
  55. learning_rate,
  56. momentum,
  57. parameter_list=None,
  58. regularization=None,
  59. multi_precision=False,
  60. **args):
  61. super(Momentum, self).__init__()
  62. self.learning_rate = learning_rate
  63. self.momentum = momentum
  64. self.parameter_list = parameter_list
  65. self.regularization = regularization
  66. self.multi_precision = multi_precision
  67. def __call__(self):
  68. opt = paddle.optimizer.Momentum(
  69. learning_rate=self.learning_rate,
  70. momentum=self.momentum,
  71. parameters=self.parameter_list,
  72. weight_decay=self.regularization,
  73. multi_precision=self.multi_precision)
  74. return opt
  75. class RMSProp(object):
  76. """
  77. Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
  78. Args:
  79. learning_rate (float|Variable) - The learning rate used to update parameters.
  80. Can be a float value or a Variable with one float value as data element.
  81. momentum (float) - Momentum factor.
  82. rho (float) - rho value in equation.
  83. epsilon (float) - avoid division by zero, default is 1e-6.
  84. regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
  85. """
  86. def __init__(self,
  87. learning_rate,
  88. momentum,
  89. rho=0.95,
  90. epsilon=1e-6,
  91. parameter_list=None,
  92. regularization=None,
  93. **args):
  94. super(RMSProp, self).__init__()
  95. self.learning_rate = learning_rate
  96. self.momentum = momentum
  97. self.rho = rho
  98. self.epsilon = epsilon
  99. self.parameter_list = parameter_list
  100. self.regularization = regularization
  101. def __call__(self):
  102. opt = paddle.optimizer.RMSProp(
  103. learning_rate=self.learning_rate,
  104. momentum=self.momentum,
  105. rho=self.rho,
  106. epsilon=self.epsilon,
  107. parameters=self.parameter_list,
  108. weight_decay=self.regularization)
  109. return opt
  110. class OptimizerBuilder(object):
  111. """
  112. Build optimizer
  113. Args:
  114. function(str): optimizer name of learning rate
  115. params(dict): parameters used for init the class
  116. regularizer (dict): parameters used for create regularization
  117. """
  118. def __init__(self,
  119. function='Momentum',
  120. params={'momentum': 0.9},
  121. regularizer=None):
  122. self.function = function
  123. self.params = params
  124. # create regularizer
  125. if regularizer is not None:
  126. mod = sys.modules[__name__]
  127. reg_func = regularizer['function'] + 'Decay'
  128. del regularizer['function']
  129. reg = getattr(mod, reg_func)(**regularizer)()
  130. self.params['regularization'] = reg
  131. def __call__(self, learning_rate, parameter_list=None):
  132. mod = sys.modules[__name__]
  133. opt = getattr(mod, self.function)
  134. return opt(learning_rate=learning_rate,
  135. parameter_list=parameter_list,
  136. **self.params)()