swap_background_arm.cc 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/visualize/swap_background_arm.h"
  15. #include "ultra_infer/vision/visualize/visualize.h"
  16. #ifdef __ARM_NEON
  17. #include <arm_neon.h>
  18. #endif
  19. #include "ultra_infer/utils/utils.h"
  20. namespace ultra_infer {
  21. namespace vision {
  22. static constexpr int _OMP_THREADS = 2;
  23. cv::Mat SwapBackgroundNEON(const cv::Mat &im, const cv::Mat &background,
  24. const MattingResult &result,
  25. bool remove_small_connected_area) {
  26. #ifndef __ARM_NEON
  27. FDASSERT(false, "UltraInfer was not compiled with Arm NEON support!");
  28. #else
  29. FDASSERT((!im.empty()), "Image can't be empty!");
  30. FDASSERT((im.channels() == 3), "Only support 3 channels image mat!");
  31. FDASSERT((!background.empty()), "Background image can't be empty!");
  32. FDASSERT((background.channels() == 3),
  33. "Only support 3 channels background image mat!");
  34. int out_h = static_cast<int>(result.shape[0]);
  35. int out_w = static_cast<int>(result.shape[1]);
  36. int height = im.rows;
  37. int width = im.cols;
  38. int bg_height = background.rows;
  39. int bg_width = background.cols;
  40. // WARN: may change the original alpha
  41. float *alpha_ptr = const_cast<float *>(result.alpha.data());
  42. cv::Mat alpha(out_h, out_w, CV_32FC1, alpha_ptr);
  43. if (remove_small_connected_area) {
  44. alpha = Visualize::RemoveSmallConnectedArea(alpha, 0.05f);
  45. }
  46. auto vis_img = cv::Mat(height, width, CV_8UC3);
  47. cv::Mat background_ref;
  48. if ((bg_height != height) || (bg_width != width)) {
  49. cv::resize(background, background_ref, cv::Size(width, height));
  50. } else {
  51. background_ref = background; // ref only
  52. }
  53. if ((background_ref).type() != CV_8UC3) {
  54. (background_ref).convertTo((background_ref), CV_8UC3);
  55. }
  56. if ((out_h != height) || (out_w != width)) {
  57. cv::resize(alpha, alpha, cv::Size(width, height));
  58. }
  59. uint8_t *vis_data = static_cast<uint8_t *>(vis_img.data);
  60. const uint8_t *background_data =
  61. static_cast<const uint8_t *>(background_ref.data);
  62. const uint8_t *im_data = static_cast<const uint8_t *>(im.data);
  63. const float *alpha_data = reinterpret_cast<const float *>(alpha.data);
  64. const int32_t size = static_cast<int32_t>(height * width);
  65. #pragma omp parallel for proc_bind(close) num_threads(_OMP_THREADS)
  66. for (int i = 0; i < size - 7; i += 8) {
  67. uint8x8x3_t ibgrx8x3 = vld3_u8(im_data + i * 3); // 24 bytes
  68. // u8 -> u16 -> u32 -> f32
  69. uint16x8_t ibx8 = vmovl_u8(ibgrx8x3.val[0]);
  70. uint16x8_t igx8 = vmovl_u8(ibgrx8x3.val[1]);
  71. uint16x8_t irx8 = vmovl_u8(ibgrx8x3.val[2]);
  72. uint8x8x3_t bbgrx8x3 = vld3_u8(background_data + i * 3); // 24 bytes
  73. uint16x8_t bbx8 = vmovl_u8(bbgrx8x3.val[0]);
  74. uint16x8_t bgx8 = vmovl_u8(bbgrx8x3.val[1]);
  75. uint16x8_t brx8 = vmovl_u8(bbgrx8x3.val[2]);
  76. uint32x4_t hibx4 = vmovl_u16(vget_high_u16(ibx8));
  77. uint32x4_t higx4 = vmovl_u16(vget_high_u16(igx8));
  78. uint32x4_t hirx4 = vmovl_u16(vget_high_u16(irx8));
  79. uint32x4_t libx4 = vmovl_u16(vget_low_u16(ibx8));
  80. uint32x4_t ligx4 = vmovl_u16(vget_low_u16(igx8));
  81. uint32x4_t lirx4 = vmovl_u16(vget_low_u16(irx8));
  82. uint32x4_t hbbx4 = vmovl_u16(vget_high_u16(bbx8));
  83. uint32x4_t hbgx4 = vmovl_u16(vget_high_u16(bgx8));
  84. uint32x4_t hbrx4 = vmovl_u16(vget_high_u16(brx8));
  85. uint32x4_t lbbx4 = vmovl_u16(vget_low_u16(bbx8));
  86. uint32x4_t lbgx4 = vmovl_u16(vget_low_u16(bgx8));
  87. uint32x4_t lbrx4 = vmovl_u16(vget_low_u16(brx8));
  88. float32x4_t fhibx4 = vcvtq_f32_u32(hibx4);
  89. float32x4_t fhigx4 = vcvtq_f32_u32(higx4);
  90. float32x4_t fhirx4 = vcvtq_f32_u32(hirx4);
  91. float32x4_t flibx4 = vcvtq_f32_u32(libx4);
  92. float32x4_t fligx4 = vcvtq_f32_u32(ligx4);
  93. float32x4_t flirx4 = vcvtq_f32_u32(lirx4);
  94. float32x4_t fhbbx4 = vcvtq_f32_u32(hbbx4);
  95. float32x4_t fhbgx4 = vcvtq_f32_u32(hbgx4);
  96. float32x4_t fhbrx4 = vcvtq_f32_u32(hbrx4);
  97. float32x4_t flbbx4 = vcvtq_f32_u32(lbbx4);
  98. float32x4_t flbgx4 = vcvtq_f32_u32(lbgx4);
  99. float32x4_t flbrx4 = vcvtq_f32_u32(lbrx4);
  100. // alpha load from little end
  101. float32x4_t lalpx4 = vld1q_f32(alpha_data + i); // low bits
  102. float32x4_t halpx4 = vld1q_f32(alpha_data + i + 4); // high bits
  103. float32x4_t rlalpx4 = vsubq_f32(vdupq_n_f32(1.0f), lalpx4);
  104. float32x4_t rhalpx4 = vsubq_f32(vdupq_n_f32(1.0f), halpx4);
  105. // blending
  106. float32x4_t fhvbx4 =
  107. vaddq_f32(vmulq_f32(fhibx4, halpx4), vmulq_f32(fhbbx4, rhalpx4));
  108. float32x4_t fhvgx4 =
  109. vaddq_f32(vmulq_f32(fhigx4, halpx4), vmulq_f32(fhbgx4, rhalpx4));
  110. float32x4_t fhvrx4 =
  111. vaddq_f32(vmulq_f32(fhirx4, halpx4), vmulq_f32(fhbrx4, rhalpx4));
  112. float32x4_t flvbx4 =
  113. vaddq_f32(vmulq_f32(flibx4, lalpx4), vmulq_f32(flbbx4, rlalpx4));
  114. float32x4_t flvgx4 =
  115. vaddq_f32(vmulq_f32(fligx4, lalpx4), vmulq_f32(flbgx4, rlalpx4));
  116. float32x4_t flvrx4 =
  117. vaddq_f32(vmulq_f32(flirx4, lalpx4), vmulq_f32(flbrx4, rlalpx4));
  118. // f32 -> u32 -> u16 -> u8
  119. uint8x8x3_t vbgrx8x3;
  120. // combine low 64 bits and high 64 bits into one 128 neon register
  121. vbgrx8x3.val[0] = vmovn_u16(vcombine_u16(vmovn_u32(vcvtq_u32_f32(flvbx4)),
  122. vmovn_u32(vcvtq_u32_f32(fhvbx4))));
  123. vbgrx8x3.val[1] = vmovn_u16(vcombine_u16(vmovn_u32(vcvtq_u32_f32(flvgx4)),
  124. vmovn_u32(vcvtq_u32_f32(fhvgx4))));
  125. vbgrx8x3.val[2] = vmovn_u16(vcombine_u16(vmovn_u32(vcvtq_u32_f32(flvrx4)),
  126. vmovn_u32(vcvtq_u32_f32(fhvrx4))));
  127. vst3_u8(vis_data + i * 3, vbgrx8x3);
  128. }
  129. for (int i = size - 7; i < size; i++) {
  130. float alp = alpha_data[i];
  131. for (int c = 0; c < 3; ++c) {
  132. vis_data[i * 3 + 0] = cv::saturate_cast<uchar>(
  133. static_cast<float>(im_data[i * 3 + c]) * alp +
  134. (1.0f - alp) * static_cast<float>(background_data[i * 3 + c]));
  135. }
  136. }
  137. return vis_img;
  138. #endif
  139. }
  140. cv::Mat SwapBackgroundNEON(const cv::Mat &im, const cv::Mat &background,
  141. const SegmentationResult &result,
  142. int background_label) {
  143. #ifndef __ARM_NEON
  144. FDASSERT(false, "UltraInfer was not compiled with Arm NEON support!")
  145. #else
  146. FDASSERT((!im.empty()), "Image can't be empty!");
  147. FDASSERT((im.channels() == 3), "Only support 3 channels image mat!");
  148. FDASSERT((!background.empty()), "Background image can't be empty!");
  149. FDASSERT((background.channels() == 3),
  150. "Only support 3 channels background image mat!");
  151. int out_h = static_cast<int>(result.shape[0]);
  152. int out_w = static_cast<int>(result.shape[1]);
  153. int height = im.rows;
  154. int width = im.cols;
  155. int bg_height = background.rows;
  156. int bg_width = background.cols;
  157. auto vis_img = cv::Mat(height, width, CV_8UC3);
  158. cv::Mat background_ref;
  159. if ((bg_height != height) || (bg_width != width)) {
  160. cv::resize(background, background_ref, cv::Size(width, height));
  161. } else {
  162. background_ref = background; // ref only
  163. }
  164. if ((background_ref).type() != CV_8UC3) {
  165. (background_ref).convertTo((background_ref), CV_8UC3);
  166. }
  167. uint8_t *vis_data = static_cast<uint8_t *>(vis_img.data);
  168. const uint8_t *background_data =
  169. static_cast<const uint8_t *>(background_ref.data);
  170. const uint8_t *im_data = static_cast<const uint8_t *>(im.data);
  171. const uint8_t *label_data =
  172. static_cast<const uint8_t *>(result.label_map.data());
  173. const uint8_t background_label_ = static_cast<uint8_t>(background_label);
  174. const int32_t size = static_cast<int32_t>(height * width);
  175. uint8x16_t backgroundx16 = vdupq_n_u8(background_label_);
  176. #pragma omp parallel for proc_bind(close) num_threads(_OMP_THREADS)
  177. for (int i = 0; i < size - 15; i += 16) {
  178. uint8x16x3_t ibgr16x3 = vld3q_u8(im_data + i * 3); // 48 bytes
  179. uint8x16x3_t bbgr16x3 = vld3q_u8(background_data + i * 3);
  180. uint8x16_t labelx16 = vld1q_u8(label_data + i); // 16 bytes
  181. // Set mask bit = 1 if label != background_label
  182. uint8x16_t nkeepx16 = vceqq_u8(labelx16, backgroundx16);
  183. uint8x16_t keepx16 = vmvnq_u8(nkeepx16); // keep_value = 1
  184. uint8x16x3_t vbgr16x3;
  185. vbgr16x3.val[0] = vorrq_u8(vandq_u8(ibgr16x3.val[0], keepx16),
  186. vandq_u8(bbgr16x3.val[0], nkeepx16));
  187. vbgr16x3.val[1] = vorrq_u8(vandq_u8(ibgr16x3.val[1], keepx16),
  188. vandq_u8(bbgr16x3.val[1], nkeepx16));
  189. vbgr16x3.val[2] = vorrq_u8(vandq_u8(ibgr16x3.val[2], keepx16),
  190. vandq_u8(bbgr16x3.val[2], nkeepx16));
  191. // Store the blended pixels to vis img
  192. vst3q_u8(vis_data + i * 3, vbgr16x3);
  193. }
  194. for (int i = size - 15; i < size; i++) {
  195. uint8_t label = label_data[i];
  196. if (label != background_label_) {
  197. vis_data[i * 3 + 0] = im_data[i * 3 + 0];
  198. vis_data[i * 3 + 1] = im_data[i * 3 + 1];
  199. vis_data[i * 3 + 2] = im_data[i * 3 + 2];
  200. } else {
  201. vis_data[i * 3 + 0] = background_data[i * 3 + 0];
  202. vis_data[i * 3 + 1] = background_data[i * 3 + 1];
  203. vis_data[i * 3 + 2] = background_data[i * 3 + 2];
  204. }
  205. }
  206. return vis_img;
  207. #endif
  208. }
  209. } // namespace vision
  210. } // namespace ultra_infer