post_process.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddlex.ppdet.core.workspace import register
  19. from paddlex.ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly
  20. from paddlex.ppdet.modeling.layers import TTFBox
  21. from .transformers import bbox_cxcywh_to_xyxy
  22. try:
  23. from collections.abc import Sequence
  24. except Exception:
  25. from collections import Sequence
  26. __all__ = [
  27. 'BBoxPostProcess', 'MaskPostProcess', 'FCOSPostProcess',
  28. 'S2ANetBBoxPostProcess', 'JDEBBoxPostProcess', 'CenterNetPostProcess',
  29. 'DETRBBoxPostProcess', 'SparsePostProcess'
  30. ]
  31. @register
  32. class BBoxPostProcess(nn.Layer):
  33. __shared__ = ['num_classes']
  34. __inject__ = ['decode', 'nms']
  35. def __init__(self, num_classes=80, decode=None, nms=None):
  36. super(BBoxPostProcess, self).__init__()
  37. self.num_classes = num_classes
  38. self.decode = decode
  39. self.nms = nms
  40. def forward(self, head_out, rois, im_shape, scale_factor):
  41. """
  42. Decode the bbox and do NMS if needed.
  43. Args:
  44. head_out (tuple): bbox_pred and cls_prob of bbox_head output.
  45. rois (tuple): roi and rois_num of rpn_head output.
  46. im_shape (Tensor): The shape of the input image.
  47. scale_factor (Tensor): The scale factor of the input image.
  48. Returns:
  49. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  50. labels, scores and bboxes. The size of bboxes are corresponding
  51. to the input image, the bboxes may be used in other branch.
  52. bbox_num (Tensor): The number of prediction boxes of each batch with
  53. shape [1], and is N.
  54. """
  55. if self.nms is not None:
  56. bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
  57. bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
  58. else:
  59. bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
  60. scale_factor)
  61. return bbox_pred, bbox_num
  62. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  63. """
  64. Rescale, clip and filter the bbox from the output of NMS to
  65. get final prediction.
  66. Notes:
  67. Currently only support bs = 1.
  68. Args:
  69. bboxes (Tensor): The output bboxes with shape [N, 6] after decode
  70. and NMS, including labels, scores and bboxes.
  71. bbox_num (Tensor): The number of prediction boxes of each batch with
  72. shape [1], and is N.
  73. im_shape (Tensor): The shape of the input image.
  74. scale_factor (Tensor): The scale factor of the input image.
  75. Returns:
  76. pred_result (Tensor): The final prediction results with shape [N, 6]
  77. including labels, scores and bboxes.
  78. """
  79. bboxes_list = []
  80. bbox_num_list = []
  81. id_start = 0
  82. fake_bboxes = paddle.to_tensor(
  83. np.array(
  84. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  85. fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  86. # add fake bbox when output is empty for each batch
  87. for i in range(bbox_num.shape[0]):
  88. if bbox_num[i] == 0:
  89. bboxes_i = fake_bboxes
  90. bbox_num_i = fake_bbox_num
  91. id_start += 1
  92. else:
  93. bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
  94. bbox_num_i = bbox_num[i]
  95. id_start += bbox_num[i]
  96. bboxes_list.append(bboxes_i)
  97. bbox_num_list.append(bbox_num_i)
  98. bboxes = paddle.concat(bboxes_list)
  99. bbox_num = paddle.concat(bbox_num_list)
  100. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  101. origin_shape_list = []
  102. scale_factor_list = []
  103. # scale_factor: scale_y, scale_x
  104. for i in range(bbox_num.shape[0]):
  105. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  106. [bbox_num[i], 2])
  107. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  108. scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
  109. expand_scale = paddle.expand(scale, [bbox_num[i], 4])
  110. origin_shape_list.append(expand_shape)
  111. scale_factor_list.append(expand_scale)
  112. self.origin_shape_list = paddle.concat(origin_shape_list)
  113. scale_factor_list = paddle.concat(scale_factor_list)
  114. # bboxes: [N, 6], label, score, bbox
  115. pred_label = bboxes[:, 0:1]
  116. pred_score = bboxes[:, 1:2]
  117. pred_bbox = bboxes[:, 2:]
  118. # rescale bbox to original image
  119. scaled_bbox = pred_bbox / scale_factor_list
  120. origin_h = self.origin_shape_list[:, 0]
  121. origin_w = self.origin_shape_list[:, 1]
  122. zeros = paddle.zeros_like(origin_h)
  123. # clip bbox to [0, original_size]
  124. x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
  125. y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
  126. x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
  127. y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
  128. pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  129. # filter empty bbox
  130. keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
  131. keep_mask = paddle.unsqueeze(keep_mask, [1])
  132. pred_label = paddle.where(keep_mask, pred_label,
  133. paddle.ones_like(pred_label) * -1)
  134. pred_result = paddle.concat(
  135. [pred_label, pred_score, pred_bbox], axis=1)
  136. return pred_result
  137. def get_origin_shape(self, ):
  138. return self.origin_shape_list
  139. @register
  140. class MaskPostProcess(object):
  141. """
  142. refer to:
  143. https://github.com/facebookresearch/detectron2/layers/mask_ops.py
  144. Get Mask output according to the output from model
  145. """
  146. def __init__(self, binary_thresh=0.5):
  147. super(MaskPostProcess, self).__init__()
  148. self.binary_thresh = binary_thresh
  149. def paste_mask(self, masks, boxes, im_h, im_w):
  150. """
  151. Paste the mask prediction to the original image.
  152. """
  153. x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
  154. masks = paddle.unsqueeze(masks, [0, 1])
  155. img_y = paddle.arange(0, im_h, dtype='float32') + 0.5
  156. img_x = paddle.arange(0, im_w, dtype='float32') + 0.5
  157. img_y = (img_y - y0) / (y1 - y0) * 2 - 1
  158. img_x = (img_x - x0) / (x1 - x0) * 2 - 1
  159. img_x = paddle.unsqueeze(img_x, [1])
  160. img_y = paddle.unsqueeze(img_y, [2])
  161. N = boxes.shape[0]
  162. gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]])
  163. gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]])
  164. grid = paddle.stack([gx, gy], axis=3)
  165. img_masks = F.grid_sample(masks, grid, align_corners=False)
  166. return img_masks[:, 0]
  167. def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
  168. """
  169. Decode the mask_out and paste the mask to the origin image.
  170. Args:
  171. mask_out (Tensor): mask_head output with shape [N, 28, 28].
  172. bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
  173. and NMS, including labels, scores and bboxes.
  174. bbox_num (Tensor): The number of prediction boxes of each batch with
  175. shape [1], and is N.
  176. origin_shape (Tensor): The origin shape of the input image, the tensor
  177. shape is [N, 2], and each row is [h, w].
  178. Returns:
  179. pred_result (Tensor): The final prediction mask results with shape
  180. [N, h, w] in binary mask style.
  181. """
  182. num_mask = mask_out.shape[0]
  183. origin_shape = paddle.cast(origin_shape, 'int32')
  184. # TODO: support bs > 1 and mask output dtype is bool
  185. pred_result = paddle.zeros(
  186. [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
  187. if bbox_num == 1 and bboxes[0][0] == -1:
  188. return pred_result
  189. # TODO: optimize chunk paste
  190. pred_result = []
  191. for i in range(bboxes.shape[0]):
  192. im_h, im_w = origin_shape[i][0], origin_shape[i][1]
  193. pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h,
  194. im_w)
  195. pred_mask = pred_mask >= self.binary_thresh
  196. pred_mask = paddle.cast(pred_mask, 'int32')
  197. pred_result.append(pred_mask)
  198. pred_result = paddle.concat(pred_result)
  199. return pred_result
  200. @register
  201. class FCOSPostProcess(object):
  202. __inject__ = ['decode', 'nms']
  203. def __init__(self, decode=None, nms=None):
  204. super(FCOSPostProcess, self).__init__()
  205. self.decode = decode
  206. self.nms = nms
  207. def __call__(self, fcos_head_outs, scale_factor):
  208. """
  209. Decode the bbox and do NMS in FCOS.
  210. """
  211. locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
  212. bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
  213. centerness, scale_factor)
  214. bbox_pred, bbox_num, _ = self.nms(bboxes, score)
  215. return bbox_pred, bbox_num
  216. @register
  217. class S2ANetBBoxPostProcess(nn.Layer):
  218. __shared__ = ['num_classes']
  219. __inject__ = ['nms']
  220. def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0,
  221. nms=None):
  222. super(S2ANetBBoxPostProcess, self).__init__()
  223. self.num_classes = num_classes
  224. self.nms_pre = paddle.to_tensor(nms_pre)
  225. self.min_bbox_size = min_bbox_size
  226. self.nms = nms
  227. self.origin_shape_list = []
  228. self.fake_pred_cls_score_bbox = paddle.to_tensor(
  229. np.array(
  230. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
  231. dtype='float32'))
  232. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  233. def forward(self, pred_scores, pred_bboxes):
  234. """
  235. pred_scores : [N, M] score
  236. pred_bboxes : [N, 5] xc, yc, w, h, a
  237. im_shape : [N, 2] im_shape
  238. scale_factor : [N, 2] scale_factor
  239. """
  240. pred_ploys0 = rbox2poly(pred_bboxes)
  241. pred_ploys = paddle.unsqueeze(pred_ploys0, axis=0)
  242. # pred_scores [NA, 16] --> [16, NA]
  243. pred_scores0 = paddle.transpose(pred_scores, [1, 0])
  244. pred_scores = paddle.unsqueeze(pred_scores0, axis=0)
  245. pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores,
  246. self.num_classes)
  247. # Prevent empty bbox_pred from decode or NMS.
  248. # Bboxes and score before NMS may be empty due to the score threshold.
  249. if pred_cls_score_bbox.shape[0] <= 0 or pred_cls_score_bbox.shape[
  250. 1] <= 1:
  251. pred_cls_score_bbox = self.fake_pred_cls_score_bbox
  252. bbox_num = self.fake_bbox_num
  253. pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10])
  254. return pred_cls_score_bbox, bbox_num
  255. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  256. """
  257. Rescale, clip and filter the bbox from the output of NMS to
  258. get final prediction.
  259. Args:
  260. bboxes(Tensor): bboxes [N, 10]
  261. bbox_num(Tensor): bbox_num
  262. im_shape(Tensor): [1 2]
  263. scale_factor(Tensor): [1 2]
  264. Returns:
  265. bbox_pred(Tensor): The output is the prediction with shape [N, 8]
  266. including labels, scores and bboxes. The size of
  267. bboxes are corresponding to the original image.
  268. """
  269. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  270. origin_shape_list = []
  271. scale_factor_list = []
  272. # scale_factor: scale_y, scale_x
  273. for i in range(bbox_num.shape[0]):
  274. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  275. [bbox_num[i], 2])
  276. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  277. scale = paddle.concat([
  278. scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
  279. scale_y
  280. ])
  281. expand_scale = paddle.expand(scale, [bbox_num[i], 8])
  282. origin_shape_list.append(expand_shape)
  283. scale_factor_list.append(expand_scale)
  284. origin_shape_list = paddle.concat(origin_shape_list)
  285. scale_factor_list = paddle.concat(scale_factor_list)
  286. # bboxes: [N, 10], label, score, bbox
  287. pred_label_score = bboxes[:, 0:2]
  288. pred_bbox = bboxes[:, 2:]
  289. # rescale bbox to original image
  290. pred_bbox = pred_bbox.reshape([-1, 8])
  291. scaled_bbox = pred_bbox / scale_factor_list
  292. origin_h = origin_shape_list[:, 0]
  293. origin_w = origin_shape_list[:, 1]
  294. bboxes = scaled_bbox
  295. zeros = paddle.zeros_like(origin_h)
  296. x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros)
  297. y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros)
  298. x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros)
  299. y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros)
  300. x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros)
  301. y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros)
  302. x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros)
  303. y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros)
  304. pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1)
  305. pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1)
  306. return pred_result
  307. @register
  308. class JDEBBoxPostProcess(nn.Layer):
  309. __shared__ = ['num_classes']
  310. __inject__ = ['decode', 'nms']
  311. def __init__(self, num_classes=1, decode=None, nms=None, return_idx=True):
  312. super(JDEBBoxPostProcess, self).__init__()
  313. self.num_classes = num_classes
  314. self.decode = decode
  315. self.nms = nms
  316. self.return_idx = return_idx
  317. self.fake_bbox_pred = paddle.to_tensor(
  318. np.array(
  319. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  320. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  321. self.fake_nms_keep_idx = paddle.to_tensor(
  322. np.array(
  323. [[0]], dtype='int32'))
  324. self.fake_yolo_boxes_out = paddle.to_tensor(
  325. np.array(
  326. [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32'))
  327. self.fake_yolo_scores_out = paddle.to_tensor(
  328. np.array(
  329. [[[0.0]]], dtype='float32'))
  330. self.fake_boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64'))
  331. def forward(self, head_out, anchors):
  332. """
  333. Decode the bbox and do NMS for JDE model.
  334. Args:
  335. head_out (list): Bbox_pred and cls_prob of bbox_head output.
  336. anchors (list): Anchors of JDE model.
  337. Returns:
  338. boxes_idx (Tensor): The index of kept bboxes after decode 'JDEBox'.
  339. bbox_pred (Tensor): The output is the prediction with shape [N, 6]
  340. including labels, scores and bboxes.
  341. bbox_num (Tensor): The number of prediction of each batch with shape [N].
  342. nms_keep_idx (Tensor): The index of kept bboxes after NMS.
  343. """
  344. boxes_idx, yolo_boxes_scores = self.decode(head_out, anchors)
  345. if len(boxes_idx) == 0:
  346. boxes_idx = self.fake_boxes_idx
  347. yolo_boxes_out = self.fake_yolo_boxes_out
  348. yolo_scores_out = self.fake_yolo_scores_out
  349. else:
  350. yolo_boxes = paddle.gather_nd(yolo_boxes_scores, boxes_idx)
  351. # TODO: only support bs=1 now
  352. yolo_boxes_out = paddle.reshape(
  353. yolo_boxes[:, :4], shape=[1, len(boxes_idx), 4])
  354. yolo_scores_out = paddle.reshape(
  355. yolo_boxes[:, 4:5], shape=[1, 1, len(boxes_idx)])
  356. boxes_idx = boxes_idx[:, 1:]
  357. if self.return_idx:
  358. bbox_pred, bbox_num, nms_keep_idx = self.nms(
  359. yolo_boxes_out, yolo_scores_out, self.num_classes)
  360. if bbox_pred.shape[0] == 0:
  361. bbox_pred = self.fake_bbox_pred
  362. bbox_num = self.fake_bbox_num
  363. nms_keep_idx = self.fake_nms_keep_idx
  364. return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
  365. else:
  366. bbox_pred, bbox_num, _ = self.nms(yolo_boxes_out, yolo_scores_out,
  367. self.num_classes)
  368. if bbox_pred.shape[0] == 0:
  369. bbox_pred = self.fake_bbox_pred
  370. bbox_num = self.fake_bbox_num
  371. return _, bbox_pred, bbox_num, _
  372. @register
  373. class CenterNetPostProcess(TTFBox):
  374. """
  375. Postprocess the model outputs to get final prediction:
  376. 1. Do NMS for heatmap to get top `max_per_img` bboxes.
  377. 2. Decode bboxes using center offset and box size.
  378. 3. Rescale decoded bboxes reference to the origin image shape.
  379. Args:
  380. max_per_img(int): the maximum number of predicted objects in a image,
  381. 500 by default.
  382. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  383. regress_ltrb (bool): whether to regress left/top/right/bottom or
  384. width/height for a box, true by default.
  385. for_mot (bool): whether return other features used in tracking model.
  386. """
  387. __shared__ = ['down_ratio', 'for_mot']
  388. def __init__(self,
  389. max_per_img=500,
  390. down_ratio=4,
  391. regress_ltrb=True,
  392. for_mot=False):
  393. super(TTFBox, self).__init__()
  394. self.max_per_img = max_per_img
  395. self.down_ratio = down_ratio
  396. self.regress_ltrb = regress_ltrb
  397. self.for_mot = for_mot
  398. def __call__(self, hm, wh, reg, im_shape, scale_factor):
  399. heat = self._simple_nms(hm)
  400. scores, inds, topk_clses, ys, xs = self._topk(heat)
  401. scores = scores.unsqueeze(1)
  402. clses = topk_clses.unsqueeze(1)
  403. reg_t = paddle.transpose(reg, [0, 2, 3, 1])
  404. # Like TTFBox, batch size is 1.
  405. # TODO: support batch size > 1
  406. reg = paddle.reshape(reg_t, [-1, reg_t.shape[-1]])
  407. reg = paddle.gather(reg, inds)
  408. xs = paddle.cast(xs, 'float32')
  409. ys = paddle.cast(ys, 'float32')
  410. xs = xs + reg[:, 0:1]
  411. ys = ys + reg[:, 1:2]
  412. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  413. wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]])
  414. wh = paddle.gather(wh, inds)
  415. if self.regress_ltrb:
  416. x1 = xs - wh[:, 0:1]
  417. y1 = ys - wh[:, 1:2]
  418. x2 = xs + wh[:, 2:3]
  419. y2 = ys + wh[:, 3:4]
  420. else:
  421. x1 = xs - wh[:, 0:1] / 2
  422. y1 = ys - wh[:, 1:2] / 2
  423. x2 = xs + wh[:, 0:1] / 2
  424. y2 = ys + wh[:, 1:2] / 2
  425. n, c, feat_h, feat_w = hm.shape[:]
  426. padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
  427. padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
  428. x1 = x1 * self.down_ratio
  429. y1 = y1 * self.down_ratio
  430. x2 = x2 * self.down_ratio
  431. y2 = y2 * self.down_ratio
  432. x1 = x1 - padw
  433. y1 = y1 - padh
  434. x2 = x2 - padw
  435. y2 = y2 - padh
  436. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  437. scale_y = scale_factor[:, 0:1]
  438. scale_x = scale_factor[:, 1:2]
  439. scale_expand = paddle.concat(
  440. [scale_x, scale_y, scale_x, scale_y], axis=1)
  441. boxes_shape = bboxes.shape[:]
  442. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  443. bboxes = paddle.divide(bboxes, scale_expand)
  444. if self.for_mot:
  445. results = paddle.concat([bboxes, scores, clses], axis=1)
  446. return results, inds, topk_clses
  447. else:
  448. results = paddle.concat([clses, scores, bboxes], axis=1)
  449. return results, paddle.shape(results)[0:1], topk_clses
  450. @register
  451. class DETRBBoxPostProcess(object):
  452. __shared__ = ['num_classes', 'use_focal_loss']
  453. __inject__ = []
  454. def __init__(self,
  455. num_classes=80,
  456. num_top_queries=100,
  457. use_focal_loss=False):
  458. super(DETRBBoxPostProcess, self).__init__()
  459. self.num_classes = num_classes
  460. self.num_top_queries = num_top_queries
  461. self.use_focal_loss = use_focal_loss
  462. def __call__(self, head_out, im_shape, scale_factor):
  463. """
  464. Decode the bbox.
  465. Args:
  466. head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output.
  467. im_shape (Tensor): The shape of the input image.
  468. scale_factor (Tensor): The scale factor of the input image.
  469. Returns:
  470. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  471. labels, scores and bboxes. The size of bboxes are corresponding
  472. to the input image, the bboxes may be used in other branch.
  473. bbox_num (Tensor): The number of prediction boxes of each batch with
  474. shape [bs], and is N.
  475. """
  476. bboxes, logits, masks = head_out
  477. bbox_pred = bbox_cxcywh_to_xyxy(bboxes)
  478. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  479. img_h, img_w = origin_shape.unbind(1)
  480. origin_shape = paddle.stack(
  481. [img_w, img_h, img_w, img_h], axis=-1).unsqueeze(0)
  482. bbox_pred *= origin_shape
  483. scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
  484. logits)[:, :, :-1]
  485. if not self.use_focal_loss:
  486. scores, labels = scores.max(-1), scores.argmax(-1)
  487. if scores.shape[1] > self.num_top_queries:
  488. scores, index = paddle.topk(
  489. scores, self.num_top_queries, axis=-1)
  490. labels = paddle.stack(
  491. [paddle.gather(l, i) for l, i in zip(labels, index)])
  492. bbox_pred = paddle.stack(
  493. [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
  494. else:
  495. scores, index = paddle.topk(
  496. scores.reshape([logits.shape[0], -1]),
  497. self.num_top_queries,
  498. axis=-1)
  499. labels = index % logits.shape[2]
  500. index = index // logits.shape[2]
  501. bbox_pred = paddle.stack(
  502. [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
  503. bbox_pred = paddle.concat(
  504. [
  505. labels.unsqueeze(-1).astype('float32'), scores.unsqueeze(-1),
  506. bbox_pred
  507. ],
  508. axis=-1)
  509. bbox_num = paddle.to_tensor(
  510. bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]])
  511. bbox_pred = bbox_pred.reshape([-1, 6])
  512. return bbox_pred, bbox_num
  513. @register
  514. class SparsePostProcess(object):
  515. __shared__ = ['num_classes']
  516. def __init__(self, num_proposals, num_classes=80):
  517. super(SparsePostProcess, self).__init__()
  518. self.num_classes = num_classes
  519. self.num_proposals = num_proposals
  520. def __call__(self, box_cls, box_pred, scale_factor_wh, img_whwh):
  521. """
  522. Arguments:
  523. box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
  524. The tensor predicts the classification probability for each proposal.
  525. box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
  526. The tensor predicts 4-vector (x,y,w,h) box
  527. regression values for every proposal
  528. scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of per img
  529. img_whwh (Tensor): tensors of shape [batch_size, 4]
  530. Returns:
  531. bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values:
  532. [label, confidence, xmin, ymin, xmax, ymax]
  533. bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image.
  534. """
  535. assert len(box_cls) == len(scale_factor_wh) == len(img_whwh)
  536. img_wh = img_whwh[:, :2]
  537. scores = F.sigmoid(box_cls)
  538. labels = paddle.arange(0, self.num_classes). \
  539. unsqueeze(0).tile([self.num_proposals, 1]).flatten(start_axis=0, stop_axis=1)
  540. classes_all = []
  541. scores_all = []
  542. boxes_all = []
  543. for i, (scores_per_image,
  544. box_pred_per_image) in enumerate(zip(scores, box_pred)):
  545. scores_per_image, topk_indices = scores_per_image.flatten(
  546. 0, 1).topk(
  547. self.num_proposals, sorted=False)
  548. labels_per_image = paddle.gather(labels, topk_indices, axis=0)
  549. box_pred_per_image = box_pred_per_image.reshape([-1, 1, 4]).tile(
  550. [1, self.num_classes, 1]).reshape([-1, 4])
  551. box_pred_per_image = paddle.gather(
  552. box_pred_per_image, topk_indices, axis=0)
  553. classes_all.append(labels_per_image)
  554. scores_all.append(scores_per_image)
  555. boxes_all.append(box_pred_per_image)
  556. bbox_num = paddle.zeros([len(scale_factor_wh)], dtype="int32")
  557. boxes_final = []
  558. for i in range(len(scale_factor_wh)):
  559. classes = classes_all[i]
  560. boxes = boxes_all[i]
  561. scores = scores_all[i]
  562. boxes[:, 0::2] = paddle.clip(
  563. boxes[:, 0::2], min=0,
  564. max=img_wh[i][0]) / scale_factor_wh[i][0]
  565. boxes[:, 1::2] = paddle.clip(
  566. boxes[:, 1::2], min=0,
  567. max=img_wh[i][1]) / scale_factor_wh[i][1]
  568. boxes_w, boxes_h = (boxes[:, 2] - boxes[:, 0]).numpy(), (
  569. boxes[:, 3] - boxes[:, 1]).numpy()
  570. keep = (boxes_w > 1.) & (boxes_h > 1.)
  571. if (keep.sum() == 0):
  572. bboxes = paddle.zeros([1, 6]).astype("float32")
  573. else:
  574. boxes = paddle.to_tensor(boxes.numpy()[keep]).astype("float32")
  575. classes = paddle.to_tensor(classes.numpy()[keep]).astype(
  576. "float32").unsqueeze(-1)
  577. scores = paddle.to_tensor(scores.numpy()[keep]).astype(
  578. "float32").unsqueeze(-1)
  579. bboxes = paddle.concat([classes, scores, boxes], axis=-1)
  580. boxes_final.append(bboxes)
  581. bbox_num[i] = bboxes.shape[0]
  582. bbox_pred = paddle.concat(boxes_final)
  583. return bbox_pred, bbox_num
  584. def nms(dets, thresh):
  585. """Apply classic DPM-style greedy NMS."""
  586. if dets.shape[0] == 0:
  587. return dets[[], :]
  588. scores = dets[:, 0]
  589. x1 = dets[:, 1]
  590. y1 = dets[:, 2]
  591. x2 = dets[:, 3]
  592. y2 = dets[:, 4]
  593. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  594. order = scores.argsort()[::-1]
  595. ndets = dets.shape[0]
  596. suppressed = np.zeros((ndets), dtype=np.int)
  597. # nominal indices
  598. # _i, _j
  599. # sorted indices
  600. # i, j
  601. # temp variables for box i's (the box currently under consideration)
  602. # ix1, iy1, ix2, iy2, iarea
  603. # variables for computing overlap with box j (lower scoring box)
  604. # xx1, yy1, xx2, yy2
  605. # w, h
  606. # inter, ovr
  607. for _i in range(ndets):
  608. i = order[_i]
  609. if suppressed[i] == 1:
  610. continue
  611. ix1 = x1[i]
  612. iy1 = y1[i]
  613. ix2 = x2[i]
  614. iy2 = y2[i]
  615. iarea = areas[i]
  616. for _j in range(_i + 1, ndets):
  617. j = order[_j]
  618. if suppressed[j] == 1:
  619. continue
  620. xx1 = max(ix1, x1[j])
  621. yy1 = max(iy1, y1[j])
  622. xx2 = min(ix2, x2[j])
  623. yy2 = min(iy2, y2[j])
  624. w = max(0.0, xx2 - xx1 + 1)
  625. h = max(0.0, yy2 - yy1 + 1)
  626. inter = w * h
  627. ovr = inter / (iarea + areas[j] - inter)
  628. if ovr >= thresh:
  629. suppressed[j] = 1
  630. keep = np.where(suppressed == 0)[0]
  631. dets = dets[keep, :]
  632. return dets