iou3d_nms.cc 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. /*
  15. 3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
  16. Written by Shaoshuai Shi
  17. All Rights Reserved 2019-2020.
  18. */
  19. #if defined(WITH_GPU)
  20. #include <cuda.h>
  21. #include <cuda_runtime_api.h>
  22. #include "iou3d_nms.h"
  23. namespace ultra_infer {
  24. namespace paddle_custom_ops {
  25. #define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
  26. // #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
  27. static inline int DIVUP(const int m, const int n) {
  28. return ((m) / (n) + ((m) % (n) > 0));
  29. }
  30. #define CHECK_ERROR(ans) \
  31. { gpuAssert((ans), __FILE__, __LINE__); }
  32. inline void gpuAssert(cudaError_t code, const char *file, int line,
  33. bool abort = true) {
  34. if (code != cudaSuccess) {
  35. fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
  36. line);
  37. if (abort)
  38. exit(code);
  39. }
  40. }
  41. #define D(x) \
  42. PD_THROW('\n', x, \
  43. "\n--------------------------------- where is the error ? " \
  44. "---------------------------------------\n");
  45. static const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
  46. void boxesoverlapLauncher(const int num_a, const float *boxes_a,
  47. const int num_b, const float *boxes_b,
  48. float *ans_overlap);
  49. void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b,
  50. const float *boxes_b, float *ans_iou);
  51. void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num,
  52. float nms_overlap_thresh);
  53. void nmsNormalLauncher(const float *boxes, unsigned long long *mask,
  54. int boxes_num, float nms_overlap_thresh);
  55. int boxes_overlap_bev_gpu(paddle::Tensor boxes_a, paddle::Tensor boxes_b,
  56. paddle::Tensor ans_overlap) {
  57. // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
  58. // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
  59. // params ans_overlap: (N, M)
  60. CHECK_INPUT(boxes_a);
  61. CHECK_INPUT(boxes_b);
  62. CHECK_INPUT(ans_overlap);
  63. int num_a = boxes_a.shape()[0];
  64. int num_b = boxes_b.shape()[0];
  65. const float *boxes_a_data = boxes_a.data<float>();
  66. const float *boxes_b_data = boxes_b.data<float>();
  67. float *ans_overlap_data = ans_overlap.data<float>();
  68. boxesoverlapLauncher(num_a, boxes_a_data, num_b, boxes_b_data,
  69. ans_overlap_data);
  70. return 1;
  71. }
  72. int boxes_iou_bev_gpu(paddle::Tensor boxes_a, paddle::Tensor boxes_b,
  73. paddle::Tensor ans_iou) {
  74. // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
  75. // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
  76. // params ans_overlap: (N, M)
  77. CHECK_INPUT(boxes_a);
  78. CHECK_INPUT(boxes_b);
  79. CHECK_INPUT(ans_iou);
  80. int num_a = boxes_a.shape()[0];
  81. int num_b = boxes_b.shape()[0];
  82. const float *boxes_a_data = boxes_a.data<float>();
  83. const float *boxes_b_data = boxes_b.data<float>();
  84. float *ans_iou_data = ans_iou.data<float>();
  85. boxesioubevLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_iou_data);
  86. return 1;
  87. }
  88. std::vector<paddle::Tensor> nms_gpu(const paddle::Tensor &boxes,
  89. float nms_overlap_thresh) {
  90. // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
  91. // params keep: (N)
  92. CHECK_INPUT(boxes);
  93. // CHECK_CONTIGUOUS(keep);
  94. auto keep = paddle::empty({boxes.shape()[0]}, paddle::DataType::INT32,
  95. paddle::CPUPlace());
  96. auto num_to_keep_tensor =
  97. paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace());
  98. int *num_to_keep_data = num_to_keep_tensor.data<int>();
  99. int boxes_num = boxes.shape()[0];
  100. const float *boxes_data = boxes.data<float>();
  101. int *keep_data = keep.data<int>();
  102. int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
  103. unsigned long long *mask_data = NULL;
  104. CHECK_ERROR(cudaMalloc((void **)&mask_data,
  105. boxes_num * col_blocks * sizeof(unsigned long long)));
  106. nmsLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);
  107. // unsigned long long mask_cpu[boxes_num * col_blocks];
  108. // unsigned long long *mask_cpu = new unsigned long long [boxes_num *
  109. // col_blocks];
  110. std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);
  111. // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
  112. CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
  113. boxes_num * col_blocks * sizeof(unsigned long long),
  114. cudaMemcpyDeviceToHost));
  115. cudaFree(mask_data);
  116. // WARN(qiuyanjun): codes below will throw a compile error on windows with
  117. // msvc. Thus, we chose to use std::vectored to store the result instead.
  118. // unsigned long long remv_cpu[col_blocks];
  119. // memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
  120. std::vector<unsigned long long> remv_cpu(col_blocks, 0);
  121. int num_to_keep = 0;
  122. for (int i = 0; i < boxes_num; i++) {
  123. int nblock = i / THREADS_PER_BLOCK_NMS;
  124. int inblock = i % THREADS_PER_BLOCK_NMS;
  125. if (!(remv_cpu[nblock] & (1ULL << inblock))) {
  126. keep_data[num_to_keep++] = i;
  127. unsigned long long *p = &mask_cpu[0] + i * col_blocks;
  128. for (int j = nblock; j < col_blocks; j++) {
  129. remv_cpu[j] |= p[j];
  130. }
  131. }
  132. }
  133. num_to_keep_data[0] = num_to_keep;
  134. if (cudaSuccess != cudaGetLastError())
  135. printf("Error!\n");
  136. return {keep, num_to_keep_tensor};
  137. }
  138. int nms_normal_gpu(paddle::Tensor boxes, paddle::Tensor keep,
  139. float nms_overlap_thresh) {
  140. // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
  141. // params keep: (N)
  142. CHECK_INPUT(boxes);
  143. // CHECK_CONTIGUOUS(keep);
  144. int boxes_num = boxes.shape()[0];
  145. const float *boxes_data = boxes.data<float>();
  146. // WARN(qiuyanjun): long type for Tensor::data() API is not exported by
  147. // paddle, it will raise some link error on windows with msvc. Please check:
  148. // https://github.com/PaddlePaddle/Paddle/blob/release/2.5/paddle/phi/api/lib/tensor.cc
  149. #if defined(_WIN32)
  150. int *keep_data = keep.data<int>();
  151. #else
  152. long *keep_data = keep.data<long>();
  153. #endif
  154. int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
  155. unsigned long long *mask_data = NULL;
  156. CHECK_ERROR(cudaMalloc((void **)&mask_data,
  157. boxes_num * col_blocks * sizeof(unsigned long long)));
  158. nmsNormalLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);
  159. // unsigned long long mask_cpu[boxes_num * col_blocks];
  160. // unsigned long long *mask_cpu = new unsigned long long [boxes_num *
  161. // col_blocks];
  162. std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);
  163. // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
  164. CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
  165. boxes_num * col_blocks * sizeof(unsigned long long),
  166. cudaMemcpyDeviceToHost));
  167. cudaFree(mask_data);
  168. // WARN(qiuyanjun): codes below will throw a compile error on windows with
  169. // msvc. Thus, we chose to use std::vectored to store the result instead.
  170. // unsigned long long remv_cpu[col_blocks];
  171. // memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
  172. std::vector<unsigned long long> remv_cpu(col_blocks, 0);
  173. int num_to_keep = 0;
  174. for (int i = 0; i < boxes_num; i++) {
  175. int nblock = i / THREADS_PER_BLOCK_NMS;
  176. int inblock = i % THREADS_PER_BLOCK_NMS;
  177. if (!(remv_cpu[nblock] & (1ULL << inblock))) {
  178. keep_data[num_to_keep++] = i;
  179. unsigned long long *p = &mask_cpu[0] + i * col_blocks;
  180. for (int j = nblock; j < col_blocks; j++) {
  181. remv_cpu[j] |= p[j];
  182. }
  183. }
  184. }
  185. if (cudaSuccess != cudaGetLastError())
  186. printf("Error!\n");
  187. return num_to_keep;
  188. }
  189. } // namespace paddle_custom_ops
  190. } // namespace ultra_infer
  191. #endif