loss.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn.functional as F
  16. __all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss', 'MultiLabelLoss']
  17. class Loss(object):
  18. """
  19. Loss
  20. """
  21. def __init__(self, class_dim=1000, epsilon=None):
  22. assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
  23. self._class_dim = class_dim
  24. if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
  25. self._epsilon = epsilon
  26. self._label_smoothing = True
  27. else:
  28. self._epsilon = None
  29. self._label_smoothing = False
  30. def _labelsmoothing(self, target):
  31. if target.shape[-1] != self._class_dim:
  32. one_hot_target = F.one_hot(target, self._class_dim)
  33. else:
  34. one_hot_target = target
  35. soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
  36. soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
  37. return soft_target
  38. def _binary_crossentropy(self, input, target):
  39. if self._label_smoothing:
  40. target = self._labelsmoothing(target)
  41. cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
  42. else:
  43. cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
  44. avg_cost = paddle.mean(cost)
  45. return avg_cost
  46. def _crossentropy(self, input, target):
  47. if self._label_smoothing:
  48. target = self._labelsmoothing(target)
  49. input = -F.log_softmax(input, axis=-1)
  50. cost = paddle.sum(target * input, axis=-1)
  51. else:
  52. cost = F.cross_entropy(input=input, label=target)
  53. avg_cost = paddle.mean(cost)
  54. return avg_cost
  55. def _kldiv(self, input, target, name=None):
  56. eps = 1.0e-10
  57. cost = target * paddle.log(
  58. (target + eps) / (input + eps)) * self._class_dim
  59. return cost
  60. def _jsdiv(self, input, target):
  61. input = F.softmax(input)
  62. target = F.softmax(target)
  63. cost = self._kldiv(input, target) + self._kldiv(target, input)
  64. cost = cost / 2
  65. avg_cost = paddle.mean(cost)
  66. return avg_cost
  67. def __call__(self, input, target):
  68. pass
  69. class MultiLabelLoss(Loss):
  70. """
  71. Multilabel loss based binary cross entropy
  72. """
  73. def __init__(self, class_dim=1000, epsilon=None):
  74. super(MultiLabelLoss, self).__init__(class_dim, epsilon)
  75. def __call__(self, input, target):
  76. cost = self._binary_crossentropy(input, target)
  77. return cost
  78. class CELoss(Loss):
  79. """
  80. Cross entropy loss
  81. """
  82. def __init__(self, class_dim=1000, epsilon=None):
  83. super(CELoss, self).__init__(class_dim, epsilon)
  84. def __call__(self, input, target):
  85. cost = self._crossentropy(input, target)
  86. return cost
  87. class MixCELoss(Loss):
  88. """
  89. Cross entropy loss with mix(mixup, cutmix, fixmix)
  90. """
  91. def __init__(self, class_dim=1000, epsilon=None):
  92. super(MixCELoss, self).__init__(class_dim, epsilon)
  93. def __call__(self, input, target0, target1, lam):
  94. cost0 = self._crossentropy(input, target0)
  95. cost1 = self._crossentropy(input, target1)
  96. cost = lam * cost0 + (1.0 - lam) * cost1
  97. avg_cost = paddle.mean(cost)
  98. return avg_cost
  99. class GoogLeNetLoss(Loss):
  100. """
  101. Cross entropy loss used after googlenet
  102. """
  103. def __init__(self, class_dim=1000, epsilon=None):
  104. super(GoogLeNetLoss, self).__init__(class_dim, epsilon)
  105. def __call__(self, input0, input1, input2, target):
  106. cost0 = self._crossentropy(input0, target)
  107. cost1 = self._crossentropy(input1, target)
  108. cost2 = self._crossentropy(input2, target)
  109. cost = cost0 + 0.3 * cost1 + 0.3 * cost2
  110. avg_cost = paddle.mean(cost)
  111. return avg_cost
  112. class JSDivLoss(Loss):
  113. """
  114. JSDiv loss
  115. """
  116. def __init__(self, class_dim=1000, epsilon=None):
  117. super(JSDivLoss, self).__init__(class_dim, epsilon)
  118. def __call__(self, input, target):
  119. cost = self._jsdiv(input, target)
  120. return cost