iou3d_nms_kernel.cu 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. /*
  15. 3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
  16. Written by Shaoshuai Shi
  17. All Rights Reserved 2019-2020.
  18. */
  19. #include <stdio.h>
  20. #define THREADS_PER_BLOCK 16
  21. #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
  22. // #define DEBUG
  23. const int THREADS_PER_BLOCK_NMS = sizeof(int64_t) * 8;
  24. const float EPS = 1e-8;
  25. struct Point {
  26. float x, y;
  27. __device__ Point() {}
  28. __device__ Point(double _x, double _y) { x = _x, y = _y; }
  29. __device__ void set(float _x, float _y) {
  30. x = _x;
  31. y = _y;
  32. }
  33. __device__ Point operator+(const Point &b) const {
  34. return Point(x + b.x, y + b.y);
  35. }
  36. __device__ Point operator-(const Point &b) const {
  37. return Point(x - b.x, y - b.y);
  38. }
  39. };
  40. __device__ inline float cross(const Point &a, const Point &b) {
  41. return a.x * b.y - a.y * b.x;
  42. }
  43. __device__ inline float cross(const Point &p1, const Point &p2,
  44. const Point &p0) {
  45. return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
  46. }
  47. __device__ int check_rect_cross(const Point &p1, const Point &p2,
  48. const Point &q1, const Point &q2) {
  49. int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) &&
  50. min(q1.x, q2.x) <= max(p1.x, p2.x) &&
  51. min(p1.y, p2.y) <= max(q1.y, q2.y) &&
  52. min(q1.y, q2.y) <= max(p1.y, p2.y);
  53. return ret;
  54. }
  55. __device__ inline int check_in_box2d(const float *box, const Point &p) {
  56. // params: (7) [x, y, z, dx, dy, dz, heading]
  57. const float MARGIN = 1e-2;
  58. float center_x = box[0], center_y = box[1];
  59. float angle_cos = cos(-box[6]),
  60. angle_sin =
  61. sin(-box[6]); // rotate the point in the opposite direction of box
  62. float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
  63. float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;
  64. return (fabs(rot_x) < box[3] / 2 + MARGIN &&
  65. fabs(rot_y) < box[4] / 2 + MARGIN);
  66. }
  67. __device__ inline int intersection(const Point &p1, const Point &p0,
  68. const Point &q1, const Point &q0,
  69. Point &ans) {
  70. // fast exclusion
  71. if (check_rect_cross(p0, p1, q0, q1) == 0)
  72. return 0;
  73. // check cross standing
  74. float s1 = cross(q0, p1, p0);
  75. float s2 = cross(p1, q1, p0);
  76. float s3 = cross(p0, q1, q0);
  77. float s4 = cross(q1, p1, q0);
  78. if (!(s1 * s2 > 0 && s3 * s4 > 0))
  79. return 0;
  80. // calculate intersection of two lines
  81. float s5 = cross(q1, p1, p0);
  82. if (fabs(s5 - s1) > EPS) {
  83. ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
  84. ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
  85. } else {
  86. float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
  87. float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
  88. float D = a0 * b1 - a1 * b0;
  89. ans.x = (b0 * c1 - b1 * c0) / D;
  90. ans.y = (a1 * c0 - a0 * c1) / D;
  91. }
  92. return 1;
  93. }
  94. __device__ inline void rotate_around_center(const Point &center,
  95. const float angle_cos,
  96. const float angle_sin, Point &p) {
  97. float new_x =
  98. (p.x - center.x) * angle_cos + (p.y - center.y) * (-angle_sin) + center.x;
  99. float new_y =
  100. (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;
  101. p.set(new_x, new_y);
  102. }
  103. __device__ inline int point_cmp(const Point &a, const Point &b,
  104. const Point &center) {
  105. return atan2(a.y - center.y, a.x - center.x) >
  106. atan2(b.y - center.y, b.x - center.x);
  107. }
  108. __device__ inline float box_overlap(const float *box_a, const float *box_b) {
  109. // params box_a: [x, y, z, dx, dy, dz, heading]
  110. // params box_b: [x, y, z, dx, dy, dz, heading]
  111. float a_angle = box_a[6], b_angle = box_b[6];
  112. float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2,
  113. a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2;
  114. float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half;
  115. float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half;
  116. float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half;
  117. float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half;
  118. Point center_a(box_a[0], box_a[1]);
  119. Point center_b(box_b[0], box_b[1]);
  120. #ifdef DEBUG
  121. printf(
  122. "a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n",
  123. a_x1, a_y1, a_x2, a_y2, a_angle, b_x1, b_y1, b_x2, b_y2, b_angle);
  124. printf("center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y,
  125. center_b.x, center_b.y);
  126. #endif
  127. Point box_a_corners[5];
  128. box_a_corners[0].set(a_x1, a_y1);
  129. box_a_corners[1].set(a_x2, a_y1);
  130. box_a_corners[2].set(a_x2, a_y2);
  131. box_a_corners[3].set(a_x1, a_y2);
  132. Point box_b_corners[5];
  133. box_b_corners[0].set(b_x1, b_y1);
  134. box_b_corners[1].set(b_x2, b_y1);
  135. box_b_corners[2].set(b_x2, b_y2);
  136. box_b_corners[3].set(b_x1, b_y2);
  137. // get oriented corners
  138. float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
  139. float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
  140. for (int k = 0; k < 4; k++) {
  141. #ifdef DEBUG
  142. printf("before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k,
  143. box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x,
  144. box_b_corners[k].y);
  145. #endif
  146. rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);
  147. rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);
  148. #ifdef DEBUG
  149. printf("corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x,
  150. box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y);
  151. #endif
  152. }
  153. box_a_corners[4] = box_a_corners[0];
  154. box_b_corners[4] = box_b_corners[0];
  155. // get intersection of lines
  156. Point cross_points[16];
  157. Point poly_center;
  158. int cnt = 0, flag = 0;
  159. poly_center.set(0, 0);
  160. for (int i = 0; i < 4; i++) {
  161. for (int j = 0; j < 4; j++) {
  162. flag = intersection(box_a_corners[i + 1], box_a_corners[i],
  163. box_b_corners[j + 1], box_b_corners[j],
  164. cross_points[cnt]);
  165. if (flag) {
  166. poly_center = poly_center + cross_points[cnt];
  167. cnt++;
  168. #ifdef DEBUG
  169. printf(
  170. "Cross points (%.3f, %.3f): a(%.3f, %.3f)->(%.3f, %.3f), b(%.3f, "
  171. "%.3f)->(%.3f, %.3f) \n",
  172. cross_points[cnt - 1].x, cross_points[cnt - 1].y,
  173. box_a_corners[i].x, box_a_corners[i].y, box_a_corners[i + 1].x,
  174. box_a_corners[i + 1].y, box_b_corners[i].x, box_b_corners[i].y,
  175. box_b_corners[i + 1].x, box_b_corners[i + 1].y);
  176. #endif
  177. }
  178. }
  179. }
  180. // check corners
  181. for (int k = 0; k < 4; k++) {
  182. if (check_in_box2d(box_a, box_b_corners[k])) {
  183. poly_center = poly_center + box_b_corners[k];
  184. cross_points[cnt] = box_b_corners[k];
  185. cnt++;
  186. #ifdef DEBUG
  187. printf("b corners in a: corner_b(%.3f, %.3f)", cross_points[cnt - 1].x,
  188. cross_points[cnt - 1].y);
  189. #endif
  190. }
  191. if (check_in_box2d(box_b, box_a_corners[k])) {
  192. poly_center = poly_center + box_a_corners[k];
  193. cross_points[cnt] = box_a_corners[k];
  194. cnt++;
  195. #ifdef DEBUG
  196. printf("a corners in b: corner_a(%.3f, %.3f)", cross_points[cnt - 1].x,
  197. cross_points[cnt - 1].y);
  198. #endif
  199. }
  200. }
  201. poly_center.x /= cnt;
  202. poly_center.y /= cnt;
  203. // sort the points of polygon
  204. Point temp;
  205. for (int j = 0; j < cnt - 1; j++) {
  206. for (int i = 0; i < cnt - j - 1; i++) {
  207. if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) {
  208. temp = cross_points[i];
  209. cross_points[i] = cross_points[i + 1];
  210. cross_points[i + 1] = temp;
  211. }
  212. }
  213. }
  214. #ifdef DEBUG
  215. printf("cnt=%d\n", cnt);
  216. for (int i = 0; i < cnt; i++) {
  217. printf("All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x,
  218. cross_points[i].y);
  219. }
  220. #endif
  221. // get the overlap areas
  222. float area = 0;
  223. for (int k = 0; k < cnt - 1; k++) {
  224. area += cross(cross_points[k] - cross_points[0],
  225. cross_points[k + 1] - cross_points[0]);
  226. }
  227. return fabs(area) / 2.0;
  228. }
  229. __device__ inline float iou_bev(const float *box_a, const float *box_b) {
  230. // params box_a: [x, y, z, dx, dy, dz, heading]
  231. // params box_b: [x, y, z, dx, dy, dz, heading]
  232. float sa = box_a[3] * box_a[4];
  233. float sb = box_b[3] * box_b[4];
  234. float s_overlap = box_overlap(box_a, box_b);
  235. return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
  236. }
  237. __global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a,
  238. const int num_b, const float *boxes_b,
  239. float *ans_overlap) {
  240. // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
  241. // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
  242. const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
  243. const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
  244. if (a_idx >= num_a || b_idx >= num_b) {
  245. return;
  246. }
  247. const float *cur_box_a = boxes_a + a_idx * 7;
  248. const float *cur_box_b = boxes_b + b_idx * 7;
  249. float s_overlap = box_overlap(cur_box_a, cur_box_b);
  250. ans_overlap[a_idx * num_b + b_idx] = s_overlap;
  251. }
  252. __global__ void boxes_iou_bev_kernel(const int num_a, const float *boxes_a,
  253. const int num_b, const float *boxes_b,
  254. float *ans_iou) {
  255. // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
  256. // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
  257. const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
  258. const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
  259. if (a_idx >= num_a || b_idx >= num_b) {
  260. return;
  261. }
  262. const float *cur_box_a = boxes_a + a_idx * 7;
  263. const float *cur_box_b = boxes_b + b_idx * 7;
  264. float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
  265. ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
  266. }
  267. __global__ void nms_kernel(const int boxes_num, const float nms_overlap_thresh,
  268. const float *boxes, int64_t *mask) {
  269. // params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
  270. // params: mask (N, N/THREADS_PER_BLOCK_NMS)
  271. const int row_start = blockIdx.y;
  272. const int col_start = blockIdx.x;
  273. // if (row_start > col_start) return;
  274. const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
  275. THREADS_PER_BLOCK_NMS);
  276. const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
  277. THREADS_PER_BLOCK_NMS);
  278. __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
  279. if (threadIdx.x < col_size) {
  280. block_boxes[threadIdx.x * 7 + 0] =
  281. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
  282. block_boxes[threadIdx.x * 7 + 1] =
  283. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
  284. block_boxes[threadIdx.x * 7 + 2] =
  285. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
  286. block_boxes[threadIdx.x * 7 + 3] =
  287. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
  288. block_boxes[threadIdx.x * 7 + 4] =
  289. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
  290. block_boxes[threadIdx.x * 7 + 5] =
  291. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
  292. block_boxes[threadIdx.x * 7 + 6] =
  293. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
  294. }
  295. __syncthreads();
  296. if (threadIdx.x < row_size) {
  297. const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
  298. const float *cur_box = boxes + cur_box_idx * 7;
  299. int i = 0;
  300. int64_t t = 0;
  301. int start = 0;
  302. if (row_start == col_start) {
  303. start = threadIdx.x + 1;
  304. }
  305. for (i = start; i < col_size; i++) {
  306. if (iou_bev(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
  307. t |= 1ULL << i;
  308. }
  309. }
  310. const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
  311. mask[cur_box_idx * col_blocks + col_start] = t;
  312. }
  313. }
  314. __device__ inline float iou_normal(float const *const a, float const *const b) {
  315. // params: a: [x, y, z, dx, dy, dz, heading]
  316. // params: b: [x, y, z, dx, dy, dz, heading]
  317. float left = fmaxf(a[0] - a[3] / 2, b[0] - b[3] / 2),
  318. right = fminf(a[0] + a[3] / 2, b[0] + b[3] / 2);
  319. float top = fmaxf(a[1] - a[4] / 2, b[1] - b[4] / 2),
  320. bottom = fminf(a[1] + a[4] / 2, b[1] + b[4] / 2);
  321. float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f);
  322. float interS = width * height;
  323. float Sa = a[3] * a[4];
  324. float Sb = b[3] * b[4];
  325. return interS / fmaxf(Sa + Sb - interS, EPS);
  326. }
  327. __global__ void nms_normal_kernel(const int boxes_num,
  328. const float nms_overlap_thresh,
  329. const float *boxes, int64_t *mask) {
  330. // params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
  331. // params: mask (N, N/THREADS_PER_BLOCK_NMS)
  332. const int row_start = blockIdx.y;
  333. const int col_start = blockIdx.x;
  334. // if (row_start > col_start) return;
  335. const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
  336. THREADS_PER_BLOCK_NMS);
  337. const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
  338. THREADS_PER_BLOCK_NMS);
  339. __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
  340. if (threadIdx.x < col_size) {
  341. block_boxes[threadIdx.x * 7 + 0] =
  342. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
  343. block_boxes[threadIdx.x * 7 + 1] =
  344. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
  345. block_boxes[threadIdx.x * 7 + 2] =
  346. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
  347. block_boxes[threadIdx.x * 7 + 3] =
  348. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
  349. block_boxes[threadIdx.x * 7 + 4] =
  350. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
  351. block_boxes[threadIdx.x * 7 + 5] =
  352. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
  353. block_boxes[threadIdx.x * 7 + 6] =
  354. boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
  355. }
  356. __syncthreads();
  357. if (threadIdx.x < row_size) {
  358. const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
  359. const float *cur_box = boxes + cur_box_idx * 7;
  360. int i = 0;
  361. int64_t t = 0;
  362. int start = 0;
  363. if (row_start == col_start) {
  364. start = threadIdx.x + 1;
  365. }
  366. for (i = start; i < col_size; i++) {
  367. if (iou_normal(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
  368. t |= 1ULL << i;
  369. }
  370. }
  371. const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
  372. mask[cur_box_idx * col_blocks + col_start] = t;
  373. }
  374. }
  375. void BoxesOverlapLauncher(const cudaStream_t &stream, const int num_a,
  376. const float *boxes_a, const int num_b,
  377. const float *boxes_b, float *ans_overlap) {
  378. dim3 blocks(
  379. DIVUP(num_b, THREADS_PER_BLOCK),
  380. DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
  381. dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
  382. boxes_overlap_kernel<<<blocks, threads, 0, stream>>>(num_a, boxes_a, num_b,
  383. boxes_b, ans_overlap);
  384. #ifdef DEBUG
  385. cudaDeviceSynchronize(); // for using printf in kernel function
  386. #endif
  387. }
  388. void BoxesIouBevLauncher(const cudaStream_t &stream, const int num_a,
  389. const float *boxes_a, const int num_b,
  390. const float *boxes_b, float *ans_iou) {
  391. dim3 blocks(
  392. DIVUP(num_b, THREADS_PER_BLOCK),
  393. DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
  394. dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
  395. boxes_iou_bev_kernel<<<blocks, threads, 0, stream>>>(num_a, boxes_a, num_b,
  396. boxes_b, ans_iou);
  397. #ifdef DEBUG
  398. cudaDeviceSynchronize(); // for using printf in kernel function
  399. #endif
  400. }
  401. void NmsLauncher(const cudaStream_t &stream, const float *boxes, int64_t *mask,
  402. int boxes_num, float nms_overlap_thresh) {
  403. dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
  404. DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
  405. dim3 threads(THREADS_PER_BLOCK_NMS);
  406. nms_kernel<<<blocks, threads, 0, stream>>>(boxes_num, nms_overlap_thresh,
  407. boxes, mask);
  408. }
  409. void NmsNormalLauncher(const cudaStream_t &stream, const float *boxes,
  410. int64_t *mask, int boxes_num, float nms_overlap_thresh) {
  411. dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
  412. DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
  413. dim3 threads(THREADS_PER_BLOCK_NMS);
  414. nms_normal_kernel<<<blocks, threads, 0, stream>>>(
  415. boxes_num, nms_overlap_thresh, boxes, mask);
  416. }