multiclass_nms.cc 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/detection/ppdet/multiclass_nms.h"
  15. #include "ultra_infer/core/fd_tensor.h"
  16. #include "ultra_infer/utils/utils.h"
  17. #include <algorithm>
  18. namespace ultra_infer {
  19. namespace vision {
  20. namespace detection {
  21. template <class T>
  22. bool SortScorePairDescend(const std::pair<float, T> &pair1,
  23. const std::pair<float, T> &pair2) {
  24. return pair1.first > pair2.first;
  25. }
  26. void GetMaxScoreIndex(const float *scores, const int &score_size,
  27. const float &threshold, const int &top_k,
  28. std::vector<std::pair<float, int>> *sorted_indices) {
  29. for (size_t i = 0; i < score_size; ++i) {
  30. if (scores[i] > threshold) {
  31. sorted_indices->push_back(std::make_pair(scores[i], i));
  32. }
  33. }
  34. // Sort the score pair according to the scores in descending order
  35. std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
  36. SortScorePairDescend<int>);
  37. // Keep top_k scores if needed.
  38. if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
  39. sorted_indices->resize(top_k);
  40. }
  41. }
  42. float BBoxArea(const float *box, const bool &normalized) {
  43. if (box[2] < box[0] || box[3] < box[1]) {
  44. // If coordinate values are is invalid
  45. // (e.g. xmax < xmin or ymax < ymin), return 0.
  46. return 0.f;
  47. } else {
  48. const float w = box[2] - box[0];
  49. const float h = box[3] - box[1];
  50. if (normalized) {
  51. return w * h;
  52. } else {
  53. // If coordinate values are not within range [0, 1].
  54. return (w + 1) * (h + 1);
  55. }
  56. }
  57. }
  58. float JaccardOverlap(const float *box1, const float *box2,
  59. const bool &normalized) {
  60. if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
  61. box2[3] < box1[1]) {
  62. return 0.f;
  63. } else {
  64. const float inter_xmin = std::max(box1[0], box2[0]);
  65. const float inter_ymin = std::max(box1[1], box2[1]);
  66. const float inter_xmax = std::min(box1[2], box2[2]);
  67. const float inter_ymax = std::min(box1[3], box2[3]);
  68. float norm = normalized ? 0.0f : 1.0f;
  69. float inter_w = inter_xmax - inter_xmin + norm;
  70. float inter_h = inter_ymax - inter_ymin + norm;
  71. const float inter_area = inter_w * inter_h;
  72. const float bbox1_area = BBoxArea(box1, normalized);
  73. const float bbox2_area = BBoxArea(box2, normalized);
  74. return inter_area / (bbox1_area + bbox2_area - inter_area);
  75. }
  76. }
  77. void PaddleMultiClassNMS::FastNMS(const float *boxes, const float *scores,
  78. const int &num_boxes,
  79. std::vector<int> *keep_indices) {
  80. std::vector<std::pair<float, int>> sorted_indices;
  81. GetMaxScoreIndex(scores, num_boxes, score_threshold, nms_top_k,
  82. &sorted_indices);
  83. float adaptive_threshold = nms_threshold;
  84. while (sorted_indices.size() != 0) {
  85. const int idx = sorted_indices.front().second;
  86. bool keep = true;
  87. for (size_t k = 0; k < keep_indices->size(); ++k) {
  88. if (!keep) {
  89. break;
  90. }
  91. const int kept_idx = (*keep_indices)[k];
  92. float overlap =
  93. JaccardOverlap(boxes + idx * 4, boxes + kept_idx * 4, normalized);
  94. keep = overlap <= adaptive_threshold;
  95. }
  96. if (keep) {
  97. keep_indices->push_back(idx);
  98. }
  99. sorted_indices.erase(sorted_indices.begin());
  100. if (keep && nms_eta<1.0 & adaptive_threshold> 0.5) {
  101. adaptive_threshold *= nms_eta;
  102. }
  103. }
  104. }
  105. int PaddleMultiClassNMS::NMSForEachSample(
  106. const float *boxes, const float *scores, int num_boxes, int num_classes,
  107. std::map<int, std::vector<int>> *keep_indices) {
  108. for (int i = 0; i < num_classes; ++i) {
  109. if (i == background_label) {
  110. continue;
  111. }
  112. const float *score_for_class_i = scores + i * num_boxes;
  113. FastNMS(boxes, score_for_class_i, num_boxes, &((*keep_indices)[i]));
  114. }
  115. int num_det = 0;
  116. for (auto iter = keep_indices->begin(); iter != keep_indices->end(); ++iter) {
  117. num_det += iter->second.size();
  118. }
  119. if (keep_top_k > -1 && num_det > keep_top_k) {
  120. std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
  121. for (const auto &it : *keep_indices) {
  122. int label = it.first;
  123. const float *current_score = scores + label * num_boxes;
  124. auto &label_indices = it.second;
  125. for (size_t j = 0; j < label_indices.size(); ++j) {
  126. int idx = label_indices[j];
  127. score_index_pairs.push_back(
  128. std::make_pair(current_score[idx], std::make_pair(label, idx)));
  129. }
  130. }
  131. std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
  132. SortScorePairDescend<std::pair<int, int>>);
  133. score_index_pairs.resize(keep_top_k);
  134. std::map<int, std::vector<int>> new_indices;
  135. for (size_t j = 0; j < score_index_pairs.size(); ++j) {
  136. int label = score_index_pairs[j].second.first;
  137. int idx = score_index_pairs[j].second.second;
  138. new_indices[label].push_back(idx);
  139. }
  140. new_indices.swap(*keep_indices);
  141. num_det = keep_top_k;
  142. }
  143. return num_det;
  144. }
  145. void PaddleMultiClassNMS::Compute(const float *boxes_data,
  146. const float *scores_data,
  147. const std::vector<int64_t> &boxes_dim,
  148. const std::vector<int64_t> &scores_dim) {
  149. int score_size = scores_dim.size();
  150. int64_t batch_size = scores_dim[0];
  151. int64_t box_dim = boxes_dim[2];
  152. int64_t out_dim = box_dim + 2;
  153. int num_nmsed_out = 0;
  154. FDASSERT(score_size == 3,
  155. "Require rank of input scores be 3, but now it's %d.", score_size);
  156. FDASSERT(boxes_dim[2] == 4,
  157. "Require the 3-dimension of input boxes be 4, but now it's %lld.",
  158. box_dim);
  159. out_num_rois_data.resize(batch_size);
  160. std::vector<std::map<int, std::vector<int>>> all_indices;
  161. for (size_t i = 0; i < batch_size; ++i) {
  162. std::map<int, std::vector<int>> indices; // indices kept for each class
  163. const float *current_boxes_ptr =
  164. boxes_data + i * boxes_dim[1] * boxes_dim[2];
  165. const float *current_scores_ptr =
  166. scores_data + i * scores_dim[1] * scores_dim[2];
  167. int num = NMSForEachSample(current_boxes_ptr, current_scores_ptr,
  168. boxes_dim[1], scores_dim[1], &indices);
  169. num_nmsed_out += num;
  170. out_num_rois_data[i] = num;
  171. all_indices.emplace_back(indices);
  172. }
  173. std::vector<int64_t> out_box_dims = {num_nmsed_out, 6};
  174. std::vector<int64_t> out_index_dims = {num_nmsed_out, 1};
  175. if (num_nmsed_out == 0) {
  176. for (size_t i = 0; i < batch_size; ++i) {
  177. out_num_rois_data[i] = 0;
  178. }
  179. return;
  180. }
  181. out_box_data.resize(num_nmsed_out * 6);
  182. out_index_data.resize(num_nmsed_out);
  183. int count = 0;
  184. for (size_t i = 0; i < batch_size; ++i) {
  185. const float *current_boxes_ptr =
  186. boxes_data + i * boxes_dim[1] * boxes_dim[2];
  187. const float *current_scores_ptr =
  188. scores_data + i * scores_dim[1] * scores_dim[2];
  189. for (const auto &it : all_indices[i]) {
  190. int label = it.first;
  191. const auto &indices = it.second;
  192. const float *current_scores_class_ptr =
  193. current_scores_ptr + label * scores_dim[2];
  194. for (size_t j = 0; j < indices.size(); ++j) {
  195. int start = count * 6;
  196. out_box_data[start] = label;
  197. out_box_data[start + 1] = current_scores_class_ptr[indices[j]];
  198. out_box_data[start + 2] = current_boxes_ptr[indices[j] * 4];
  199. out_box_data[start + 3] = current_boxes_ptr[indices[j] * 4 + 1];
  200. out_box_data[start + 4] = current_boxes_ptr[indices[j] * 4 + 2];
  201. out_box_data[start + 5] = current_boxes_ptr[indices[j] * 4 + 3];
  202. out_index_data[count] = i * boxes_dim[1] + indices[j];
  203. count += 1;
  204. }
  205. }
  206. }
  207. }
  208. } // namespace detection
  209. } // namespace vision
  210. } // namespace ultra_infer