jde_embedding_head.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle import ParamAttr
  22. from paddle.regularizer import L2Decay
  23. from paddlex.ppdet.core.workspace import register
  24. from paddle.nn.initializer import Normal, Constant
  25. __all__ = ['JDEEmbeddingHead']
  26. class LossParam(nn.Layer):
  27. def __init__(self, init_value=0., use_uncertainy=True):
  28. super(LossParam, self).__init__()
  29. self.loss_param = self.create_parameter(
  30. shape=[1],
  31. attr=ParamAttr(initializer=Constant(value=init_value)),
  32. dtype="float32")
  33. def forward(self, inputs):
  34. out = paddle.exp(-self.loss_param) * inputs + self.loss_param
  35. return out * 0.5
  36. @register
  37. class JDEEmbeddingHead(nn.Layer):
  38. __shared__ = ['num_classes']
  39. __inject__ = ['emb_loss', 'jde_loss']
  40. """
  41. JDEEmbeddingHead
  42. Args:
  43. num_classes(int): Number of classes. Only support one class tracking.
  44. num_identifiers(int): Number of identifiers.
  45. anchor_levels(int): Number of anchor levels, same as FPN levels.
  46. anchor_scales(int): Number of anchor scales on each FPN level.
  47. embedding_dim(int): Embedding dimension. Default: 512.
  48. emb_loss(object): Instance of 'JDEEmbeddingLoss'
  49. jde_loss(object): Instance of 'JDELoss'
  50. """
  51. def __init__(
  52. self,
  53. num_classes=1,
  54. num_identifiers=14455, # defined by dataset.total_identities when training
  55. anchor_levels=3,
  56. anchor_scales=4,
  57. embedding_dim=512,
  58. emb_loss='JDEEmbeddingLoss',
  59. jde_loss='JDELoss'):
  60. super(JDEEmbeddingHead, self).__init__()
  61. self.num_classes = num_classes
  62. self.num_identifiers = num_identifiers
  63. self.anchor_levels = anchor_levels
  64. self.anchor_scales = anchor_scales
  65. self.embedding_dim = embedding_dim
  66. self.emb_loss = emb_loss
  67. self.jde_loss = jde_loss
  68. self.emb_scale = math.sqrt(2) * math.log(
  69. self.num_identifiers - 1) if self.num_identifiers > 1 else 1
  70. self.identify_outputs = []
  71. self.loss_params_cls = []
  72. self.loss_params_reg = []
  73. self.loss_params_ide = []
  74. for i in range(self.anchor_levels):
  75. name = 'identify_output.{}'.format(i)
  76. identify_output = self.add_sublayer(
  77. name,
  78. nn.Conv2D(
  79. in_channels=64 * (2**self.anchor_levels) // (2**i),
  80. out_channels=self.embedding_dim,
  81. kernel_size=3,
  82. stride=1,
  83. padding=1,
  84. weight_attr=ParamAttr(name=name + '.conv.weights'),
  85. bias_attr=ParamAttr(
  86. name=name + '.conv.bias', regularizer=L2Decay(0.))))
  87. self.identify_outputs.append(identify_output)
  88. loss_p_cls = self.add_sublayer('cls.{}'.format(i),
  89. LossParam(-4.15))
  90. self.loss_params_cls.append(loss_p_cls)
  91. loss_p_reg = self.add_sublayer('reg.{}'.format(i),
  92. LossParam(-4.85))
  93. self.loss_params_reg.append(loss_p_reg)
  94. loss_p_ide = self.add_sublayer('ide.{}'.format(i), LossParam(-2.3))
  95. self.loss_params_ide.append(loss_p_ide)
  96. self.classifier = self.add_sublayer(
  97. 'classifier',
  98. nn.Linear(
  99. self.embedding_dim,
  100. self.num_identifiers,
  101. weight_attr=ParamAttr(
  102. learning_rate=1., initializer=Normal(
  103. mean=0.0, std=0.01)),
  104. bias_attr=ParamAttr(
  105. learning_rate=2., regularizer=L2Decay(0.))))
  106. def forward(self,
  107. identify_feats,
  108. targets=None,
  109. loss_confs=None,
  110. loss_boxes=None,
  111. test_emb=False):
  112. assert len(identify_feats) == self.anchor_levels
  113. ide_outs = []
  114. for feat, ide_head in zip(identify_feats, self.identify_outputs):
  115. ide_outs.append(ide_head(feat))
  116. if self.training:
  117. assert targets != None
  118. assert len(loss_confs) == len(loss_boxes) == self.anchor_levels
  119. loss_ides = self.emb_loss(ide_outs, targets, self.emb_scale,
  120. self.classifier)
  121. return self.jde_loss(loss_confs, loss_boxes, loss_ides,
  122. self.loss_params_cls, self.loss_params_reg,
  123. self.loss_params_ide, targets)
  124. else:
  125. if test_emb:
  126. assert targets != None
  127. embs_and_gts = self.get_emb_and_gt_outs(ide_outs, targets)
  128. return embs_and_gts
  129. else:
  130. emb_outs = self.get_emb_outs(ide_outs)
  131. return emb_outs
  132. def get_emb_and_gt_outs(self, ide_outs, targets):
  133. emb_and_gts = []
  134. for i, p_ide in enumerate(ide_outs):
  135. t_conf = targets['tconf{}'.format(i)]
  136. t_ide = targets['tide{}'.format(i)]
  137. p_ide = p_ide.transpose((0, 2, 3, 1))
  138. p_ide_flatten = paddle.reshape(p_ide, [-1, self.embedding_dim])
  139. mask = t_conf > 0
  140. mask = paddle.cast(mask, dtype="int64")
  141. emb_mask = mask.max(1).flatten()
  142. emb_mask_inds = paddle.nonzero(emb_mask > 0).flatten()
  143. if len(emb_mask_inds) > 0:
  144. t_ide_flatten = paddle.reshape(t_ide.max(1), [-1, 1])
  145. tids = paddle.gather(t_ide_flatten, emb_mask_inds)
  146. embedding = paddle.gather(p_ide_flatten, emb_mask_inds)
  147. embedding = self.emb_scale * F.normalize(embedding)
  148. emb_and_gt = paddle.concat([embedding, tids], axis=1)
  149. emb_and_gts.append(emb_and_gt)
  150. if len(emb_and_gts) > 0:
  151. return paddle.concat(emb_and_gts, axis=0)
  152. else:
  153. return paddle.zeros((1, self.embedding_dim + 1))
  154. def get_emb_outs(self, ide_outs):
  155. emb_outs = []
  156. for i, p_ide in enumerate(ide_outs):
  157. p_ide = p_ide.transpose((0, 2, 3, 1))
  158. p_ide_repeat = paddle.tile(p_ide, [self.anchor_scales, 1, 1, 1])
  159. embedding = F.normalize(p_ide_repeat, axis=-1)
  160. emb = paddle.reshape(embedding, [-1, self.embedding_dim])
  161. emb_outs.append(emb)
  162. if len(emb_outs) > 0:
  163. return paddle.concat(emb_outs, axis=0)
  164. else:
  165. return paddle.zeros((1, self.embedding_dim))