post_process.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddlex.ppdet.core.workspace import register
  19. from paddlex.ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly
  20. from paddlex.ppdet.modeling.layers import TTFBox
  21. from .transformers import bbox_cxcywh_to_xyxy
  22. try:
  23. from collections.abc import Sequence
  24. except Exception:
  25. from collections import Sequence
  26. __all__ = [
  27. 'BBoxPostProcess', 'MaskPostProcess', 'FCOSPostProcess',
  28. 'S2ANetBBoxPostProcess', 'JDEBBoxPostProcess', 'CenterNetPostProcess',
  29. 'DETRBBoxPostProcess', 'SparsePostProcess'
  30. ]
  31. @register
  32. class BBoxPostProcess(nn.Layer):
  33. __shared__ = ['num_classes']
  34. __inject__ = ['decode', 'nms']
  35. def __init__(self, num_classes=80, decode=None, nms=None):
  36. super(BBoxPostProcess, self).__init__()
  37. self.num_classes = num_classes
  38. self.decode = decode
  39. self.nms = nms
  40. self.fake_bboxes = paddle.to_tensor(
  41. np.array(
  42. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  43. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  44. def forward(self, head_out, rois, im_shape, scale_factor):
  45. """
  46. Decode the bbox and do NMS if needed.
  47. Args:
  48. head_out (tuple): bbox_pred and cls_prob of bbox_head output.
  49. rois (tuple): roi and rois_num of rpn_head output.
  50. im_shape (Tensor): The shape of the input image.
  51. scale_factor (Tensor): The scale factor of the input image.
  52. Returns:
  53. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  54. labels, scores and bboxes. The size of bboxes are corresponding
  55. to the input image, the bboxes may be used in other branch.
  56. bbox_num (Tensor): The number of prediction boxes of each batch with
  57. shape [1], and is N.
  58. """
  59. if self.nms is not None:
  60. bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
  61. bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
  62. else:
  63. bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
  64. scale_factor)
  65. return bbox_pred, bbox_num
  66. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  67. """
  68. Rescale, clip and filter the bbox from the output of NMS to
  69. get final prediction.
  70. Notes:
  71. Currently only support bs = 1.
  72. Args:
  73. bboxes (Tensor): The output bboxes with shape [N, 6] after decode
  74. and NMS, including labels, scores and bboxes.
  75. bbox_num (Tensor): The number of prediction boxes of each batch with
  76. shape [1], and is N.
  77. im_shape (Tensor): The shape of the input image.
  78. scale_factor (Tensor): The scale factor of the input image.
  79. Returns:
  80. pred_result (Tensor): The final prediction results with shape [N, 6]
  81. including labels, scores and bboxes.
  82. """
  83. if bboxes.shape[0] == 0:
  84. bboxes = self.fake_bboxes
  85. bbox_num = self.fake_bbox_num
  86. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  87. origin_shape_list = []
  88. scale_factor_list = []
  89. # scale_factor: scale_y, scale_x
  90. for i in range(bbox_num.shape[0]):
  91. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  92. [bbox_num[i], 2])
  93. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  94. scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
  95. expand_scale = paddle.expand(scale, [bbox_num[i], 4])
  96. origin_shape_list.append(expand_shape)
  97. scale_factor_list.append(expand_scale)
  98. self.origin_shape_list = paddle.concat(origin_shape_list)
  99. scale_factor_list = paddle.concat(scale_factor_list)
  100. # bboxes: [N, 6], label, score, bbox
  101. pred_label = bboxes[:, 0:1]
  102. pred_score = bboxes[:, 1:2]
  103. pred_bbox = bboxes[:, 2:]
  104. # rescale bbox to original image
  105. scaled_bbox = pred_bbox / scale_factor_list
  106. origin_h = self.origin_shape_list[:, 0]
  107. origin_w = self.origin_shape_list[:, 1]
  108. zeros = paddle.zeros_like(origin_h)
  109. # clip bbox to [0, original_size]
  110. x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
  111. y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
  112. x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
  113. y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
  114. pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  115. # filter empty bbox
  116. keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
  117. keep_mask = paddle.unsqueeze(keep_mask, [1])
  118. pred_label = paddle.where(keep_mask, pred_label,
  119. paddle.ones_like(pred_label) * -1)
  120. pred_result = paddle.concat(
  121. [pred_label, pred_score, pred_bbox], axis=1)
  122. return pred_result
  123. def get_origin_shape(self, ):
  124. return self.origin_shape_list
  125. @register
  126. class MaskPostProcess(object):
  127. """
  128. refer to:
  129. https://github.com/facebookresearch/detectron2/layers/mask_ops.py
  130. Get Mask output according to the output from model
  131. """
  132. def __init__(self, binary_thresh=0.5):
  133. super(MaskPostProcess, self).__init__()
  134. self.binary_thresh = binary_thresh
  135. def paste_mask(self, masks, boxes, im_h, im_w):
  136. """
  137. Paste the mask prediction to the original image.
  138. """
  139. x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
  140. masks = paddle.unsqueeze(masks, [0, 1])
  141. img_y = paddle.arange(0, im_h, dtype='float32') + 0.5
  142. img_x = paddle.arange(0, im_w, dtype='float32') + 0.5
  143. img_y = (img_y - y0) / (y1 - y0) * 2 - 1
  144. img_x = (img_x - x0) / (x1 - x0) * 2 - 1
  145. img_x = paddle.unsqueeze(img_x, [1])
  146. img_y = paddle.unsqueeze(img_y, [2])
  147. N = boxes.shape[0]
  148. gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]])
  149. gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]])
  150. grid = paddle.stack([gx, gy], axis=3)
  151. img_masks = F.grid_sample(masks, grid, align_corners=False)
  152. return img_masks[:, 0]
  153. def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
  154. """
  155. Decode the mask_out and paste the mask to the origin image.
  156. Args:
  157. mask_out (Tensor): mask_head output with shape [N, 28, 28].
  158. bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
  159. and NMS, including labels, scores and bboxes.
  160. bbox_num (Tensor): The number of prediction boxes of each batch with
  161. shape [1], and is N.
  162. origin_shape (Tensor): The origin shape of the input image, the tensor
  163. shape is [N, 2], and each row is [h, w].
  164. Returns:
  165. pred_result (Tensor): The final prediction mask results with shape
  166. [N, h, w] in binary mask style.
  167. """
  168. num_mask = mask_out.shape[0]
  169. origin_shape = paddle.cast(origin_shape, 'int32')
  170. # TODO: support bs > 1 and mask output dtype is bool
  171. pred_result = paddle.zeros(
  172. [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
  173. if bbox_num == 1 and bboxes[0][0] == -1:
  174. return pred_result
  175. # TODO: optimize chunk paste
  176. pred_result = []
  177. for i in range(bboxes.shape[0]):
  178. im_h, im_w = origin_shape[i][0], origin_shape[i][1]
  179. pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h,
  180. im_w)
  181. pred_mask = pred_mask >= self.binary_thresh
  182. pred_mask = paddle.cast(pred_mask, 'int32')
  183. pred_result.append(pred_mask)
  184. pred_result = paddle.concat(pred_result)
  185. return pred_result
  186. @register
  187. class FCOSPostProcess(object):
  188. __inject__ = ['decode', 'nms']
  189. def __init__(self, decode=None, nms=None):
  190. super(FCOSPostProcess, self).__init__()
  191. self.decode = decode
  192. self.nms = nms
  193. def __call__(self, fcos_head_outs, scale_factor):
  194. """
  195. Decode the bbox and do NMS in FCOS.
  196. """
  197. locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
  198. bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
  199. centerness, scale_factor)
  200. bbox_pred, bbox_num, _ = self.nms(bboxes, score)
  201. return bbox_pred, bbox_num
  202. @register
  203. class S2ANetBBoxPostProcess(nn.Layer):
  204. __shared__ = ['num_classes']
  205. __inject__ = ['nms']
  206. def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0,
  207. nms=None):
  208. super(S2ANetBBoxPostProcess, self).__init__()
  209. self.num_classes = num_classes
  210. self.nms_pre = paddle.to_tensor(nms_pre)
  211. self.min_bbox_size = min_bbox_size
  212. self.nms = nms
  213. self.origin_shape_list = []
  214. self.fake_pred_cls_score_bbox = paddle.to_tensor(
  215. np.array(
  216. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
  217. dtype='float32'))
  218. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  219. def forward(self, pred_scores, pred_bboxes):
  220. """
  221. pred_scores : [N, M] score
  222. pred_bboxes : [N, 5] xc, yc, w, h, a
  223. im_shape : [N, 2] im_shape
  224. scale_factor : [N, 2] scale_factor
  225. """
  226. pred_ploys0 = rbox2poly(pred_bboxes)
  227. pred_ploys = paddle.unsqueeze(pred_ploys0, axis=0)
  228. # pred_scores [NA, 16] --> [16, NA]
  229. pred_scores0 = paddle.transpose(pred_scores, [1, 0])
  230. pred_scores = paddle.unsqueeze(pred_scores0, axis=0)
  231. pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores,
  232. self.num_classes)
  233. # Prevent empty bbox_pred from decode or NMS.
  234. # Bboxes and score before NMS may be empty due to the score threshold.
  235. if pred_cls_score_bbox.shape[0] <= 0 or pred_cls_score_bbox.shape[
  236. 1] <= 1:
  237. pred_cls_score_bbox = self.fake_pred_cls_score_bbox
  238. bbox_num = self.fake_bbox_num
  239. pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10])
  240. return pred_cls_score_bbox, bbox_num
  241. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  242. """
  243. Rescale, clip and filter the bbox from the output of NMS to
  244. get final prediction.
  245. Args:
  246. bboxes(Tensor): bboxes [N, 10]
  247. bbox_num(Tensor): bbox_num
  248. im_shape(Tensor): [1 2]
  249. scale_factor(Tensor): [1 2]
  250. Returns:
  251. bbox_pred(Tensor): The output is the prediction with shape [N, 8]
  252. including labels, scores and bboxes. The size of
  253. bboxes are corresponding to the original image.
  254. """
  255. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  256. origin_shape_list = []
  257. scale_factor_list = []
  258. # scale_factor: scale_y, scale_x
  259. for i in range(bbox_num.shape[0]):
  260. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  261. [bbox_num[i], 2])
  262. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  263. scale = paddle.concat([
  264. scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
  265. scale_y
  266. ])
  267. expand_scale = paddle.expand(scale, [bbox_num[i], 8])
  268. origin_shape_list.append(expand_shape)
  269. scale_factor_list.append(expand_scale)
  270. origin_shape_list = paddle.concat(origin_shape_list)
  271. scale_factor_list = paddle.concat(scale_factor_list)
  272. # bboxes: [N, 10], label, score, bbox
  273. pred_label_score = bboxes[:, 0:2]
  274. pred_bbox = bboxes[:, 2:]
  275. # rescale bbox to original image
  276. pred_bbox = pred_bbox.reshape([-1, 8])
  277. scaled_bbox = pred_bbox / scale_factor_list
  278. origin_h = origin_shape_list[:, 0]
  279. origin_w = origin_shape_list[:, 1]
  280. bboxes = scaled_bbox
  281. zeros = paddle.zeros_like(origin_h)
  282. x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros)
  283. y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros)
  284. x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros)
  285. y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros)
  286. x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros)
  287. y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros)
  288. x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros)
  289. y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros)
  290. pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1)
  291. pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1)
  292. return pred_result
  293. @register
  294. class JDEBBoxPostProcess(nn.Layer):
  295. __shared__ = ['num_classes']
  296. __inject__ = ['decode', 'nms']
  297. def __init__(self, num_classes=1, decode=None, nms=None, return_idx=True):
  298. super(JDEBBoxPostProcess, self).__init__()
  299. self.num_classes = num_classes
  300. self.decode = decode
  301. self.nms = nms
  302. self.return_idx = return_idx
  303. self.fake_bbox_pred = paddle.to_tensor(
  304. np.array(
  305. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  306. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  307. self.fake_nms_keep_idx = paddle.to_tensor(
  308. np.array(
  309. [[0]], dtype='int32'))
  310. self.fake_yolo_boxes_out = paddle.to_tensor(
  311. np.array(
  312. [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32'))
  313. self.fake_yolo_scores_out = paddle.to_tensor(
  314. np.array(
  315. [[[0.0]]], dtype='float32'))
  316. self.fake_boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64'))
  317. def forward(self, head_out, anchors):
  318. """
  319. Decode the bbox and do NMS for JDE model.
  320. Args:
  321. head_out (list): Bbox_pred and cls_prob of bbox_head output.
  322. anchors (list): Anchors of JDE model.
  323. Returns:
  324. boxes_idx (Tensor): The index of kept bboxes after decode 'JDEBox'.
  325. bbox_pred (Tensor): The output is the prediction with shape [N, 6]
  326. including labels, scores and bboxes.
  327. bbox_num (Tensor): The number of prediction of each batch with shape [N].
  328. nms_keep_idx (Tensor): The index of kept bboxes after NMS.
  329. """
  330. boxes_idx, yolo_boxes_scores = self.decode(head_out, anchors)
  331. if len(boxes_idx) == 0:
  332. boxes_idx = self.fake_boxes_idx
  333. yolo_boxes_out = self.fake_yolo_boxes_out
  334. yolo_scores_out = self.fake_yolo_scores_out
  335. else:
  336. yolo_boxes = paddle.gather_nd(yolo_boxes_scores, boxes_idx)
  337. # TODO: only support bs=1 now
  338. yolo_boxes_out = paddle.reshape(
  339. yolo_boxes[:, :4], shape=[1, len(boxes_idx), 4])
  340. yolo_scores_out = paddle.reshape(
  341. yolo_boxes[:, 4:5], shape=[1, 1, len(boxes_idx)])
  342. boxes_idx = boxes_idx[:, 1:]
  343. if self.return_idx:
  344. bbox_pred, bbox_num, nms_keep_idx = self.nms(
  345. yolo_boxes_out, yolo_scores_out, self.num_classes)
  346. if bbox_pred.shape[0] == 0:
  347. bbox_pred = self.fake_bbox_pred
  348. bbox_num = self.fake_bbox_num
  349. nms_keep_idx = self.fake_nms_keep_idx
  350. return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
  351. else:
  352. bbox_pred, bbox_num, _ = self.nms(yolo_boxes_out, yolo_scores_out,
  353. self.num_classes)
  354. if bbox_pred.shape[0] == 0:
  355. bbox_pred = self.fake_bbox_pred
  356. bbox_num = self.fake_bbox_num
  357. return _, bbox_pred, bbox_num, _
  358. @register
  359. class CenterNetPostProcess(TTFBox):
  360. """
  361. Postprocess the model outputs to get final prediction:
  362. 1. Do NMS for heatmap to get top `max_per_img` bboxes.
  363. 2. Decode bboxes using center offset and box size.
  364. 3. Rescale decoded bboxes reference to the origin image shape.
  365. Args:
  366. max_per_img(int): the maximum number of predicted objects in a image,
  367. 500 by default.
  368. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  369. regress_ltrb (bool): whether to regress left/top/right/bottom or
  370. width/height for a box, true by default.
  371. for_mot (bool): whether return other features used in tracking model.
  372. """
  373. __shared__ = ['down_ratio', 'for_mot']
  374. def __init__(self,
  375. max_per_img=500,
  376. down_ratio=4,
  377. regress_ltrb=True,
  378. for_mot=False):
  379. super(TTFBox, self).__init__()
  380. self.max_per_img = max_per_img
  381. self.down_ratio = down_ratio
  382. self.regress_ltrb = regress_ltrb
  383. self.for_mot = for_mot
  384. def __call__(self, hm, wh, reg, im_shape, scale_factor):
  385. heat = self._simple_nms(hm)
  386. scores, inds, topk_clses, ys, xs = self._topk(heat)
  387. scores = scores.unsqueeze(1)
  388. clses = topk_clses.unsqueeze(1)
  389. reg_t = paddle.transpose(reg, [0, 2, 3, 1])
  390. # Like TTFBox, batch size is 1.
  391. # TODO: support batch size > 1
  392. reg = paddle.reshape(reg_t, [-1, reg_t.shape[-1]])
  393. reg = paddle.gather(reg, inds)
  394. xs = paddle.cast(xs, 'float32')
  395. ys = paddle.cast(ys, 'float32')
  396. xs = xs + reg[:, 0:1]
  397. ys = ys + reg[:, 1:2]
  398. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  399. wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]])
  400. wh = paddle.gather(wh, inds)
  401. if self.regress_ltrb:
  402. x1 = xs - wh[:, 0:1]
  403. y1 = ys - wh[:, 1:2]
  404. x2 = xs + wh[:, 2:3]
  405. y2 = ys + wh[:, 3:4]
  406. else:
  407. x1 = xs - wh[:, 0:1] / 2
  408. y1 = ys - wh[:, 1:2] / 2
  409. x2 = xs + wh[:, 0:1] / 2
  410. y2 = ys + wh[:, 1:2] / 2
  411. n, c, feat_h, feat_w = hm.shape[:]
  412. padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
  413. padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
  414. x1 = x1 * self.down_ratio
  415. y1 = y1 * self.down_ratio
  416. x2 = x2 * self.down_ratio
  417. y2 = y2 * self.down_ratio
  418. x1 = x1 - padw
  419. y1 = y1 - padh
  420. x2 = x2 - padw
  421. y2 = y2 - padh
  422. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  423. scale_y = scale_factor[:, 0:1]
  424. scale_x = scale_factor[:, 1:2]
  425. scale_expand = paddle.concat(
  426. [scale_x, scale_y, scale_x, scale_y], axis=1)
  427. boxes_shape = bboxes.shape[:]
  428. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  429. bboxes = paddle.divide(bboxes, scale_expand)
  430. if self.for_mot:
  431. results = paddle.concat([bboxes, scores, clses], axis=1)
  432. return results, inds, topk_clses
  433. else:
  434. results = paddle.concat([clses, scores, bboxes], axis=1)
  435. return results, paddle.shape(results)[0:1], topk_clses
  436. @register
  437. class DETRBBoxPostProcess(object):
  438. __shared__ = ['num_classes', 'use_focal_loss']
  439. __inject__ = []
  440. def __init__(self,
  441. num_classes=80,
  442. num_top_queries=100,
  443. use_focal_loss=False):
  444. super(DETRBBoxPostProcess, self).__init__()
  445. self.num_classes = num_classes
  446. self.num_top_queries = num_top_queries
  447. self.use_focal_loss = use_focal_loss
  448. def __call__(self, head_out, im_shape, scale_factor):
  449. """
  450. Decode the bbox.
  451. Args:
  452. head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output.
  453. im_shape (Tensor): The shape of the input image.
  454. scale_factor (Tensor): The scale factor of the input image.
  455. Returns:
  456. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  457. labels, scores and bboxes. The size of bboxes are corresponding
  458. to the input image, the bboxes may be used in other branch.
  459. bbox_num (Tensor): The number of prediction boxes of each batch with
  460. shape [bs], and is N.
  461. """
  462. bboxes, logits, masks = head_out
  463. bbox_pred = bbox_cxcywh_to_xyxy(bboxes)
  464. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  465. img_h, img_w = origin_shape.unbind(1)
  466. origin_shape = paddle.stack(
  467. [img_w, img_h, img_w, img_h], axis=-1).unsqueeze(0)
  468. bbox_pred *= origin_shape
  469. scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
  470. logits)[:, :, :-1]
  471. if not self.use_focal_loss:
  472. scores, labels = scores.max(-1), scores.argmax(-1)
  473. if scores.shape[1] > self.num_top_queries:
  474. scores, index = paddle.topk(
  475. scores, self.num_top_queries, axis=-1)
  476. labels = paddle.stack(
  477. [paddle.gather(l, i) for l, i in zip(labels, index)])
  478. bbox_pred = paddle.stack(
  479. [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
  480. else:
  481. scores, index = paddle.topk(
  482. scores.reshape([logits.shape[0], -1]),
  483. self.num_top_queries,
  484. axis=-1)
  485. labels = index % logits.shape[2]
  486. index = index // logits.shape[2]
  487. bbox_pred = paddle.stack(
  488. [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
  489. bbox_pred = paddle.concat(
  490. [
  491. labels.unsqueeze(-1).astype('float32'), scores.unsqueeze(-1),
  492. bbox_pred
  493. ],
  494. axis=-1)
  495. bbox_num = paddle.to_tensor(
  496. bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]])
  497. bbox_pred = bbox_pred.reshape([-1, 6])
  498. return bbox_pred, bbox_num
  499. @register
  500. class SparsePostProcess(object):
  501. __shared__ = ['num_classes']
  502. def __init__(self, num_proposals, num_classes=80):
  503. super(SparsePostProcess, self).__init__()
  504. self.num_classes = num_classes
  505. self.num_proposals = num_proposals
  506. def __call__(self, box_cls, box_pred, scale_factor_wh, img_whwh):
  507. """
  508. Arguments:
  509. box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
  510. The tensor predicts the classification probability for each proposal.
  511. box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
  512. The tensor predicts 4-vector (x,y,w,h) box
  513. regression values for every proposal
  514. scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of per img
  515. img_whwh (Tensor): tensors of shape [batch_size, 4]
  516. Returns:
  517. bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values:
  518. [label, confidence, xmin, ymin, xmax, ymax]
  519. bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image.
  520. """
  521. assert len(box_cls) == len(scale_factor_wh) == len(img_whwh)
  522. img_wh = img_whwh[:, :2]
  523. scores = F.sigmoid(box_cls)
  524. labels = paddle.arange(0, self.num_classes). \
  525. unsqueeze(0).tile([self.num_proposals, 1]).flatten(start_axis=0, stop_axis=1)
  526. classes_all = []
  527. scores_all = []
  528. boxes_all = []
  529. for i, (scores_per_image,
  530. box_pred_per_image) in enumerate(zip(scores, box_pred)):
  531. scores_per_image, topk_indices = scores_per_image.flatten(
  532. 0, 1).topk(
  533. self.num_proposals, sorted=False)
  534. labels_per_image = paddle.gather(labels, topk_indices, axis=0)
  535. box_pred_per_image = box_pred_per_image.reshape([-1, 1, 4]).tile(
  536. [1, self.num_classes, 1]).reshape([-1, 4])
  537. box_pred_per_image = paddle.gather(
  538. box_pred_per_image, topk_indices, axis=0)
  539. classes_all.append(labels_per_image)
  540. scores_all.append(scores_per_image)
  541. boxes_all.append(box_pred_per_image)
  542. bbox_num = paddle.zeros([len(scale_factor_wh)], dtype="int32")
  543. boxes_final = []
  544. for i in range(len(scale_factor_wh)):
  545. classes = classes_all[i]
  546. boxes = boxes_all[i]
  547. scores = scores_all[i]
  548. boxes[:, 0::2] = paddle.clip(
  549. boxes[:, 0::2], min=0,
  550. max=img_wh[i][0]) / scale_factor_wh[i][0]
  551. boxes[:, 1::2] = paddle.clip(
  552. boxes[:, 1::2], min=0,
  553. max=img_wh[i][1]) / scale_factor_wh[i][1]
  554. boxes_w, boxes_h = (boxes[:, 2] - boxes[:, 0]).numpy(), (
  555. boxes[:, 3] - boxes[:, 1]).numpy()
  556. keep = (boxes_w > 1.) & (boxes_h > 1.)
  557. if (keep.sum() == 0):
  558. bboxes = paddle.zeros([1, 6]).astype("float32")
  559. else:
  560. boxes = paddle.to_tensor(boxes.numpy()[keep]).astype("float32")
  561. classes = paddle.to_tensor(classes.numpy()[keep]).astype(
  562. "float32").unsqueeze(-1)
  563. scores = paddle.to_tensor(scores.numpy()[keep]).astype(
  564. "float32").unsqueeze(-1)
  565. bboxes = paddle.concat([classes, scores, boxes], axis=-1)
  566. boxes_final.append(bboxes)
  567. bbox_num[i] = bboxes.shape[0]
  568. bbox_pred = paddle.concat(boxes_final)
  569. return bbox_pred, bbox_num
  570. def nms(dets, thresh):
  571. """Apply classic DPM-style greedy NMS."""
  572. if dets.shape[0] == 0:
  573. return dets[[], :]
  574. scores = dets[:, 0]
  575. x1 = dets[:, 1]
  576. y1 = dets[:, 2]
  577. x2 = dets[:, 3]
  578. y2 = dets[:, 4]
  579. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  580. order = scores.argsort()[::-1]
  581. ndets = dets.shape[0]
  582. suppressed = np.zeros((ndets), dtype=np.int)
  583. # nominal indices
  584. # _i, _j
  585. # sorted indices
  586. # i, j
  587. # temp variables for box i's (the box currently under consideration)
  588. # ix1, iy1, ix2, iy2, iarea
  589. # variables for computing overlap with box j (lower scoring box)
  590. # xx1, yy1, xx2, yy2
  591. # w, h
  592. # inter, ovr
  593. for _i in range(ndets):
  594. i = order[_i]
  595. if suppressed[i] == 1:
  596. continue
  597. ix1 = x1[i]
  598. iy1 = y1[i]
  599. ix2 = x2[i]
  600. iy2 = y2[i]
  601. iarea = areas[i]
  602. for _j in range(_i + 1, ndets):
  603. j = order[_j]
  604. if suppressed[j] == 1:
  605. continue
  606. xx1 = max(ix1, x1[j])
  607. yy1 = max(iy1, y1[j])
  608. xx2 = min(ix2, x2[j])
  609. yy2 = min(iy2, y2[j])
  610. w = max(0.0, xx2 - xx1 + 1)
  611. h = max(0.0, yy2 - yy1 + 1)
  612. inter = w * h
  613. ovr = inter / (iarea + areas[j] - inter)
  614. if ovr >= thresh:
  615. suppressed[j] = 1
  616. keep = np.where(suppressed == 0)[0]
  617. dets = dets[keep, :]
  618. return dets