matchers.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from scipy.optimize import linear_sum_assignment
  21. from paddlex.ppdet.core.workspace import register, serializable
  22. from ..losses.iou_loss import GIoULoss
  23. from .utils import bbox_cxcywh_to_xyxy
  24. __all__ = ['HungarianMatcher']
  25. @register
  26. @serializable
  27. class HungarianMatcher(nn.Layer):
  28. __shared__ = ['use_focal_loss']
  29. def __init__(self,
  30. matcher_coeff={'class': 1,
  31. 'bbox': 5,
  32. 'giou': 2},
  33. use_focal_loss=False,
  34. alpha=0.25,
  35. gamma=2.0):
  36. r"""
  37. Args:
  38. matcher_coeff (dict): The coefficient of hungarian matcher cost.
  39. """
  40. super(HungarianMatcher, self).__init__()
  41. self.matcher_coeff = matcher_coeff
  42. self.use_focal_loss = use_focal_loss
  43. self.alpha = alpha
  44. self.gamma = gamma
  45. self.giou_loss = GIoULoss()
  46. def forward(self, boxes, logits, gt_bbox, gt_class):
  47. r"""
  48. Args:
  49. boxes (Tensor): [b, query, 4]
  50. logits (Tensor): [b, query, num_classes]
  51. gt_bbox (List(Tensor)): list[[n, 4]]
  52. gt_class (List(Tensor)): list[[n, 1]]
  53. Returns:
  54. A list of size batch_size, containing tuples of (index_i, index_j) where:
  55. - index_i is the indices of the selected predictions (in order)
  56. - index_j is the indices of the corresponding selected targets (in order)
  57. For each batch element, it holds:
  58. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  59. """
  60. bs, num_queries = boxes.shape[:2]
  61. num_gts = sum(len(a) for a in gt_class)
  62. if num_gts == 0:
  63. return [(paddle.to_tensor(
  64. [], dtype=paddle.int64), paddle.to_tensor(
  65. [], dtype=paddle.int64)) for _ in range(bs)]
  66. # We flatten to compute the cost matrices in a batch
  67. # [batch_size * num_queries, num_classes]
  68. out_prob = F.sigmoid(logits.flatten(
  69. 0, 1)) if self.use_focal_loss else F.softmax(
  70. logits.flatten(0, 1))
  71. # [batch_size * num_queries, 4]
  72. out_bbox = boxes.flatten(0, 1)
  73. # Also concat the target labels and boxes
  74. tgt_ids = paddle.concat(gt_class).flatten()
  75. tgt_bbox = paddle.concat(gt_bbox)
  76. # Compute the classification cost
  77. if self.use_focal_loss:
  78. neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(
  79. 1 - out_prob + 1e-8).log())
  80. pos_cost_class = self.alpha * (
  81. (1 - out_prob)**self.gamma) * (-(out_prob + 1e-8).log())
  82. cost_class = paddle.gather(
  83. pos_cost_class, tgt_ids, axis=1) - paddle.gather(
  84. neg_cost_class, tgt_ids, axis=1)
  85. else:
  86. cost_class = -paddle.gather(out_prob, tgt_ids, axis=1)
  87. # Compute the L1 cost between boxes
  88. cost_bbox = (
  89. out_bbox.unsqueeze(1) - tgt_bbox.unsqueeze(0)).abs().sum(-1)
  90. # Compute the giou cost betwen boxes
  91. cost_giou = self.giou_loss(
  92. bbox_cxcywh_to_xyxy(out_bbox.unsqueeze(1)),
  93. bbox_cxcywh_to_xyxy(tgt_bbox.unsqueeze(0))).squeeze(-1)
  94. # Final cost matrix
  95. C = self.matcher_coeff['class'] * cost_class + self.matcher_coeff['bbox'] * cost_bbox + \
  96. self.matcher_coeff['giou'] * cost_giou
  97. C = C.reshape([bs, num_queries, -1])
  98. C = [a.squeeze(0) for a in C.chunk(bs)]
  99. sizes = [a.shape[0] for a in gt_bbox]
  100. indices = [
  101. linear_sum_assignment(c.split(sizes, -1)[i].numpy())
  102. for i, c in enumerate(C)
  103. ]
  104. return [(paddle.to_tensor(
  105. i, dtype=paddle.int64), paddle.to_tensor(
  106. j, dtype=paddle.int64)) for i, j in indices]