keypoint_loss.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from itertools import cycle, islice
  18. from collections import abc
  19. import paddle
  20. import paddle.nn as nn
  21. from paddlex.ppdet.core.workspace import register, serializable
  22. __all__ = ['HrHRNetLoss', 'KeyPointMSELoss']
  23. @register
  24. @serializable
  25. class KeyPointMSELoss(nn.Layer):
  26. def __init__(self, use_target_weight=True, loss_scale=0.5):
  27. """
  28. KeyPointMSELoss layer
  29. Args:
  30. use_target_weight (bool): whether to use target weight
  31. """
  32. super(KeyPointMSELoss, self).__init__()
  33. self.criterion = nn.MSELoss(reduction='mean')
  34. self.use_target_weight = use_target_weight
  35. self.loss_scale = loss_scale
  36. def forward(self, output, records):
  37. target = records['target']
  38. target_weight = records['target_weight']
  39. batch_size = output.shape[0]
  40. num_joints = output.shape[1]
  41. heatmaps_pred = output.reshape(
  42. (batch_size, num_joints, -1)).split(num_joints, 1)
  43. heatmaps_gt = target.reshape(
  44. (batch_size, num_joints, -1)).split(num_joints, 1)
  45. loss = 0
  46. for idx in range(num_joints):
  47. heatmap_pred = heatmaps_pred[idx].squeeze()
  48. heatmap_gt = heatmaps_gt[idx].squeeze()
  49. if self.use_target_weight:
  50. loss += self.loss_scale * self.criterion(
  51. heatmap_pred.multiply(target_weight[:, idx]),
  52. heatmap_gt.multiply(target_weight[:, idx]))
  53. else:
  54. loss += self.loss_scale * self.criterion(heatmap_pred,
  55. heatmap_gt)
  56. keypoint_losses = dict()
  57. keypoint_losses['loss'] = loss / num_joints
  58. return keypoint_losses
  59. @register
  60. @serializable
  61. class HrHRNetLoss(nn.Layer):
  62. def __init__(self, num_joints, swahr):
  63. """
  64. HrHRNetLoss layer
  65. Args:
  66. num_joints (int): number of keypoints
  67. """
  68. super(HrHRNetLoss, self).__init__()
  69. if swahr:
  70. self.heatmaploss = HeatMapSWAHRLoss(num_joints)
  71. else:
  72. self.heatmaploss = HeatMapLoss()
  73. self.aeloss = AELoss()
  74. self.ziploss = ZipLoss(
  75. [self.heatmaploss, self.heatmaploss, self.aeloss])
  76. def forward(self, inputs, records):
  77. targets = []
  78. targets.append([records['heatmap_gt1x'], records['mask_1x']])
  79. targets.append([records['heatmap_gt2x'], records['mask_2x']])
  80. targets.append(records['tagmap'])
  81. keypoint_losses = dict()
  82. loss = self.ziploss(inputs, targets)
  83. keypoint_losses['heatmap_loss'] = loss[0] + loss[1]
  84. keypoint_losses['pull_loss'] = loss[2][0]
  85. keypoint_losses['push_loss'] = loss[2][1]
  86. keypoint_losses['loss'] = recursive_sum(loss)
  87. return keypoint_losses
  88. class HeatMapLoss(object):
  89. def __init__(self, loss_factor=1.0):
  90. super(HeatMapLoss, self).__init__()
  91. self.loss_factor = loss_factor
  92. def __call__(self, preds, targets):
  93. heatmap, mask = targets
  94. loss = ((preds - heatmap)**2 * mask.cast('float').unsqueeze(1))
  95. loss = paddle.clip(loss, min=0, max=2).mean()
  96. loss *= self.loss_factor
  97. return loss
  98. class HeatMapSWAHRLoss(object):
  99. def __init__(self, num_joints, loss_factor=1.0):
  100. super(HeatMapSWAHRLoss, self).__init__()
  101. self.loss_factor = loss_factor
  102. self.num_joints = num_joints
  103. def __call__(self, preds, targets):
  104. heatmaps_gt, mask = targets
  105. heatmaps_pred = preds[0]
  106. scalemaps_pred = preds[1]
  107. heatmaps_scaled_gt = paddle.where(
  108. heatmaps_gt > 0, 0.5 * heatmaps_gt *
  109. (1 + (1 +
  110. (scalemaps_pred - 1.) * paddle.log(heatmaps_gt + 1e-10))**2),
  111. heatmaps_gt)
  112. regularizer_loss = paddle.mean(
  113. paddle.pow((scalemaps_pred - 1.) * (heatmaps_gt > 0).astype(float),
  114. 2))
  115. omiga = 0.01
  116. # thres = 2**(-1/omiga), threshold for positive weight
  117. hm_weight = heatmaps_scaled_gt**(
  118. omiga
  119. ) * paddle.abs(1 - heatmaps_pred) + paddle.abs(heatmaps_pred) * (
  120. 1 - heatmaps_scaled_gt**(omiga))
  121. loss = (((heatmaps_pred - heatmaps_scaled_gt)**2) *
  122. mask.cast('float').unsqueeze(1)) * hm_weight
  123. loss = loss.mean()
  124. loss = self.loss_factor * (loss + 1.0 * regularizer_loss)
  125. return loss
  126. class AELoss(object):
  127. def __init__(self, pull_factor=0.001, push_factor=0.001):
  128. super(AELoss, self).__init__()
  129. self.pull_factor = pull_factor
  130. self.push_factor = push_factor
  131. def apply_single(self, pred, tagmap):
  132. if tagmap.numpy()[:, :, 3].sum() == 0:
  133. return (paddle.zeros([1]), paddle.zeros([1]))
  134. nonzero = paddle.nonzero(tagmap[:, :, 3] > 0)
  135. if nonzero.shape[0] == 0:
  136. return (paddle.zeros([1]), paddle.zeros([1]))
  137. p_inds = paddle.unique(nonzero[:, 0])
  138. num_person = p_inds.shape[0]
  139. if num_person == 0:
  140. return (paddle.zeros([1]), paddle.zeros([1]))
  141. pull = 0
  142. tagpull_num = 0
  143. embs_all = []
  144. person_unvalid = 0
  145. for person_idx in p_inds.numpy():
  146. valid_single = tagmap[person_idx.item()]
  147. validkpts = paddle.nonzero(valid_single[:, 3] > 0)
  148. valid_single = paddle.index_select(valid_single, validkpts)
  149. emb = paddle.gather_nd(pred, valid_single[:, :3])
  150. if emb.shape[0] == 1:
  151. person_unvalid += 1
  152. mean = paddle.mean(emb, axis=0)
  153. embs_all.append(mean)
  154. pull += paddle.mean(paddle.pow(emb - mean, 2), axis=0)
  155. tagpull_num += emb.shape[0]
  156. pull /= max(num_person - person_unvalid, 1)
  157. if num_person < 2:
  158. return pull, paddle.zeros([1])
  159. embs_all = paddle.stack(embs_all)
  160. A = embs_all.expand([num_person, num_person])
  161. B = A.transpose([1, 0])
  162. diff = A - B
  163. diff = paddle.pow(diff, 2)
  164. push = paddle.exp(-diff)
  165. push = paddle.sum(push) - num_person
  166. push /= 2 * num_person * (num_person - 1)
  167. return pull, push
  168. def __call__(self, preds, tagmaps):
  169. bs = preds.shape[0]
  170. losses = [
  171. self.apply_single(preds[i:i + 1].squeeze(),
  172. tagmaps[i:i + 1].squeeze()) for i in range(bs)
  173. ]
  174. pull = self.pull_factor * sum(loss[0] for loss in losses) / len(losses)
  175. push = self.push_factor * sum(loss[1] for loss in losses) / len(losses)
  176. return pull, push
  177. class ZipLoss(object):
  178. def __init__(self, loss_funcs):
  179. super(ZipLoss, self).__init__()
  180. self.loss_funcs = loss_funcs
  181. def __call__(self, inputs, targets):
  182. assert len(self.loss_funcs) == len(targets) >= len(inputs)
  183. def zip_repeat(*args):
  184. longest = max(map(len, args))
  185. filled = [islice(cycle(x), longest) for x in args]
  186. return zip(*filled)
  187. return tuple(
  188. fn(x, y)
  189. for x, y, fn in zip_repeat(inputs, targets, self.loss_funcs))
  190. def recursive_sum(inputs):
  191. if isinstance(inputs, abc.Sequence):
  192. return sum([recursive_sum(x) for x in inputs])
  193. return inputs