detection.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <algorithm>
  15. #include "opencv2/imgproc/imgproc.hpp"
  16. #include "ultra_infer/vision/visualize/visualize.h"
  17. namespace ultra_infer {
  18. namespace vision {
  19. cv::Mat VisDetection(const cv::Mat &im, const DetectionResult &result,
  20. float score_threshold, int line_size, float font_size) {
  21. if (result.boxes.empty() && result.rotated_boxes.empty()) {
  22. return im;
  23. }
  24. if (result.contain_masks) {
  25. FDASSERT(result.boxes.size() == result.masks.size(),
  26. "The size of masks must be equal to the size of boxes, but now "
  27. "%zu != %zu.",
  28. result.boxes.size(), result.masks.size());
  29. }
  30. int max_label_id =
  31. *std::max_element(result.label_ids.begin(), result.label_ids.end());
  32. std::vector<int> color_map = GenerateColorMap(max_label_id);
  33. int h = im.rows;
  34. int w = im.cols;
  35. auto vis_im = im.clone();
  36. for (size_t i = 0; i < result.rotated_boxes.size(); ++i) {
  37. if (result.scores[i] < score_threshold) {
  38. continue;
  39. }
  40. int c0 = color_map[3 * result.label_ids[i] + 0];
  41. int c1 = color_map[3 * result.label_ids[i] + 1];
  42. int c2 = color_map[3 * result.label_ids[i] + 2];
  43. cv::Scalar rect_color = cv::Scalar(c0, c1, c2);
  44. std::string id = std::to_string(result.label_ids[i]);
  45. std::string score = std::to_string(result.scores[i]);
  46. if (score.size() > 4) {
  47. score = score.substr(0, 4);
  48. }
  49. std::string text = id + ", " + score;
  50. int font = cv::FONT_HERSHEY_SIMPLEX;
  51. cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
  52. for (int j = 0; j < 4; j++) {
  53. auto start = cv::Point(
  54. static_cast<int>(round(result.rotated_boxes[i][2 * j])),
  55. static_cast<int>(round(result.rotated_boxes[i][2 * j + 1])));
  56. cv::Point end;
  57. if (j != 3) {
  58. end = cv::Point(
  59. static_cast<int>(round(result.rotated_boxes[i][2 * (j + 1)])),
  60. static_cast<int>(round(result.rotated_boxes[i][2 * (j + 1) + 1])));
  61. } else {
  62. end = cv::Point(static_cast<int>(round(result.rotated_boxes[i][0])),
  63. static_cast<int>(round(result.rotated_boxes[i][1])));
  64. cv::putText(vis_im, text, end, font, font_size,
  65. cv::Scalar(255, 255, 255), 1);
  66. }
  67. cv::line(vis_im, start, end, cv::Scalar(255, 255, 255), 3, cv::LINE_AA,
  68. 0);
  69. }
  70. }
  71. for (size_t box_i = 0; box_i < result.boxes.size(); ++box_i) {
  72. if (result.scores[box_i] < score_threshold) {
  73. continue;
  74. }
  75. int x1 = static_cast<int>(round(result.boxes[box_i][0]));
  76. int y1 = static_cast<int>(round(result.boxes[box_i][1]));
  77. int x2 = static_cast<int>(round(result.boxes[box_i][2]));
  78. int y2 = static_cast<int>(round(result.boxes[box_i][3]));
  79. int box_h = y2 - y1;
  80. int box_w = x2 - x1;
  81. int c0 = color_map[3 * result.label_ids[box_i] + 0];
  82. int c1 = color_map[3 * result.label_ids[box_i] + 1];
  83. int c2 = color_map[3 * result.label_ids[box_i] + 2];
  84. cv::Scalar rect_color = cv::Scalar(c0, c1, c2);
  85. std::string id = std::to_string(result.label_ids[box_i]);
  86. std::string score = std::to_string(result.scores[box_i]);
  87. if (score.size() > 4) {
  88. score = score.substr(0, 4);
  89. }
  90. std::string text = id + ", " + score;
  91. int font = cv::FONT_HERSHEY_SIMPLEX;
  92. cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
  93. cv::Point origin;
  94. origin.x = x1;
  95. origin.y = y1;
  96. cv::Rect rect(x1, y1, box_w, box_h);
  97. cv::rectangle(vis_im, rect, rect_color, line_size);
  98. cv::putText(vis_im, text, origin, font, font_size,
  99. cv::Scalar(255, 255, 255), 1);
  100. if (result.contain_masks) {
  101. int mask_h = static_cast<int>(result.masks[box_i].shape[0]);
  102. int mask_w = static_cast<int>(result.masks[box_i].shape[1]);
  103. // non-const pointer for cv:Mat constructor
  104. uint32_t *mask_raw_data = const_cast<uint32_t *>(
  105. static_cast<const uint32_t *>(result.masks[box_i].Data()));
  106. // only reference to mask data (zero copy)
  107. cv::Mat mask(mask_h, mask_w, CV_32SC1, mask_raw_data);
  108. if ((mask_h != box_h) || (mask_w != box_w)) {
  109. cv::resize(mask, mask, cv::Size(box_w, box_h));
  110. }
  111. // use a bright color for instance mask
  112. int mc0 = 255 - c0 >= 127 ? 255 - c0 : 127;
  113. int mc1 = 255 - c1 >= 127 ? 255 - c1 : 127;
  114. int mc2 = 255 - c2 >= 127 ? 255 - c2 : 127;
  115. uint32_t *mask_data = reinterpret_cast<uint32_t *>(mask.data);
  116. // inplace blending (zero copy)
  117. uchar *vis_im_data = static_cast<uchar *>(vis_im.data);
  118. for (size_t i = y1; i < y2; ++i) {
  119. for (size_t j = x1; j < x2; ++j) {
  120. if (mask_data[(i - y1) * mask_w + (j - x1)] != 0) {
  121. vis_im_data[i * w * 3 + j * 3 + 0] = cv::saturate_cast<uchar>(
  122. static_cast<float>(mc0) * 0.5f +
  123. static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 0]) * 0.5f);
  124. vis_im_data[i * w * 3 + j * 3 + 1] = cv::saturate_cast<uchar>(
  125. static_cast<float>(mc1) * 0.5f +
  126. static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 1]) * 0.5f);
  127. vis_im_data[i * w * 3 + j * 3 + 2] = cv::saturate_cast<uchar>(
  128. static_cast<float>(mc2) * 0.5f +
  129. static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 2]) * 0.5f);
  130. }
  131. }
  132. }
  133. }
  134. }
  135. return vis_im;
  136. }
  137. // Visualize DetectionResult with custom labels.
  138. cv::Mat VisDetection(const cv::Mat &im, const DetectionResult &result,
  139. const std::vector<std::string> &labels,
  140. float score_threshold, int line_size, float font_size,
  141. std::vector<int> font_color, int font_thickness) {
  142. if (result.boxes.empty()) {
  143. return im;
  144. }
  145. if (result.contain_masks) {
  146. FDASSERT(result.boxes.size() == result.masks.size(),
  147. "The size of masks must be equal to the size of boxes, but now "
  148. "%zu != %zu.",
  149. result.boxes.size(), result.masks.size());
  150. }
  151. int max_label_id =
  152. *std::max_element(result.label_ids.begin(), result.label_ids.end());
  153. std::vector<int> color_map = GenerateColorMap(max_label_id);
  154. int h = im.rows;
  155. int w = im.cols;
  156. auto vis_im = im.clone();
  157. auto font_color_ = cv::Scalar(font_color[0], font_color[1], font_color[2]);
  158. for (size_t i = 0; i < result.rotated_boxes.size(); ++i) {
  159. if (result.scores[i] < score_threshold) {
  160. continue;
  161. }
  162. int c0 = color_map[3 * result.label_ids[i] + 0];
  163. int c1 = color_map[3 * result.label_ids[i] + 1];
  164. int c2 = color_map[3 * result.label_ids[i] + 2];
  165. cv::Scalar rect_color = cv::Scalar(c0, c1, c2);
  166. std::string id = std::to_string(result.label_ids[i]);
  167. std::string score = std::to_string(result.scores[i]);
  168. if (score.size() > 4) {
  169. score = score.substr(0, 4);
  170. }
  171. std::string text = id + ", " + score;
  172. int font = cv::FONT_HERSHEY_SIMPLEX;
  173. cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
  174. for (int j = 0; j < 4; j++) {
  175. auto start = cv::Point(
  176. static_cast<int>(round(result.rotated_boxes[i][2 * j])),
  177. static_cast<int>(round(result.rotated_boxes[i][2 * j + 1])));
  178. cv::Point end;
  179. if (j == 3) {
  180. end = cv::Point(
  181. static_cast<int>(round(result.rotated_boxes[i][2 * j])),
  182. static_cast<int>(round(result.rotated_boxes[i][2 * j + 1])));
  183. } else {
  184. end = cv::Point(static_cast<int>(round(result.rotated_boxes[i][0])),
  185. static_cast<int>(round(result.rotated_boxes[i][1])));
  186. cv::putText(vis_im, text, end, font, font_size, font_color_,
  187. font_thickness);
  188. }
  189. cv::line(vis_im, start, end, cv::Scalar(255, 255, 255), 3, cv::LINE_AA,
  190. 0);
  191. }
  192. }
  193. for (size_t i = 0; i < result.boxes.size(); ++i) {
  194. if (result.scores[i] < score_threshold) {
  195. continue;
  196. }
  197. int x1 = static_cast<int>(result.boxes[i][0]);
  198. int y1 = static_cast<int>(result.boxes[i][1]);
  199. int x2 = static_cast<int>(result.boxes[i][2]);
  200. int y2 = static_cast<int>(result.boxes[i][3]);
  201. int box_h = y2 - y1;
  202. int box_w = x2 - x1;
  203. int c0 = color_map[3 * result.label_ids[i] + 0];
  204. int c1 = color_map[3 * result.label_ids[i] + 1];
  205. int c2 = color_map[3 * result.label_ids[i] + 2];
  206. cv::Scalar rect_color = cv::Scalar(c0, c1, c2);
  207. std::string id = std::to_string(result.label_ids[i]);
  208. std::string score = std::to_string(result.scores[i]);
  209. if (score.size() > 4) {
  210. score = score.substr(0, 4);
  211. }
  212. std::string text = id + "," + score;
  213. if (labels.size() > result.label_ids[i]) {
  214. text = labels[result.label_ids[i]] + "," + text;
  215. } else {
  216. FDWARNING << "The label_id: " << result.label_ids[i]
  217. << " in DetectionResult should be less than length of labels:"
  218. << labels.size() << "." << std::endl;
  219. }
  220. if (text.size() > 16) {
  221. text = text.substr(0, 16);
  222. }
  223. int font = cv::FONT_HERSHEY_SIMPLEX;
  224. cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
  225. cv::Point origin;
  226. origin.x = x1;
  227. origin.y = y1;
  228. cv::Rect rect(x1, y1, box_w, box_h);
  229. cv::rectangle(vis_im, rect, rect_color, line_size);
  230. cv::putText(vis_im, text, origin, font, font_size, font_color_,
  231. font_thickness);
  232. if (result.contain_masks) {
  233. int mask_h = static_cast<int>(result.masks[i].shape[0]);
  234. int mask_w = static_cast<int>(result.masks[i].shape[1]);
  235. // non-const pointer for cv:Mat constructor
  236. int32_t *mask_raw_data = const_cast<int32_t *>(
  237. static_cast<const int32_t *>(result.masks[i].Data()));
  238. // only reference to mask data (zero copy)
  239. cv::Mat mask(mask_h, mask_w, CV_32SC1, mask_raw_data);
  240. if ((mask_h != box_h) || (mask_w != box_w)) {
  241. cv::resize(mask, mask, cv::Size(box_w, box_h));
  242. }
  243. // use a bright color for instance mask
  244. int mc0 = 255 - c0 >= 127 ? 255 - c0 : 127;
  245. int mc1 = 255 - c1 >= 127 ? 255 - c1 : 127;
  246. int mc2 = 255 - c2 >= 127 ? 255 - c2 : 127;
  247. int32_t *mask_data = reinterpret_cast<int32_t *>(mask.data);
  248. // inplace blending (zero copy)
  249. uchar *vis_im_data = static_cast<uchar *>(vis_im.data);
  250. for (size_t i = y1; i < y2; ++i) {
  251. for (size_t j = x1; j < x2; ++j) {
  252. if (mask_data[(i - y1) * mask_w + (j - x1)] != 0) {
  253. vis_im_data[i * w * 3 + j * 3 + 0] = cv::saturate_cast<uchar>(
  254. static_cast<float>(mc0) * 0.5f +
  255. static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 0]) * 0.5f);
  256. vis_im_data[i * w * 3 + j * 3 + 1] = cv::saturate_cast<uchar>(
  257. static_cast<float>(mc1) * 0.5f +
  258. static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 1]) * 0.5f);
  259. vis_im_data[i * w * 3 + j * 3 + 2] = cv::saturate_cast<uchar>(
  260. static_cast<float>(mc2) * 0.5f +
  261. static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 2]) * 0.5f);
  262. }
  263. }
  264. }
  265. }
  266. }
  267. return vis_im;
  268. }
  269. // Default only support visualize num_classes <= 1000
  270. // If need to visualize num_classes > 1000
  271. // Please call Visualize::GetColorMap(num_classes) first
  272. cv::Mat Visualize::VisDetection(const cv::Mat &im,
  273. const DetectionResult &result,
  274. float score_threshold, int line_size,
  275. float font_size) {
  276. if (result.boxes.empty()) {
  277. return im;
  278. }
  279. FDWARNING << "DEPRECATED: ultra_infer::vision::Visualize::VisDetection is "
  280. "deprecated, please use ultra_infer::vision:VisDetection "
  281. "function instead."
  282. << std::endl;
  283. if (result.contain_masks) {
  284. FDASSERT(result.boxes.size() == result.masks.size(),
  285. "The size of masks must be equal the size of boxes!");
  286. }
  287. auto color_map = GetColorMap();
  288. int h = im.rows;
  289. int w = im.cols;
  290. auto vis_im = im.clone();
  291. for (size_t i = 0; i < result.boxes.size(); ++i) {
  292. if (result.scores[i] < score_threshold) {
  293. continue;
  294. }
  295. int x1 = static_cast<int>(result.boxes[i][0]);
  296. int y1 = static_cast<int>(result.boxes[i][1]);
  297. int x2 = static_cast<int>(result.boxes[i][2]);
  298. int y2 = static_cast<int>(result.boxes[i][3]);
  299. int box_h = y2 - y1;
  300. int box_w = x2 - x1;
  301. int c0 = color_map[3 * result.label_ids[i] + 0];
  302. int c1 = color_map[3 * result.label_ids[i] + 1];
  303. int c2 = color_map[3 * result.label_ids[i] + 2];
  304. cv::Scalar rect_color = cv::Scalar(c0, c1, c2);
  305. std::string id = std::to_string(result.label_ids[i]);
  306. std::string score = std::to_string(result.scores[i]);
  307. if (score.size() > 4) {
  308. score = score.substr(0, 4);
  309. }
  310. std::string text = id + "," + score;
  311. int font = cv::FONT_HERSHEY_SIMPLEX;
  312. cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
  313. cv::Point origin;
  314. origin.x = x1;
  315. origin.y = y1;
  316. cv::Rect rect(x1, y1, box_w, box_h);
  317. cv::rectangle(vis_im, rect, rect_color, line_size);
  318. cv::putText(vis_im, text, origin, font, font_size,
  319. cv::Scalar(255, 255, 255), 1);
  320. if (result.contain_masks) {
  321. int mask_h = static_cast<int>(result.masks[i].shape[0]);
  322. int mask_w = static_cast<int>(result.masks[i].shape[1]);
  323. // non-const pointer for cv:Mat constructor
  324. int32_t *mask_raw_data = const_cast<int32_t *>(
  325. static_cast<const int32_t *>(result.masks[i].Data()));
  326. // only reference to mask data (zero copy)
  327. cv::Mat mask(mask_h, mask_w, CV_32SC1, mask_raw_data);
  328. if ((mask_h != box_h) || (mask_w != box_w)) {
  329. cv::resize(mask, mask, cv::Size(box_w, box_h));
  330. }
  331. // use a bright color for instance mask
  332. int mc0 = 255 - c0 >= 127 ? 255 - c0 : 127;
  333. int mc1 = 255 - c1 >= 127 ? 255 - c1 : 127;
  334. int mc2 = 255 - c2 >= 127 ? 255 - c2 : 127;
  335. int32_t *mask_data = reinterpret_cast<int32_t *>(mask.data);
  336. // inplace blending (zero copy)
  337. uchar *vis_im_data = static_cast<uchar *>(vis_im.data);
  338. for (size_t i = y1; i < y2; ++i) {
  339. for (size_t j = x1; j < x2; ++j) {
  340. if (mask_data[(i - y1) * mask_w + (j - x1)] != 0) {
  341. vis_im_data[i * w * 3 + j * 3 + 0] = cv::saturate_cast<uchar>(
  342. static_cast<float>(mc0) * 0.5f +
  343. static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 0]) * 0.5f);
  344. vis_im_data[i * w * 3 + j * 3 + 1] = cv::saturate_cast<uchar>(
  345. static_cast<float>(mc1) * 0.5f +
  346. static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 1]) * 0.5f);
  347. vis_im_data[i * w * 3 + j * 3 + 2] = cv::saturate_cast<uchar>(
  348. static_cast<float>(mc2) * 0.5f +
  349. static_cast<float>(vis_im_data[i * w * 3 + j * 3 + 2]) * 0.5f);
  350. }
  351. }
  352. }
  353. }
  354. }
  355. return vis_im;
  356. }
  357. } // namespace vision
  358. } // namespace ultra_infer