iou3d_nms.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. /*
  15. 3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
  16. Written by Shaoshuai Shi
  17. All Rights Reserved 2019-2020.
  18. */
  19. #include "iou3d_nms.h"
  20. #include <cuda.h>
  21. #include <cuda_runtime_api.h>
  22. #include <paddle/extension.h>
  23. #include <vector>
  24. #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
  25. const int THREADS_PER_BLOCK_NMS = sizeof(int64_t) * 8;
  26. void BoxesOverlapLauncher(const cudaStream_t &stream, const int num_a,
  27. const float *boxes_a, const int num_b,
  28. const float *boxes_b, float *ans_overlap);
  29. void BoxesIouBevLauncher(const cudaStream_t &stream, const int num_a,
  30. const float *boxes_a, const int num_b,
  31. const float *boxes_b, float *ans_iou);
  32. void NmsLauncher(const cudaStream_t &stream, const float *boxes, int64_t *mask,
  33. int boxes_num, float nms_overlap_thresh);
  34. void NmsNormalLauncher(const cudaStream_t &stream, const float *boxes,
  35. int64_t *mask, int boxes_num, float nms_overlap_thresh);
  36. std::vector<paddle::Tensor>
  37. boxes_overlap_bev_gpu(const paddle::Tensor &boxes_a,
  38. const paddle::Tensor &boxes_b) {
  39. // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
  40. // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
  41. // params ans_overlap: (N, M)
  42. int num_a = boxes_a.shape()[0];
  43. int num_b = boxes_b.shape()[0];
  44. const float *boxes_a_data = boxes_a.data<float>();
  45. const float *boxes_b_data = boxes_b.data<float>();
  46. auto ans_overlap = paddle::empty({num_a, num_b}, paddle::DataType::FLOAT32,
  47. paddle::GPUPlace());
  48. float *ans_overlap_data = ans_overlap.data<float>();
  49. BoxesOverlapLauncher(boxes_a.stream(), num_a, boxes_a_data, num_b,
  50. boxes_b_data, ans_overlap_data);
  51. return {ans_overlap};
  52. }
  53. std::vector<paddle::Tensor>
  54. boxes_iou_bev_gpu(const paddle::Tensor &boxes_a_tensor,
  55. const paddle::Tensor &boxes_b_tensor) {
  56. // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
  57. // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
  58. // params ans_overlap: (N, M)
  59. int num_a = boxes_a_tensor.shape()[0];
  60. int num_b = boxes_b_tensor.shape()[0];
  61. const float *boxes_a_data = boxes_a_tensor.data<float>();
  62. const float *boxes_b_data = boxes_b_tensor.data<float>();
  63. auto ans_iou_tensor = paddle::empty({num_a, num_b}, paddle::DataType::FLOAT32,
  64. paddle::GPUPlace());
  65. float *ans_iou_data = ans_iou_tensor.data<float>();
  66. BoxesIouBevLauncher(boxes_a_tensor.stream(), num_a, boxes_a_data, num_b,
  67. boxes_b_data, ans_iou_data);
  68. return {ans_iou_tensor};
  69. }
  70. std::vector<paddle::Tensor> nms_gpu(const paddle::Tensor &boxes,
  71. float nms_overlap_thresh) {
  72. // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
  73. auto keep = paddle::empty({boxes.shape()[0]}, paddle::DataType::INT32,
  74. paddle::CPUPlace());
  75. auto num_to_keep_tensor =
  76. paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace());
  77. int *num_to_keep_data = num_to_keep_tensor.data<int>();
  78. int boxes_num = boxes.shape()[0];
  79. const float *boxes_data = boxes.data<float>();
  80. int *keep_data = keep.data<int>();
  81. const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
  82. // int64_t *mask_data = NULL;
  83. // CHECK_ERROR(cudaMalloc((void**)&mask_data, boxes_num * col_blocks *
  84. // sizeof(int64_t)));
  85. auto mask = paddle::empty({boxes_num * col_blocks}, paddle::DataType::INT64,
  86. paddle::GPUPlace());
  87. int64_t *mask_data = mask.data<int64_t>();
  88. NmsLauncher(boxes.stream(), boxes_data, mask_data, boxes_num,
  89. nms_overlap_thresh);
  90. // std::vector<int64_t> mask_cpu(boxes_num * col_blocks);
  91. // CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks *
  92. // sizeof(int64_t),
  93. // cudaMemcpyDeviceToHost));
  94. const paddle::Tensor mask_cpu_tensor = mask.copy_to(paddle::CPUPlace(), true);
  95. const int64_t *mask_cpu = mask_cpu_tensor.data<int64_t>();
  96. // cudaFree(mask_data);
  97. int64_t remv_cpu[col_blocks];
  98. memset(remv_cpu, 0, col_blocks * sizeof(int64_t));
  99. int num_to_keep = 0;
  100. for (int i = 0; i < boxes_num; i++) {
  101. int nblock = i / THREADS_PER_BLOCK_NMS;
  102. int inblock = i % THREADS_PER_BLOCK_NMS;
  103. if (!(remv_cpu[nblock] & (1ULL << inblock))) {
  104. keep_data[num_to_keep++] = i;
  105. const int64_t *p = &mask_cpu[0] + i * col_blocks;
  106. for (int j = nblock; j < col_blocks; j++) {
  107. remv_cpu[j] |= p[j];
  108. }
  109. }
  110. }
  111. num_to_keep_data[0] = num_to_keep;
  112. if (cudaSuccess != cudaGetLastError())
  113. printf("Error!\n");
  114. return {keep, num_to_keep_tensor};
  115. }
  116. std::vector<paddle::Tensor> nms_normal_gpu(const paddle::Tensor &boxes,
  117. float nms_overlap_thresh) {
  118. // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
  119. // params keep: (N)
  120. auto keep = paddle::empty({boxes.shape()[0]}, paddle::DataType::INT32,
  121. paddle::CPUPlace());
  122. auto num_to_keep_tensor =
  123. paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace());
  124. int *num_to_keep_data = num_to_keep_tensor.data<int>();
  125. int boxes_num = boxes.shape()[0];
  126. const float *boxes_data = boxes.data<float>();
  127. int *keep_data = keep.data<int>();
  128. const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
  129. // int64_t *mask_data = NULL;
  130. // CHECK_ERROR(cudaMalloc((void**)&mask_data, boxes_num * col_blocks *
  131. // sizeof(int64_t)));
  132. auto mask = paddle::empty({boxes_num * col_blocks}, paddle::DataType::INT64,
  133. paddle::GPUPlace());
  134. int64_t *mask_data = mask.data<int64_t>();
  135. NmsNormalLauncher(boxes.stream(), boxes_data, mask_data, boxes_num,
  136. nms_overlap_thresh);
  137. // int64_t mask_cpu[boxes_num * col_blocks];
  138. // int64_t *mask_cpu = new int64_t [boxes_num * col_blocks];
  139. // std::vector<int64_t> mask_cpu(boxes_num * col_blocks);
  140. // CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks *
  141. // sizeof(int64_t),
  142. // cudaMemcpyDeviceToHost));
  143. // cudaFree(mask_data);
  144. const paddle::Tensor mask_cpu_tensor = mask.copy_to(paddle::CPUPlace(), true);
  145. const int64_t *mask_cpu = mask_cpu_tensor.data<int64_t>();
  146. int64_t remv_cpu[col_blocks];
  147. memset(remv_cpu, 0, col_blocks * sizeof(int64_t));
  148. int num_to_keep = 0;
  149. for (int i = 0; i < boxes_num; i++) {
  150. int nblock = i / THREADS_PER_BLOCK_NMS;
  151. int inblock = i % THREADS_PER_BLOCK_NMS;
  152. if (!(remv_cpu[nblock] & (1ULL << inblock))) {
  153. keep_data[num_to_keep++] = i;
  154. const int64_t *p = &mask_cpu[0] + i * col_blocks;
  155. for (int j = nblock; j < col_blocks; j++) {
  156. remv_cpu[j] |= p[j];
  157. }
  158. }
  159. }
  160. num_to_keep_data[0] = num_to_keep;
  161. if (cudaSuccess != cudaGetLastError()) {
  162. printf("Error!\n");
  163. }
  164. return {keep, num_to_keep_tensor};
  165. }