multiclass_nms_rotated.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/detection/ppdet/multiclass_nms_rotated.h"
  15. #include <algorithm>
  16. #include <cmath>
  17. #include <opencv2/opencv.hpp>
  18. #include <vector>
  19. #include "ultra_infer/core/fd_tensor.h"
  20. #include "ultra_infer/utils/utils.h"
  21. #include "ultra_infer/vision/detection/ppdet/multiclass_nms.h"
  22. namespace ultra_infer {
  23. namespace vision {
  24. namespace detection {
  25. template <typename T> struct RotatedBox { T x_ctr, y_ctr, w, h, a; };
  26. template <typename T> struct Point {
  27. T x, y;
  28. Point(const T &px = 0, const T &py = 0) : x(px), y(py) {}
  29. Point operator+(const Point &p) const { return Point(x + p.x, y + p.y); }
  30. Point &operator+=(const Point &p) {
  31. x += p.x;
  32. y += p.y;
  33. return *this;
  34. }
  35. Point operator-(const Point &p) const { return Point(x - p.x, y - p.y); }
  36. Point operator*(const T coeff) const { return Point(x * coeff, y * coeff); }
  37. };
  38. template <typename T> T Dot2D(const Point<T> &A, const Point<T> &B) {
  39. return A.x * B.x + A.y * B.y;
  40. }
  41. template <typename T> T Cross2D(const Point<T> &A, const Point<T> &B) {
  42. return A.x * B.y - B.x * A.y;
  43. }
  44. template <typename T>
  45. int GetIntersectionPoints(const Point<T> (&pts1)[4], const Point<T> (&pts2)[4],
  46. Point<T> (&intersections)[24]) {
  47. // Line vector
  48. // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
  49. Point<T> vec1[4], vec2[4];
  50. for (int i = 0; i < 4; i++) {
  51. vec1[i] = pts1[(i + 1) % 4] - pts1[i];
  52. vec2[i] = pts2[(i + 1) % 4] - pts2[i];
  53. }
  54. // Line test - test all line combos for intersection
  55. int num = 0; // number of intersections
  56. for (int i = 0; i < 4; i++) {
  57. for (int j = 0; j < 4; j++) {
  58. // Solve for 2x2 Ax=b
  59. T det = Cross2D<T>(vec2[j], vec1[i]);
  60. // This takes care of parallel lines
  61. if (fabs(det) <= 1e-14) {
  62. continue;
  63. }
  64. auto vec12 = pts2[j] - pts1[i];
  65. T t1 = Cross2D<T>(vec2[j], vec12) / det;
  66. T t2 = Cross2D<T>(vec1[i], vec12) / det;
  67. if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
  68. intersections[num++] = pts1[i] + vec1[i] * t1;
  69. }
  70. }
  71. }
  72. // Check for vertices of rect1 inside rect2
  73. {
  74. const auto &AB = vec2[0];
  75. const auto &DA = vec2[3];
  76. auto ABdotAB = Dot2D<T>(AB, AB);
  77. auto ADdotAD = Dot2D<T>(DA, DA);
  78. for (int i = 0; i < 4; i++) {
  79. // assume ABCD is the rectangle, and P is the point to be judged
  80. // P is inside ABCD iff. P's projection on AB lies within AB
  81. // and P's projection on AD lies within AD
  82. auto AP = pts1[i] - pts2[0];
  83. auto APdotAB = Dot2D<T>(AP, AB);
  84. auto APdotAD = -Dot2D<T>(AP, DA);
  85. if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
  86. (APdotAD <= ADdotAD)) {
  87. intersections[num++] = pts1[i];
  88. }
  89. }
  90. }
  91. // Reverse the check - check for vertices of rect2 inside rect1
  92. {
  93. const auto &AB = vec1[0];
  94. const auto &DA = vec1[3];
  95. auto ABdotAB = Dot2D<T>(AB, AB);
  96. auto ADdotAD = Dot2D<T>(DA, DA);
  97. for (int i = 0; i < 4; i++) {
  98. auto AP = pts2[i] - pts1[0];
  99. auto APdotAB = Dot2D<T>(AP, AB);
  100. auto APdotAD = -Dot2D<T>(AP, DA);
  101. if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
  102. (APdotAD <= ADdotAD)) {
  103. intersections[num++] = pts2[i];
  104. }
  105. }
  106. }
  107. return num;
  108. }
  109. template <typename T>
  110. int ConvexHullGraham(const Point<T> (&p)[24], const int &num_in,
  111. Point<T> (&q)[24], bool shift_to_zero = false) {
  112. assert(num_in >= 2);
  113. // Step 1:
  114. // Find point with minimum y
  115. // if more than 1 points have the same minimum y,
  116. // pick the one with the minimum x.
  117. int t = 0;
  118. for (int i = 1; i < num_in; i++) {
  119. if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
  120. t = i;
  121. }
  122. }
  123. auto &start = p[t]; // starting point
  124. // Step 2:
  125. // Subtract starting point from every points (for sorting in the next step)
  126. for (int i = 0; i < num_in; i++) {
  127. q[i] = p[i] - start;
  128. }
  129. // Swap the starting point to position 0
  130. auto tmp = q[0];
  131. q[0] = q[t];
  132. q[t] = tmp;
  133. // Step 3:
  134. // Sort point 1 ~ num_in according to their relative cross-product values
  135. // (essentially sorting according to angles)
  136. // If the angles are the same, sort according to their distance to origin
  137. T dist[24];
  138. for (int i = 0; i < num_in; i++) {
  139. dist[i] = Dot2D<T>(q[i], q[i]);
  140. }
  141. // CPU version
  142. std::sort(q + 1, q + num_in,
  143. [](const Point<T> &A, const Point<T> &B) -> bool {
  144. T temp = Cross2D<T>(A, B);
  145. if (fabs(temp) < 1e-6) {
  146. return Dot2D<T>(A, A) < Dot2D<T>(B, B);
  147. } else {
  148. return temp > 0;
  149. }
  150. });
  151. // Step 4:
  152. // Make sure there are at least 2 points (that don't overlap with each other)
  153. // in the stack
  154. int k; // index of the non-overlapped second point
  155. for (k = 1; k < num_in; k++) {
  156. if (dist[k] > 1e-8) {
  157. break;
  158. }
  159. }
  160. if (k == num_in) {
  161. // We reach the end, which means the convex hull is just one point
  162. q[0] = p[t];
  163. return 1;
  164. }
  165. q[1] = q[k];
  166. int m = 2; // 2 points in the stack
  167. // Step 5:
  168. // Finally we can start the scanning process.
  169. // When a non-convex relationship between the 3 points is found
  170. // (either concave shape or duplicated points),
  171. // we pop the previous point from the stack
  172. // until the 3-point relationship is convex again, or
  173. // until the stack only contains two points
  174. for (int i = k + 1; i < num_in; i++) {
  175. while (m > 1 && Cross2D<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
  176. m--;
  177. }
  178. q[m++] = q[i];
  179. }
  180. // Step 6 (Optional):
  181. // In general sense we need the original coordinates, so we
  182. // need to shift the points back (reverting Step 2)
  183. // But if we're only interested in getting the area/perimeter of the shape
  184. // We can simply return.
  185. if (!shift_to_zero) {
  186. for (int i = 0; i < m; i++) {
  187. q[i] += start;
  188. }
  189. }
  190. return m;
  191. }
  192. template <typename T> T PolygonArea(const Point<T> (&q)[24], const int &m) {
  193. if (m <= 2) {
  194. return 0;
  195. }
  196. T area = 0;
  197. for (int i = 1; i < m - 1; i++) {
  198. area += fabs(Cross2D<T>(q[i] - q[0], q[i + 1] - q[0]));
  199. }
  200. return area / 2.0;
  201. }
  202. template <typename T>
  203. T RboxesIntersection(T const *const poly1_raw, T const *const poly2_raw) {
  204. // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
  205. // from rotated_rect_intersection_pts
  206. Point<T> intersectPts[24], orderedPts[24];
  207. Point<T> pts1[4];
  208. Point<T> pts2[4];
  209. for (int i = 0; i < 4; i++) {
  210. pts1[i] = Point<T>(poly1_raw[2 * i], poly1_raw[2 * i + 1]);
  211. pts2[i] = Point<T>(poly2_raw[2 * i], poly2_raw[2 * i + 1]);
  212. }
  213. int num = GetIntersectionPoints<T>(pts1, pts2, intersectPts);
  214. if (num <= 2) {
  215. return 0.0;
  216. }
  217. // Convex Hull to order the intersection points in clockwise order and find
  218. // the contour area.
  219. int num_convex = ConvexHullGraham<T>(intersectPts, num, orderedPts, true);
  220. return PolygonArea<T>(orderedPts, num_convex);
  221. }
  222. template <typename T> T PolyArea(T const *const poly_raw) {
  223. T area = 0.0;
  224. int j = 3;
  225. for (int i = 0; i < 4; i++) {
  226. // area += (x[j] + x[i]) * (y[j] - y[i]);
  227. area += (poly_raw[2 * j] + poly_raw[2 * i]) *
  228. (poly_raw[2 * j + 1] - poly_raw[2 * i + 1]);
  229. j = i;
  230. }
  231. // return static_cast<T>(abs(static_cast<float>(area) / 2.0));
  232. return std::abs(area / 2.0);
  233. }
  234. template <typename T>
  235. void Poly2Rbox(T const *const poly_raw, RotatedBox<T> &box) {
  236. std::vector<cv::Point2f> contour_poly{
  237. cv::Point2f(poly_raw[0], poly_raw[1]),
  238. cv::Point2f(poly_raw[2], poly_raw[3]),
  239. cv::Point2f(poly_raw[4], poly_raw[5]),
  240. cv::Point2f(poly_raw[6], poly_raw[7]),
  241. };
  242. cv::RotatedRect rotate_rect = cv::minAreaRect(contour_poly);
  243. box.x_ctr = rotate_rect.center.x;
  244. box.y_ctr = rotate_rect.center.y;
  245. box.w = rotate_rect.size.width;
  246. box.h = rotate_rect.size.height;
  247. box.a = rotate_rect.angle;
  248. }
  249. template <typename T>
  250. T RboxIouSingle(T const *const poly1_raw, T const *const poly2_raw) {
  251. const T area1 = PolyArea(poly1_raw);
  252. const T area2 = PolyArea(poly2_raw);
  253. const T intersection = RboxesIntersection<T>(poly1_raw, poly2_raw);
  254. const T iou = intersection / (area1 + area2 - intersection);
  255. return iou;
  256. }
  257. template <typename T>
  258. bool SortScorePairDescendRotated(const std::pair<float, T> &pair1,
  259. const std::pair<float, T> &pair2) {
  260. return pair1.first > pair2.first;
  261. }
  262. void GetMaxScoreIndexRotated(
  263. const float *scores, const int &score_size, const float &threshold,
  264. const int &top_k, std::vector<std::pair<float, int>> *sorted_indices) {
  265. for (size_t i = 0; i < score_size; ++i) {
  266. if (scores[i] > threshold) {
  267. sorted_indices->push_back(std::make_pair(scores[i], i));
  268. }
  269. }
  270. // Sort the score pair according to the scores in descending order
  271. std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
  272. SortScorePairDescendRotated<int>);
  273. // Keep top_k scores if needed.
  274. if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
  275. sorted_indices->resize(top_k);
  276. }
  277. }
  278. void PaddleMultiClassNMSRotated::FastNMSRotated(
  279. const float *boxes, const float *scores, const int &num_boxes,
  280. std::vector<int> *keep_indices) {
  281. std::vector<std::pair<float, int>> sorted_indices;
  282. GetMaxScoreIndexRotated(scores, num_boxes, score_threshold, nms_top_k,
  283. &sorted_indices);
  284. // printf("nms thrd: %f, sort dim: %d\n", nms_threshold,
  285. // int(sorted_indices.size()));
  286. float adaptive_threshold = nms_threshold;
  287. while (sorted_indices.size() != 0) {
  288. const int idx = sorted_indices.front().second;
  289. bool keep = true;
  290. for (size_t k = 0; k < keep_indices->size(); ++k) {
  291. if (!keep) {
  292. break;
  293. }
  294. const int kept_idx = (*keep_indices)[k];
  295. float overlap =
  296. RboxIouSingle<float>(boxes + idx * 8, boxes + kept_idx * 8);
  297. keep = overlap <= adaptive_threshold;
  298. }
  299. if (keep) {
  300. keep_indices->push_back(idx);
  301. }
  302. sorted_indices.erase(sorted_indices.begin());
  303. if (keep && nms_eta<1.0 & adaptive_threshold> 0.5) {
  304. adaptive_threshold *= nms_eta;
  305. }
  306. }
  307. }
  308. int PaddleMultiClassNMSRotated::NMSRotatedForEachSample(
  309. const float *boxes, const float *scores, int num_boxes, int num_classes,
  310. std::map<int, std::vector<int>> *keep_indices) {
  311. for (int i = 0; i < num_classes; ++i) {
  312. if (i == background_label) {
  313. continue;
  314. }
  315. const float *score_for_class_i = scores + i * num_boxes;
  316. FastNMSRotated(boxes, score_for_class_i, num_boxes, &((*keep_indices)[i]));
  317. }
  318. int num_det = 0;
  319. for (auto iter = keep_indices->begin(); iter != keep_indices->end(); ++iter) {
  320. num_det += iter->second.size();
  321. }
  322. if (keep_top_k > -1 && num_det > keep_top_k) {
  323. std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
  324. for (const auto &it : *keep_indices) {
  325. int label = it.first;
  326. const float *current_score = scores + label * num_boxes;
  327. auto &label_indices = it.second;
  328. for (size_t j = 0; j < label_indices.size(); ++j) {
  329. int idx = label_indices[j];
  330. score_index_pairs.push_back(
  331. std::make_pair(current_score[idx], std::make_pair(label, idx)));
  332. }
  333. }
  334. std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
  335. SortScorePairDescendRotated<std::pair<int, int>>);
  336. score_index_pairs.resize(keep_top_k);
  337. std::map<int, std::vector<int>> new_indices;
  338. for (size_t j = 0; j < score_index_pairs.size(); ++j) {
  339. int label = score_index_pairs[j].second.first;
  340. int idx = score_index_pairs[j].second.second;
  341. new_indices[label].push_back(idx);
  342. }
  343. new_indices.swap(*keep_indices);
  344. num_det = keep_top_k;
  345. }
  346. return num_det;
  347. }
  348. void PaddleMultiClassNMSRotated::Compute(
  349. const float *boxes_data, const float *scores_data,
  350. const std::vector<int64_t> &boxes_dim,
  351. const std::vector<int64_t> &scores_dim) {
  352. int score_size = scores_dim.size();
  353. int64_t batch_size = scores_dim[0];
  354. int64_t box_dim = boxes_dim[2];
  355. int64_t out_dim = box_dim + 2;
  356. int num_nmsed_out = 0;
  357. FDASSERT(score_size == 3,
  358. "Require rank of input scores be 3, but now it's %d.", score_size);
  359. FDASSERT(boxes_dim[2] == 8,
  360. "Require the 3-dimension of input boxes be 8, but now it's %lld.",
  361. box_dim);
  362. out_num_rois_data.resize(batch_size);
  363. std::vector<std::map<int, std::vector<int>>> all_indices;
  364. for (size_t i = 0; i < batch_size; ++i) {
  365. std::map<int, std::vector<int>> indices; // indices kept for each class
  366. const float *current_boxes_ptr =
  367. boxes_data + i * boxes_dim[1] * boxes_dim[2];
  368. const float *current_scores_ptr =
  369. scores_data + i * scores_dim[1] * scores_dim[2];
  370. int num = NMSRotatedForEachSample(current_boxes_ptr, current_scores_ptr,
  371. boxes_dim[1], scores_dim[1], &indices);
  372. num_nmsed_out += num;
  373. out_num_rois_data[i] = num;
  374. all_indices.emplace_back(indices);
  375. }
  376. std::vector<int64_t> out_box_dims = {num_nmsed_out, 10};
  377. std::vector<int64_t> out_index_dims = {num_nmsed_out, 1};
  378. if (num_nmsed_out == 0) {
  379. for (size_t i = 0; i < batch_size; ++i) {
  380. out_num_rois_data[i] = 0;
  381. }
  382. return;
  383. }
  384. out_box_data.resize(num_nmsed_out * 10);
  385. out_index_data.resize(num_nmsed_out);
  386. int count = 0;
  387. for (size_t i = 0; i < batch_size; ++i) {
  388. const float *current_boxes_ptr =
  389. boxes_data + i * boxes_dim[1] * boxes_dim[2];
  390. const float *current_scores_ptr =
  391. scores_data + i * scores_dim[1] * scores_dim[2];
  392. for (const auto &it : all_indices[i]) {
  393. int label = it.first;
  394. const auto &indices = it.second;
  395. const float *current_scores_class_ptr =
  396. current_scores_ptr + label * scores_dim[2];
  397. for (size_t j = 0; j < indices.size(); ++j) {
  398. int start = count * 10;
  399. out_box_data[start] = label;
  400. out_box_data[start + 1] = current_scores_class_ptr[indices[j]];
  401. for (int k = 0; k < 8; k++) {
  402. out_box_data[start + 2 + k] = current_boxes_ptr[indices[j] * 8 + k];
  403. }
  404. out_index_data[count] = i * boxes_dim[1] + indices[j];
  405. count += 1;
  406. }
  407. }
  408. }
  409. }
  410. } // namespace detection
  411. } // namespace vision
  412. } // namespace ultra_infer