sparsercnn_loss.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from scipy.optimize import linear_sum_assignment
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle.metric import accuracy
  22. from paddlex.ppdet.core.workspace import register
  23. from paddlex.ppdet.modeling.losses.iou_loss import GIoULoss
  24. __all__ = ["SparseRCNNLoss"]
  25. @register
  26. class SparseRCNNLoss(nn.Layer):
  27. """ This class computes the loss for SparseRCNN.
  28. The process happens in two steps:
  29. 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
  30. 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
  31. """
  32. __shared__ = ['num_classes']
  33. def __init__(self,
  34. losses,
  35. focal_loss_alpha,
  36. focal_loss_gamma,
  37. num_classes=80,
  38. class_weight=2.,
  39. l1_weight=5.,
  40. giou_weight=2.):
  41. """ Create the criterion.
  42. Parameters:
  43. num_classes: number of object categories, omitting the special no-object category
  44. weight_dict: dict containing as key the names of the losses and as values their relative weight.
  45. losses: list of all the losses to be applied. See get_loss for list of available losses.
  46. matcher: module able to compute a matching between targets and proposals
  47. """
  48. super().__init__()
  49. self.num_classes = num_classes
  50. weight_dict = {
  51. "loss_ce": class_weight,
  52. "loss_bbox": l1_weight,
  53. "loss_giou": giou_weight
  54. }
  55. self.weight_dict = weight_dict
  56. self.losses = losses
  57. self.giou_loss = GIoULoss(reduction="sum")
  58. self.focal_loss_alpha = focal_loss_alpha
  59. self.focal_loss_gamma = focal_loss_gamma
  60. self.matcher = HungarianMatcher(focal_loss_alpha, focal_loss_gamma,
  61. class_weight, l1_weight, giou_weight)
  62. def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
  63. """Classification loss (NLL)
  64. targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
  65. """
  66. assert 'pred_logits' in outputs
  67. src_logits = outputs['pred_logits']
  68. idx = self._get_src_permutation_idx(indices)
  69. target_classes_o = paddle.concat([
  70. paddle.gather(
  71. t["labels"], J, axis=0) for t, (_, J) in zip(targets, indices)
  72. ])
  73. target_classes = paddle.full(
  74. src_logits.shape[:2], self.num_classes, dtype="int32")
  75. for i, ind in enumerate(zip(idx[0], idx[1])):
  76. target_classes[int(ind[0]), int(ind[1])] = target_classes_o[i]
  77. target_classes.stop_gradient = True
  78. src_logits = src_logits.flatten(start_axis=0, stop_axis=1)
  79. # prepare one_hot target.
  80. target_classes = target_classes.flatten(start_axis=0, stop_axis=1)
  81. class_ids = paddle.arange(0, self.num_classes)
  82. labels = (target_classes.unsqueeze(-1) == class_ids).astype("float32")
  83. labels.stop_gradient = True
  84. # comp focal loss.
  85. class_loss = sigmoid_focal_loss(
  86. src_logits,
  87. labels,
  88. alpha=self.focal_loss_alpha,
  89. gamma=self.focal_loss_gamma,
  90. reduction="sum", ) / num_boxes
  91. losses = {'loss_ce': class_loss}
  92. if log:
  93. label_acc = target_classes_o.unsqueeze(-1)
  94. src_idx = [src for (src, _) in indices]
  95. pred_list = []
  96. for i in range(outputs["pred_logits"].shape[0]):
  97. pred_list.append(
  98. paddle.gather(
  99. outputs["pred_logits"][i], src_idx[i], axis=0))
  100. pred = F.sigmoid(paddle.concat(pred_list, axis=0))
  101. acc = accuracy(pred, label_acc.astype("int64"))
  102. losses["acc"] = acc
  103. return losses
  104. def loss_boxes(self, outputs, targets, indices, num_boxes):
  105. """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
  106. targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
  107. The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
  108. """
  109. assert 'pred_boxes' in outputs # [batch_size, num_proposals, 4]
  110. src_idx = [src for (src, _) in indices]
  111. src_boxes_list = []
  112. for i in range(outputs["pred_boxes"].shape[0]):
  113. src_boxes_list.append(
  114. paddle.gather(
  115. outputs["pred_boxes"][i], src_idx[i], axis=0))
  116. src_boxes = paddle.concat(src_boxes_list, axis=0)
  117. target_boxes = paddle.concat(
  118. [
  119. paddle.gather(
  120. t['boxes'], I, axis=0)
  121. for t, (_, I) in zip(targets, indices)
  122. ],
  123. axis=0)
  124. target_boxes.stop_gradient = True
  125. losses = {}
  126. losses['loss_giou'] = self.giou_loss(src_boxes,
  127. target_boxes) / num_boxes
  128. image_size = paddle.concat([v["img_whwh_tgt"] for v in targets])
  129. src_boxes_ = src_boxes / image_size
  130. target_boxes_ = target_boxes / image_size
  131. loss_bbox = F.l1_loss(src_boxes_, target_boxes_, reduction='sum')
  132. losses['loss_bbox'] = loss_bbox / num_boxes
  133. return losses
  134. def _get_src_permutation_idx(self, indices):
  135. # permute predictions following indices
  136. batch_idx = paddle.concat(
  137. [paddle.full_like(src, i) for i, (src, _) in enumerate(indices)])
  138. src_idx = paddle.concat([src for (src, _) in indices])
  139. return batch_idx, src_idx
  140. def _get_tgt_permutation_idx(self, indices):
  141. # permute targets following indices
  142. batch_idx = paddle.concat(
  143. [paddle.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  144. tgt_idx = paddle.concat([tgt for (_, tgt) in indices])
  145. return batch_idx, tgt_idx
  146. def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
  147. loss_map = {
  148. 'labels': self.loss_labels,
  149. 'boxes': self.loss_boxes,
  150. }
  151. assert loss in loss_map, f'do you really want to compute {loss} loss?'
  152. return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
  153. def forward(self, outputs, targets):
  154. """ This performs the loss computation.
  155. Parameters:
  156. outputs: dict of tensors, see the output specification of the model for the format
  157. targets: list of dicts, such that len(targets) == batch_size.
  158. The expected keys in each dict depends on the losses applied, see each loss' doc
  159. """
  160. outputs_without_aux = {
  161. k: v
  162. for k, v in outputs.items() if k != 'aux_outputs'
  163. }
  164. # Retrieve the matching between the outputs of the last layer and the targets
  165. indices = self.matcher(outputs_without_aux, targets)
  166. # Compute the average number of target boxes accross all nodes, for normalization purposes
  167. num_boxes = sum(len(t["labels"]) for t in targets)
  168. num_boxes = paddle.to_tensor(
  169. [num_boxes],
  170. dtype="float32",
  171. place=next(iter(outputs.values())).place)
  172. # Compute all the requested losses
  173. losses = {}
  174. for loss in self.losses:
  175. losses.update(
  176. self.get_loss(loss, outputs, targets, indices, num_boxes))
  177. # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  178. if 'aux_outputs' in outputs:
  179. for i, aux_outputs in enumerate(outputs['aux_outputs']):
  180. indices = self.matcher(aux_outputs, targets)
  181. for loss in self.losses:
  182. kwargs = {}
  183. if loss == 'labels':
  184. # Logging is enabled only for the last layer
  185. kwargs = {'log': False}
  186. l_dict = self.get_loss(loss, aux_outputs, targets, indices,
  187. num_boxes, **kwargs)
  188. w_dict = {}
  189. for k in l_dict.keys():
  190. if k in self.weight_dict:
  191. w_dict[k + f'_{i}'] = l_dict[k] * self.weight_dict[
  192. k]
  193. else:
  194. w_dict[k + f'_{i}'] = l_dict[k]
  195. losses.update(w_dict)
  196. return losses
  197. class HungarianMatcher(nn.Layer):
  198. """This class computes an assignment between the targets and the predictions of the network
  199. For efficiency reasons, the targets don't include the no_object. Because of this, in general,
  200. there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
  201. while the others are un-matched (and thus treated as non-objects).
  202. """
  203. def __init__(self,
  204. focal_loss_alpha,
  205. focal_loss_gamma,
  206. cost_class: float=1,
  207. cost_bbox: float=1,
  208. cost_giou: float=1):
  209. """Creates the matcher
  210. Params:
  211. cost_class: This is the relative weight of the classification error in the matching cost
  212. cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
  213. cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
  214. """
  215. super().__init__()
  216. self.cost_class = cost_class
  217. self.cost_bbox = cost_bbox
  218. self.cost_giou = cost_giou
  219. self.focal_loss_alpha = focal_loss_alpha
  220. self.focal_loss_gamma = focal_loss_gamma
  221. assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
  222. @paddle.no_grad()
  223. def forward(self, outputs, targets):
  224. """ Performs the matching
  225. Args:
  226. outputs: This is a dict that contains at least these entries:
  227. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  228. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
  229. eg. outputs = {"pred_logits": pred_logits, "pred_boxes": pred_boxes}
  230. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
  231. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
  232. objects in the target) containing the class labels
  233. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
  234. eg. targets = [{"labels":labels, "boxes": boxes}, ...,{"labels":labels, "boxes": boxes}]
  235. Returns:
  236. A list of size batch_size, containing tuples of (index_i, index_j) where:
  237. - index_i is the indices of the selected predictions (in order)
  238. - index_j is the indices of the corresponding selected targets (in order)
  239. For each batch element, it holds:
  240. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  241. """
  242. bs, num_queries = outputs["pred_logits"].shape[:2]
  243. # We flatten to compute the cost matrices in a batch
  244. out_prob = F.sigmoid(outputs["pred_logits"].flatten(
  245. start_axis=0, stop_axis=1))
  246. out_bbox = outputs["pred_boxes"].flatten(start_axis=0, stop_axis=1)
  247. # Also concat the target labels and boxes
  248. tgt_ids = paddle.concat([v["labels"] for v in targets])
  249. assert (tgt_ids > -1).all()
  250. tgt_bbox = paddle.concat([v["boxes"] for v in targets])
  251. # Compute the classification cost. Contrary to the loss, we don't use the NLL,
  252. # but approximate it in 1 - proba[target class].
  253. # The 1 is a constant that doesn't change the matching, it can be ommitted.
  254. # Compute the classification cost.
  255. alpha = self.focal_loss_alpha
  256. gamma = self.focal_loss_gamma
  257. neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(
  258. 1 - out_prob + 1e-8).log())
  259. pos_cost_class = alpha * ((1 - out_prob)
  260. **gamma) * (-(out_prob + 1e-8).log())
  261. cost_class = paddle.gather(
  262. pos_cost_class, tgt_ids, axis=1) - paddle.gather(
  263. neg_cost_class, tgt_ids, axis=1)
  264. # Compute the L1 cost between boxes
  265. image_size_out = paddle.concat(
  266. [v["img_whwh"].unsqueeze(0) for v in targets])
  267. image_size_out = image_size_out.unsqueeze(1).tile(
  268. [1, num_queries, 1]).flatten(
  269. start_axis=0, stop_axis=1)
  270. image_size_tgt = paddle.concat([v["img_whwh_tgt"] for v in targets])
  271. out_bbox_ = out_bbox / image_size_out
  272. tgt_bbox_ = tgt_bbox / image_size_tgt
  273. cost_bbox = F.l1_loss(
  274. out_bbox_.unsqueeze(-2), tgt_bbox_,
  275. reduction='none').sum(-1) # [batch_size * num_queries, num_tgts]
  276. # Compute the giou cost betwen boxes
  277. cost_giou = -get_bboxes_giou(out_bbox, tgt_bbox)
  278. # Final cost matrix
  279. C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
  280. C = C.reshape([bs, num_queries, -1])
  281. sizes = [len(v["boxes"]) for v in targets]
  282. indices = [
  283. linear_sum_assignment(c[i].numpy())
  284. for i, c in enumerate(C.split(sizes, -1))
  285. ]
  286. return [(paddle.to_tensor(
  287. i, dtype="int32"), paddle.to_tensor(
  288. j, dtype="int32")) for i, j in indices]
  289. def box_area(boxes):
  290. assert (boxes[:, 2:] >= boxes[:, :2]).all()
  291. wh = boxes[:, 2:] - boxes[:, :2]
  292. return wh[:, 0] * wh[:, 1]
  293. def boxes_iou(boxes1, boxes2):
  294. '''
  295. Compute iou
  296. Args:
  297. boxes1 (paddle.tensor) shape (N, 4)
  298. boxes2 (paddle.tensor) shape (M, 4)
  299. Return:
  300. (paddle.tensor) shape (N, M)
  301. '''
  302. area1 = box_area(boxes1)
  303. area2 = box_area(boxes2)
  304. lt = paddle.maximum(boxes1.unsqueeze(-2)[:, :, :2], boxes2[:, :2])
  305. rb = paddle.minimum(boxes1.unsqueeze(-2)[:, :, 2:], boxes2[:, 2:])
  306. wh = (rb - lt).astype("float32").clip(min=1e-9)
  307. inter = wh[:, :, 0] * wh[:, :, 1]
  308. union = area1.unsqueeze(-1) + area2 - inter + 1e-9
  309. iou = inter / union
  310. return iou, union
  311. def get_bboxes_giou(boxes1, boxes2, eps=1e-9):
  312. """calculate the ious of boxes1 and boxes2
  313. Args:
  314. boxes1 (Tensor): shape [N, 4]
  315. boxes2 (Tensor): shape [M, 4]
  316. eps (float): epsilon to avoid divide by zero
  317. Return:
  318. ious (Tensor): ious of boxes1 and boxes2, with the shape [N, M]
  319. """
  320. assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
  321. assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
  322. iou, union = boxes_iou(boxes1, boxes2)
  323. lt = paddle.minimum(boxes1.unsqueeze(-2)[:, :, :2], boxes2[:, :2])
  324. rb = paddle.maximum(boxes1.unsqueeze(-2)[:, :, 2:], boxes2[:, 2:])
  325. wh = (rb - lt).astype("float32").clip(min=eps)
  326. enclose_area = wh[:, :, 0] * wh[:, :, 1]
  327. giou = iou - (enclose_area - union) / enclose_area
  328. return giou
  329. def sigmoid_focal_loss(inputs, targets, alpha, gamma, reduction="sum"):
  330. assert reduction in ["sum", "mean"
  331. ], f'do not support this {reduction} reduction?'
  332. p = F.sigmoid(inputs)
  333. ce_loss = F.binary_cross_entropy_with_logits(
  334. inputs, targets, reduction="none")
  335. p_t = p * targets + (1 - p) * (1 - targets)
  336. loss = ce_loss * ((1 - p_t)**gamma)
  337. if alpha >= 0:
  338. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  339. loss = alpha_t * loss
  340. if reduction == "mean":
  341. loss = loss.mean()
  342. elif reduction == "sum":
  343. loss = loss.sum()
  344. return loss