post_process.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddlex.ppdet.core.workspace import register
  19. from paddlex.ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly, rbox2poly
  20. from paddlex.ppdet.modeling.layers import TTFBox
  21. try:
  22. from collections.abc import Sequence
  23. except Exception:
  24. from collections import Sequence
  25. __all__ = [
  26. 'BBoxPostProcess',
  27. 'MaskPostProcess',
  28. 'FCOSPostProcess',
  29. 'S2ANetBBoxPostProcess',
  30. 'JDEBBoxPostProcess',
  31. 'CenterNetPostProcess',
  32. ]
  33. @register
  34. class BBoxPostProcess(object):
  35. __shared__ = ['num_classes']
  36. __inject__ = ['decode', 'nms']
  37. def __init__(self, num_classes=80, decode=None, nms=None):
  38. super(BBoxPostProcess, self).__init__()
  39. self.num_classes = num_classes
  40. self.decode = decode
  41. self.nms = nms
  42. def __call__(self, head_out, rois, im_shape, scale_factor):
  43. """
  44. Decode the bbox and do NMS if needed.
  45. Args:
  46. head_out (tuple): bbox_pred and cls_prob of bbox_head output.
  47. rois (tuple): roi and rois_num of rpn_head output.
  48. im_shape (Tensor): The shape of the input image.
  49. scale_factor (Tensor): The scale factor of the input image.
  50. Returns:
  51. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  52. labels, scores and bboxes. The size of bboxes are corresponding
  53. to the input image, the bboxes may be used in other branch.
  54. bbox_num (Tensor): The number of prediction boxes of each batch with
  55. shape [1], and is N.
  56. """
  57. if self.nms is not None:
  58. bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
  59. bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
  60. else:
  61. bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
  62. scale_factor)
  63. return bbox_pred, bbox_num
  64. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  65. """
  66. Rescale, clip and filter the bbox from the output of NMS to
  67. get final prediction.
  68. Notes:
  69. Currently only support bs = 1.
  70. Args:
  71. bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
  72. and NMS, including labels, scores and bboxes.
  73. bbox_num (Tensor): The number of prediction boxes of each batch with
  74. shape [1], and is N.
  75. im_shape (Tensor): The shape of the input image.
  76. scale_factor (Tensor): The scale factor of the input image.
  77. Returns:
  78. pred_result (Tensor): The final prediction results with shape [N, 6]
  79. including labels, scores and bboxes.
  80. """
  81. if bboxes.shape[0] == 0:
  82. bboxes = paddle.to_tensor(
  83. np.array(
  84. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  85. bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  86. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  87. origin_shape_list = []
  88. scale_factor_list = []
  89. # scale_factor: scale_y, scale_x
  90. for i in range(bbox_num.shape[0]):
  91. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  92. [bbox_num[i], 2])
  93. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  94. scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
  95. expand_scale = paddle.expand(scale, [bbox_num[i], 4])
  96. origin_shape_list.append(expand_shape)
  97. scale_factor_list.append(expand_scale)
  98. self.origin_shape_list = paddle.concat(origin_shape_list)
  99. scale_factor_list = paddle.concat(scale_factor_list)
  100. # bboxes: [N, 6], label, score, bbox
  101. pred_label = bboxes[:, 0:1]
  102. pred_score = bboxes[:, 1:2]
  103. pred_bbox = bboxes[:, 2:]
  104. # rescale bbox to original image
  105. scaled_bbox = pred_bbox / scale_factor_list
  106. origin_h = self.origin_shape_list[:, 0]
  107. origin_w = self.origin_shape_list[:, 1]
  108. zeros = paddle.zeros_like(origin_h)
  109. # clip bbox to [0, original_size]
  110. x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
  111. y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
  112. x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
  113. y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
  114. pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  115. # filter empty bbox
  116. keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
  117. keep_mask = paddle.unsqueeze(keep_mask, [1])
  118. pred_label = paddle.where(keep_mask, pred_label,
  119. paddle.ones_like(pred_label) * -1)
  120. pred_result = paddle.concat([pred_label, pred_score, pred_bbox], axis=1)
  121. return pred_result
  122. def get_origin_shape(self, ):
  123. return self.origin_shape_list
  124. @register
  125. class MaskPostProcess(object):
  126. def __init__(self, binary_thresh=0.5):
  127. super(MaskPostProcess, self).__init__()
  128. self.binary_thresh = binary_thresh
  129. def paste_mask(self, masks, boxes, im_h, im_w):
  130. """
  131. Paste the mask prediction to the original image.
  132. """
  133. x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
  134. masks = paddle.unsqueeze(masks, [0, 1])
  135. img_y = paddle.arange(0, im_h, dtype='float32') + 0.5
  136. img_x = paddle.arange(0, im_w, dtype='float32') + 0.5
  137. img_y = (img_y - y0) / (y1 - y0) * 2 - 1
  138. img_x = (img_x - x0) / (x1 - x0) * 2 - 1
  139. img_x = paddle.unsqueeze(img_x, [1])
  140. img_y = paddle.unsqueeze(img_y, [2])
  141. N = boxes.shape[0]
  142. gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]])
  143. gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]])
  144. grid = paddle.stack([gx, gy], axis=3)
  145. img_masks = F.grid_sample(masks, grid, align_corners=False)
  146. return img_masks[:, 0]
  147. def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
  148. """
  149. Decode the mask_out and paste the mask to the origin image.
  150. Args:
  151. mask_out (Tensor): mask_head output with shape [N, 28, 28].
  152. bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
  153. and NMS, including labels, scores and bboxes.
  154. bbox_num (Tensor): The number of prediction boxes of each batch with
  155. shape [1], and is N.
  156. origin_shape (Tensor): The origin shape of the input image, the tensor
  157. shape is [N, 2], and each row is [h, w].
  158. Returns:
  159. pred_result (Tensor): The final prediction mask results with shape
  160. [N, h, w] in binary mask style.
  161. """
  162. num_mask = mask_out.shape[0]
  163. origin_shape = paddle.cast(origin_shape, 'int32')
  164. # TODO: support bs > 1 and mask output dtype is bool
  165. pred_result = paddle.zeros(
  166. [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
  167. if bbox_num == 1 and bboxes[0][0] == -1:
  168. return pred_result
  169. # TODO: optimize chunk paste
  170. pred_result = []
  171. for i in range(bboxes.shape[0]):
  172. im_h, im_w = origin_shape[i][0], origin_shape[i][1]
  173. pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h,
  174. im_w)
  175. pred_mask = pred_mask >= self.binary_thresh
  176. pred_mask = paddle.cast(pred_mask, 'int32')
  177. pred_result.append(pred_mask)
  178. pred_result = paddle.concat(pred_result)
  179. return pred_result
  180. @register
  181. class FCOSPostProcess(object):
  182. __inject__ = ['decode', 'nms']
  183. def __init__(self, decode=None, nms=None):
  184. super(FCOSPostProcess, self).__init__()
  185. self.decode = decode
  186. self.nms = nms
  187. def __call__(self, fcos_head_outs, scale_factor):
  188. """
  189. Decode the bbox and do NMS in FCOS.
  190. """
  191. locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
  192. bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
  193. centerness, scale_factor)
  194. bbox_pred, bbox_num, _ = self.nms(bboxes, score)
  195. return bbox_pred, bbox_num
  196. @register
  197. class S2ANetBBoxPostProcess(nn.Layer):
  198. __shared__ = ['num_classes']
  199. __inject__ = ['nms']
  200. def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0, nms=None):
  201. super(S2ANetBBoxPostProcess, self).__init__()
  202. self.num_classes = num_classes
  203. self.nms_pre = nms_pre
  204. self.min_bbox_size = min_bbox_size
  205. self.nms = nms
  206. self.origin_shape_list = []
  207. self.fake_pred_cls_score_bbox = paddle.to_tensor(
  208. np.array(
  209. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
  210. dtype='float32'))
  211. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  212. def forward(self, pred_scores, pred_bboxes):
  213. """
  214. pred_scores : [N, M] score
  215. pred_bboxes : [N, 5] xc, yc, w, h, a
  216. im_shape : [N, 2] im_shape
  217. scale_factor : [N, 2] scale_factor
  218. """
  219. pred_ploys0 = rbox2poly(pred_bboxes)
  220. pred_ploys = paddle.unsqueeze(pred_ploys0, axis=0)
  221. # pred_scores [NA, 16] --> [16, NA]
  222. pred_scores0 = paddle.transpose(pred_scores, [1, 0])
  223. pred_scores = paddle.unsqueeze(pred_scores0, axis=0)
  224. pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores,
  225. self.num_classes)
  226. # Prevent empty bbox_pred from decode or NMS.
  227. # Bboxes and score before NMS may be empty due to the score threshold.
  228. if pred_cls_score_bbox.shape[0] <= 0 or pred_cls_score_bbox.shape[
  229. 1] <= 1:
  230. pred_cls_score_bbox = self.fake_pred_cls_score_bbox
  231. bbox_num = self.fake_bbox_num
  232. pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10])
  233. return pred_cls_score_bbox, bbox_num
  234. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  235. """
  236. Rescale, clip and filter the bbox from the output of NMS to
  237. get final prediction.
  238. Args:
  239. bboxes(Tensor): bboxes [N, 10]
  240. bbox_num(Tensor): bbox_num
  241. im_shape(Tensor): [1 2]
  242. scale_factor(Tensor): [1 2]
  243. Returns:
  244. bbox_pred(Tensor): The output is the prediction with shape [N, 8]
  245. including labels, scores and bboxes. The size of
  246. bboxes are corresponding to the original image.
  247. """
  248. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  249. origin_shape_list = []
  250. scale_factor_list = []
  251. # scale_factor: scale_y, scale_x
  252. for i in range(bbox_num.shape[0]):
  253. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  254. [bbox_num[i], 2])
  255. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  256. scale = paddle.concat([
  257. scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
  258. scale_y
  259. ])
  260. expand_scale = paddle.expand(scale, [bbox_num[i], 8])
  261. origin_shape_list.append(expand_shape)
  262. scale_factor_list.append(expand_scale)
  263. origin_shape_list = paddle.concat(origin_shape_list)
  264. scale_factor_list = paddle.concat(scale_factor_list)
  265. # bboxes: [N, 10], label, score, bbox
  266. pred_label_score = bboxes[:, 0:2]
  267. pred_bbox = bboxes[:, 2:]
  268. # rescale bbox to original image
  269. pred_bbox = pred_bbox.reshape([-1, 8])
  270. scaled_bbox = pred_bbox / scale_factor_list
  271. origin_h = origin_shape_list[:, 0]
  272. origin_w = origin_shape_list[:, 1]
  273. bboxes = scaled_bbox
  274. zeros = paddle.zeros_like(origin_h)
  275. x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros)
  276. y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros)
  277. x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros)
  278. y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros)
  279. x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros)
  280. y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros)
  281. x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros)
  282. y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros)
  283. pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1)
  284. pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1)
  285. return pred_result
  286. @register
  287. class JDEBBoxPostProcess(BBoxPostProcess):
  288. def __call__(self, head_out, anchors):
  289. """
  290. Decode the bbox and do NMS for JDE model.
  291. Args:
  292. head_out (list): Bbox_pred and cls_prob of bbox_head output.
  293. anchors (list): Anchors of JDE model.
  294. Returns:
  295. boxes_idx (Tensor): The index of kept bboxes after decode 'JDEBox'.
  296. bbox_pred (Tensor): The output is the prediction with shape [N, 6]
  297. including labels, scores and bboxes.
  298. bbox_num (Tensor): The number of prediction of each batch with shape [N].
  299. nms_keep_idx (Tensor): The index of kept bboxes after NMS.
  300. """
  301. boxes_idx, bboxes, score = self.decode(head_out, anchors)
  302. bbox_pred, bbox_num, nms_keep_idx = self.nms(bboxes, score,
  303. self.num_classes)
  304. if bbox_pred.shape[0] == 0:
  305. bbox_pred = paddle.to_tensor(
  306. np.array(
  307. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  308. bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  309. nms_keep_idx = paddle.to_tensor(np.array([[0]], dtype='int32'))
  310. return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
  311. @register
  312. class CenterNetPostProcess(TTFBox):
  313. """
  314. Postprocess the model outputs to get final prediction:
  315. 1. Do NMS for heatmap to get top `max_per_img` bboxes.
  316. 2. Decode bboxes using center offset and box size.
  317. 3. Rescale decoded bboxes reference to the origin image shape.
  318. Args:
  319. max_per_img(int): the maximum number of predicted objects in a image,
  320. 500 by default.
  321. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  322. regress_ltrb (bool): whether to regress left/top/right/bottom or
  323. width/height for a box, true by default.
  324. for_mot (bool): whether return other features used in tracking model.
  325. """
  326. __shared__ = ['down_ratio']
  327. def __init__(self,
  328. max_per_img=500,
  329. down_ratio=4,
  330. regress_ltrb=True,
  331. for_mot=False):
  332. super(TTFBox, self).__init__()
  333. self.max_per_img = max_per_img
  334. self.down_ratio = down_ratio
  335. self.regress_ltrb = regress_ltrb
  336. self.for_mot = for_mot
  337. def __call__(self, hm, wh, reg, im_shape, scale_factor):
  338. heat = self._simple_nms(hm)
  339. scores, inds, clses, ys, xs = self._topk(heat)
  340. scores = paddle.tensor.unsqueeze(scores, [1])
  341. clses = paddle.tensor.unsqueeze(clses, [1])
  342. reg_t = paddle.transpose(reg, [0, 2, 3, 1])
  343. # Like TTFBox, batch size is 1.
  344. # TODO: support batch size > 1
  345. reg = paddle.reshape(reg_t, [-1, paddle.shape(reg_t)[-1]])
  346. reg = paddle.gather(reg, inds)
  347. xs = paddle.cast(xs, 'float32')
  348. ys = paddle.cast(ys, 'float32')
  349. xs = xs + reg[:, 0:1]
  350. ys = ys + reg[:, 1:2]
  351. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  352. wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
  353. wh = paddle.gather(wh, inds)
  354. if self.regress_ltrb:
  355. x1 = xs - wh[:, 0:1]
  356. y1 = ys - wh[:, 1:2]
  357. x2 = xs + wh[:, 2:3]
  358. y2 = ys + wh[:, 3:4]
  359. else:
  360. x1 = xs - wh[:, 0:1] / 2
  361. y1 = ys - wh[:, 1:2] / 2
  362. x2 = xs + wh[:, 0:1] / 2
  363. y2 = ys + wh[:, 1:2] / 2
  364. n, c, feat_h, feat_w = paddle.shape(hm)
  365. padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
  366. padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
  367. x1 = x1 * self.down_ratio
  368. y1 = y1 * self.down_ratio
  369. x2 = x2 * self.down_ratio
  370. y2 = y2 * self.down_ratio
  371. x1 = x1 - padw
  372. y1 = y1 - padh
  373. x2 = x2 - padw
  374. y2 = y2 - padh
  375. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  376. scale_y = scale_factor[:, 0:1]
  377. scale_x = scale_factor[:, 1:2]
  378. scale_expand = paddle.concat(
  379. [scale_x, scale_y, scale_x, scale_y], axis=1)
  380. boxes_shape = paddle.shape(bboxes)
  381. boxes_shape.stop_gradient = True
  382. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  383. bboxes = paddle.divide(bboxes, scale_expand)
  384. if self.for_mot:
  385. results = paddle.concat([bboxes, scores, clses], axis=1)
  386. return results, inds
  387. else:
  388. results = paddle.concat([clses, scores, bboxes], axis=1)
  389. return results, paddle.shape(results)[0:1]