iou3d_cpu.cc 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. /*
  15. 3D Rotated IoU Calculation (CPU)
  16. Written by Shaoshuai Shi
  17. All Rights Reserved 2020.
  18. */
  19. #include "iou3d_cpu.h"
  20. #include <math.h>
  21. #include <stdio.h>
  22. #include <vector>
  23. namespace ultra_infer {
  24. namespace paddle_custom_ops {
  25. static inline float min(float a, float b) { return a > b ? b : a; }
  26. static inline float max(float a, float b) { return a > b ? a : b; }
  27. #if defined(_WIN32)
  28. #if defined(EPS)
  29. #undef EPS
  30. #endif
  31. #define EPS 1e-8
  32. #else
  33. static const float EPS = 1e-8;
  34. #endif
  35. struct Point {
  36. float x, y;
  37. Point() {}
  38. Point(double _x, double _y) { x = _x, y = _y; }
  39. void set(float _x, float _y) {
  40. x = _x;
  41. y = _y;
  42. }
  43. Point operator+(const Point &b) const { return Point(x + b.x, y + b.y); }
  44. Point operator-(const Point &b) const { return Point(x - b.x, y - b.y); }
  45. };
  46. static inline float cross(const Point &a, const Point &b) {
  47. return a.x * b.y - a.y * b.x;
  48. }
  49. static inline float cross(const Point &p1, const Point &p2, const Point &p0) {
  50. return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
  51. }
  52. static inline int check_rect_cross(const Point &p1, const Point &p2,
  53. const Point &q1, const Point &q2) {
  54. int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) &&
  55. min(q1.x, q2.x) <= max(p1.x, p2.x) &&
  56. min(p1.y, p2.y) <= max(q1.y, q2.y) &&
  57. min(q1.y, q2.y) <= max(p1.y, p2.y);
  58. return ret;
  59. }
  60. static inline int check_in_box2d(const float *box, const Point &p) {
  61. // params: (7) [x, y, z, dx, dy, dz, heading]
  62. const float MARGIN = 1e-2;
  63. float center_x = box[0], center_y = box[1];
  64. float angle_cos = cos(-box[6]),
  65. angle_sin =
  66. sin(-box[6]); // rotate the point in the opposite direction of box
  67. float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
  68. float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;
  69. return (fabs(rot_x) < box[3] / 2 + MARGIN &&
  70. fabs(rot_y) < box[4] / 2 + MARGIN);
  71. }
  72. static inline int intersection(const Point &p1, const Point &p0,
  73. const Point &q1, const Point &q0, Point &ans) {
  74. // fast exclusion
  75. if (check_rect_cross(p0, p1, q0, q1) == 0)
  76. return 0;
  77. // check cross standing
  78. float s1 = cross(q0, p1, p0);
  79. float s2 = cross(p1, q1, p0);
  80. float s3 = cross(p0, q1, q0);
  81. float s4 = cross(q1, p1, q0);
  82. if (!(s1 * s2 > 0 && s3 * s4 > 0))
  83. return 0;
  84. // calculate intersection of two lines
  85. float s5 = cross(q1, p1, p0);
  86. if (fabs(s5 - s1) > EPS) {
  87. ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
  88. ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
  89. } else {
  90. float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
  91. float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
  92. float D = a0 * b1 - a1 * b0;
  93. ans.x = (b0 * c1 - b1 * c0) / D;
  94. ans.y = (a1 * c0 - a0 * c1) / D;
  95. }
  96. return 1;
  97. }
  98. static inline void rotate_around_center(const Point &center,
  99. const float angle_cos,
  100. const float angle_sin, Point &p) {
  101. float new_x =
  102. (p.x - center.x) * angle_cos + (p.y - center.y) * (-angle_sin) + center.x;
  103. float new_y =
  104. (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;
  105. p.set(new_x, new_y);
  106. }
  107. static inline int point_cmp(const Point &a, const Point &b,
  108. const Point &center) {
  109. return atan2(a.y - center.y, a.x - center.x) >
  110. atan2(b.y - center.y, b.x - center.x);
  111. }
  112. static inline float box_overlap(const float *box_a, const float *box_b) {
  113. // params: box_a (7) [x, y, z, dx, dy, dz, heading]
  114. // params: box_b (7) [x, y, z, dx, dy, dz, heading]
  115. // float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 =
  116. // box_a[3], a_angle = box_a[4];
  117. // float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 =
  118. // box_b[3], b_angle = box_b[4];
  119. float a_angle = box_a[6], b_angle = box_b[6];
  120. float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2,
  121. a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2;
  122. float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half;
  123. float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half;
  124. float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half;
  125. float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half;
  126. Point center_a(box_a[0], box_a[1]);
  127. Point center_b(box_b[0], box_b[1]);
  128. Point box_a_corners[5];
  129. box_a_corners[0].set(a_x1, a_y1);
  130. box_a_corners[1].set(a_x2, a_y1);
  131. box_a_corners[2].set(a_x2, a_y2);
  132. box_a_corners[3].set(a_x1, a_y2);
  133. Point box_b_corners[5];
  134. box_b_corners[0].set(b_x1, b_y1);
  135. box_b_corners[1].set(b_x2, b_y1);
  136. box_b_corners[2].set(b_x2, b_y2);
  137. box_b_corners[3].set(b_x1, b_y2);
  138. // get oriented corners
  139. float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
  140. float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
  141. for (int k = 0; k < 4; k++) {
  142. rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);
  143. rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);
  144. }
  145. box_a_corners[4] = box_a_corners[0];
  146. box_b_corners[4] = box_b_corners[0];
  147. // get intersection of lines
  148. Point cross_points[16];
  149. Point poly_center;
  150. int cnt = 0, flag = 0;
  151. poly_center.set(0, 0);
  152. for (int i = 0; i < 4; i++) {
  153. for (int j = 0; j < 4; j++) {
  154. flag = intersection(box_a_corners[i + 1], box_a_corners[i],
  155. box_b_corners[j + 1], box_b_corners[j],
  156. cross_points[cnt]);
  157. if (flag) {
  158. poly_center = poly_center + cross_points[cnt];
  159. cnt++;
  160. }
  161. }
  162. }
  163. // check corners
  164. for (int k = 0; k < 4; k++) {
  165. if (check_in_box2d(box_a, box_b_corners[k])) {
  166. poly_center = poly_center + box_b_corners[k];
  167. cross_points[cnt] = box_b_corners[k];
  168. cnt++;
  169. }
  170. if (check_in_box2d(box_b, box_a_corners[k])) {
  171. poly_center = poly_center + box_a_corners[k];
  172. cross_points[cnt] = box_a_corners[k];
  173. cnt++;
  174. }
  175. }
  176. poly_center.x /= cnt;
  177. poly_center.y /= cnt;
  178. // sort the points of polygon
  179. Point temp;
  180. for (int j = 0; j < cnt - 1; j++) {
  181. for (int i = 0; i < cnt - j - 1; i++) {
  182. if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) {
  183. temp = cross_points[i];
  184. cross_points[i] = cross_points[i + 1];
  185. cross_points[i + 1] = temp;
  186. }
  187. }
  188. }
  189. // get the overlap areas
  190. float area = 0;
  191. for (int k = 0; k < cnt - 1; k++) {
  192. area += cross(cross_points[k] - cross_points[0],
  193. cross_points[k + 1] - cross_points[0]);
  194. }
  195. return fabs(area) / 2.0;
  196. }
  197. static inline float iou_bev(const float *box_a, const float *box_b) {
  198. // params: box_a (7) [x, y, z, dx, dy, dz, heading]
  199. // params: box_b (7) [x, y, z, dx, dy, dz, heading]
  200. float sa = box_a[3] * box_a[4];
  201. float sb = box_b[3] * box_b[4];
  202. float s_overlap = box_overlap(box_a, box_b);
  203. return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
  204. }
  205. int boxes_iou_bev_cpu(paddle::Tensor boxes_a_tensor,
  206. paddle::Tensor boxes_b_tensor,
  207. paddle::Tensor ans_iou_tensor) {
  208. // params boxes_a_tensor: (N, 7) [x, y, z, dx, dy, dz, heading]
  209. // params boxes_b_tensor: (M, 7) [x, y, z, dx, dy, dz, heading]
  210. // params ans_iou_tensor: (N, M)
  211. // CHECK_CONTIGUOUS(boxes_a_tensor);
  212. // CHECK_CONTIGUOUS(boxes_b_tensor);
  213. int num_boxes_a = boxes_a_tensor.shape()[0];
  214. int num_boxes_b = boxes_b_tensor.shape()[0];
  215. const float *boxes_a = boxes_a_tensor.data<float>();
  216. const float *boxes_b = boxes_b_tensor.data<float>();
  217. float *ans_iou = ans_iou_tensor.data<float>();
  218. for (int i = 0; i < num_boxes_a; i++) {
  219. for (int j = 0; j < num_boxes_b; j++) {
  220. ans_iou[i * num_boxes_b + j] = iou_bev(boxes_a + i * 7, boxes_b + j * 7);
  221. }
  222. }
  223. return 1;
  224. }
  225. } // namespace paddle_custom_ops
  226. } // namespace ultra_infer