matting.cc 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "opencv2/highgui.hpp"
  15. #include "opencv2/imgproc/imgproc.hpp"
  16. #include "ultra_infer/vision/visualize/visualize.h"
  17. namespace ultra_infer {
  18. namespace vision {
  19. cv::Mat VisMatting(const cv::Mat &im, const MattingResult &result,
  20. bool transparent_background, float transparent_threshold,
  21. bool remove_small_connected_area) {
  22. FDASSERT((!im.empty()), "im can't be empty!");
  23. FDASSERT((im.channels() == 3), "Only support 3 channels mat!");
  24. auto vis_img = im.clone();
  25. cv::Mat transparent_vis_mat;
  26. int channel = im.channels();
  27. int out_h = static_cast<int>(result.shape[0]);
  28. int out_w = static_cast<int>(result.shape[1]);
  29. int height = im.rows;
  30. int width = im.cols;
  31. std::vector<float> alpha_copy;
  32. alpha_copy.assign(result.alpha.begin(), result.alpha.end());
  33. float *alpha_ptr = static_cast<float *>(alpha_copy.data());
  34. cv::Mat alpha(out_h, out_w, CV_32FC1, alpha_ptr);
  35. if (remove_small_connected_area) {
  36. alpha = RemoveSmallConnectedArea(alpha, 0.05f);
  37. }
  38. if ((out_h != height) || (out_w != width)) {
  39. cv::resize(alpha, alpha, cv::Size(width, height));
  40. }
  41. if ((vis_img).type() != CV_8UC3) {
  42. (vis_img).convertTo((vis_img), CV_8UC3);
  43. }
  44. if (transparent_background) {
  45. if (vis_img.channels() != 4) {
  46. cv::cvtColor(vis_img, transparent_vis_mat, cv::COLOR_BGR2BGRA);
  47. vis_img = transparent_vis_mat;
  48. channel = 4;
  49. }
  50. }
  51. uchar *vis_data = static_cast<uchar *>(vis_img.data);
  52. uchar *im_data = static_cast<uchar *>(im.data);
  53. float *alpha_data = reinterpret_cast<float *>(alpha.data);
  54. for (size_t i = 0; i < height; ++i) {
  55. for (size_t j = 0; j < width; ++j) {
  56. float alpha_val = alpha_data[i * width + j];
  57. if (transparent_background) {
  58. if (alpha_val < transparent_threshold) {
  59. vis_data[i * width * channel + j * channel + 3] =
  60. cv::saturate_cast<uchar>(0.f);
  61. } else {
  62. vis_data[i * width * channel + j * channel + 0] =
  63. cv::saturate_cast<uchar>(
  64. static_cast<float>(im_data[i * width * 3 + j * 3 + 0]));
  65. vis_data[i * width * channel + j * channel + 1] =
  66. cv::saturate_cast<uchar>(
  67. static_cast<float>(im_data[i * width * 3 + j * 3 + 1]));
  68. vis_data[i * width * channel + j * channel + 2] =
  69. cv::saturate_cast<uchar>(
  70. static_cast<float>(im_data[i * width * 3 + j * 3 + 2]));
  71. }
  72. } else {
  73. vis_data[i * width * channel + j * channel + 0] =
  74. cv::saturate_cast<uchar>(
  75. static_cast<float>(im_data[i * width * 3 + j * 3 + 0]) *
  76. alpha_val +
  77. (1.f - alpha_val) * 153.f);
  78. vis_data[i * width * channel + j * channel + 1] =
  79. cv::saturate_cast<uchar>(
  80. static_cast<float>(im_data[i * width * 3 + j * 3 + 1]) *
  81. alpha_val +
  82. (1.f - alpha_val) * 255.f);
  83. vis_data[i * width * channel + j * channel + 2] =
  84. cv::saturate_cast<uchar>(
  85. static_cast<float>(im_data[i * width * 3 + j * 3 + 2]) *
  86. alpha_val +
  87. (1.f - alpha_val) * 120.f);
  88. }
  89. }
  90. }
  91. return vis_img;
  92. }
  93. cv::Mat Visualize::VisMattingAlpha(const cv::Mat &im,
  94. const MattingResult &result,
  95. bool remove_small_connected_area) {
  96. FDWARNING << "DEPRECATED: ultra_infer::vision::Visualize::VisMattingAlpha is "
  97. "deprecated, please use ultra_infer::vision:VisMatting function "
  98. "instead."
  99. << std::endl;
  100. FDASSERT((!im.empty()), "im can't be empty!");
  101. FDASSERT((im.channels() == 3), "Only support 3 channels mat!");
  102. auto vis_img = im.clone();
  103. int out_h = static_cast<int>(result.shape[0]);
  104. int out_w = static_cast<int>(result.shape[1]);
  105. int height = im.rows;
  106. int width = im.cols;
  107. std::vector<float> alpha_copy;
  108. alpha_copy.assign(result.alpha.begin(), result.alpha.end());
  109. float *alpha_ptr = static_cast<float *>(alpha_copy.data());
  110. cv::Mat alpha(out_h, out_w, CV_32FC1, alpha_ptr);
  111. if (remove_small_connected_area) {
  112. alpha = RemoveSmallConnectedArea(alpha, 0.05f);
  113. }
  114. if ((out_h != height) || (out_w != width)) {
  115. cv::resize(alpha, alpha, cv::Size(width, height));
  116. }
  117. if ((vis_img).type() != CV_8UC3) {
  118. (vis_img).convertTo((vis_img), CV_8UC3);
  119. }
  120. uchar *vis_data = static_cast<uchar *>(vis_img.data);
  121. uchar *im_data = static_cast<uchar *>(im.data);
  122. float *alpha_data = reinterpret_cast<float *>(alpha.data);
  123. for (size_t i = 0; i < height; ++i) {
  124. for (size_t j = 0; j < width; ++j) {
  125. float alpha_val = alpha_data[i * width + j];
  126. vis_data[i * width * 3 + j * 3 + 0] = cv::saturate_cast<uchar>(
  127. static_cast<float>(im_data[i * width * 3 + j * 3 + 0]) * alpha_val +
  128. (1.f - alpha_val) * 153.f);
  129. vis_data[i * width * 3 + j * 3 + 1] = cv::saturate_cast<uchar>(
  130. static_cast<float>(im_data[i * width * 3 + j * 3 + 1]) * alpha_val +
  131. (1.f - alpha_val) * 255.f);
  132. vis_data[i * width * 3 + j * 3 + 2] = cv::saturate_cast<uchar>(
  133. static_cast<float>(im_data[i * width * 3 + j * 3 + 2]) * alpha_val +
  134. (1.f - alpha_val) * 120.f);
  135. }
  136. }
  137. return vis_img;
  138. }
  139. } // namespace vision
  140. } // namespace ultra_infer