classification.cc 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <algorithm>
  15. #include "opencv2/imgproc/imgproc.hpp"
  16. #include "ultra_infer/vision/visualize/visualize.h"
  17. namespace ultra_infer {
  18. namespace vision {
  19. cv::Mat VisClassification(const cv::Mat &im, const ClassifyResult &result,
  20. int top_k, float score_threshold, float font_size) {
  21. int h = im.rows;
  22. int w = im.cols;
  23. auto vis_im = im.clone();
  24. int h_sep = h / 30;
  25. int w_sep = w / 10;
  26. if (top_k > result.scores.size()) {
  27. top_k = result.scores.size();
  28. }
  29. for (int i = 0; i < top_k; ++i) {
  30. if (result.scores[i] < score_threshold) {
  31. continue;
  32. }
  33. std::string id = std::to_string(result.label_ids[i]);
  34. std::string score = std::to_string(result.scores[i]);
  35. if (score.size() > 4) {
  36. score = score.substr(0, 4);
  37. }
  38. std::string text = id + "," + score;
  39. int font = cv::FONT_HERSHEY_SIMPLEX;
  40. cv::Point origin;
  41. origin.x = w_sep;
  42. origin.y = h_sep * (i + 1);
  43. cv::putText(vis_im, text, origin, font, font_size,
  44. cv::Scalar(255, 255, 255), 1);
  45. }
  46. return vis_im;
  47. }
  48. // Visualize ClassifyResult with custom labels.
  49. cv::Mat VisClassification(const cv::Mat &im, const ClassifyResult &result,
  50. const std::vector<std::string> &labels, int top_k,
  51. float score_threshold, float font_size) {
  52. int h = im.rows;
  53. int w = im.cols;
  54. auto vis_im = im.clone();
  55. int h_sep = h / 30;
  56. int w_sep = w / 10;
  57. if (top_k > result.scores.size()) {
  58. top_k = result.scores.size();
  59. }
  60. for (int i = 0; i < top_k; ++i) {
  61. if (result.scores[i] < score_threshold) {
  62. continue;
  63. }
  64. std::string id = std::to_string(result.label_ids[i]);
  65. std::string score = std::to_string(result.scores[i]);
  66. if (score.size() > 4) {
  67. score = score.substr(0, 4);
  68. }
  69. std::string text = id + "," + score;
  70. if (labels.size() > result.label_ids[i]) {
  71. text = labels[result.label_ids[i]] + "," + text;
  72. } else {
  73. FDWARNING << "The label_id: " << result.label_ids[i]
  74. << " in DetectionResult should be less than length of labels:"
  75. << labels.size() << "." << std::endl;
  76. }
  77. if (text.size() > 16) {
  78. text = text.substr(0, 16);
  79. }
  80. int font = cv::FONT_HERSHEY_SIMPLEX;
  81. cv::Point origin;
  82. origin.x = w_sep;
  83. origin.y = h_sep * (i + 1);
  84. cv::putText(vis_im, text, origin, font, font_size,
  85. cv::Scalar(255, 255, 255), 1);
  86. }
  87. return vis_im;
  88. }
  89. } // namespace vision
  90. } // namespace ultra_infer