segmentation_arm.cc 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/visualize/segmentation_arm.h"
  15. #ifdef __ARM_NEON
  16. #include <arm_neon.h>
  17. #endif
  18. namespace ultra_infer {
  19. namespace vision {
  20. static constexpr int _OMP_THREADS = 2;
  21. static inline void QuantizeBlendingWeight8(float weight,
  22. uint8_t *old_multi_factor,
  23. uint8_t *new_multi_factor) {
  24. // Quantize the weight to boost blending performance.
  25. // if 0.0 < w <= 1/8, w ~ 1/8=1/(2^3) shift right 3 mul 1, 7
  26. // if 1/8 < w <= 2/8, w ~ 2/8=1/(2^3) shift right 3 mul 2, 6
  27. // if 2/8 < w <= 3/8, w ~ 3/8=1/(2^3) shift right 3 mul 3, 5
  28. // if 3/8 < w <= 4/8, w ~ 4/8=1/(2^3) shift right 3 mul 4, 4
  29. // Shift factor is always 3, but the mul factor is different.
  30. // Moving 7 bits to the right tends to result in a zero value,
  31. // So, We choose to shift 3 bits to get an approximation.
  32. uint8_t weight_quantize = static_cast<uint8_t>(weight * 8.0f);
  33. *new_multi_factor = weight_quantize;
  34. *old_multi_factor = (8 - weight_quantize);
  35. }
  36. cv::Mat VisSegmentationNEON(const cv::Mat &im, const SegmentationResult &result,
  37. float weight, bool quantize_weight) {
  38. #ifndef __ARM_NEON
  39. FDASSERT(false, "UltraInfer was not compiled with Arm NEON support!")
  40. #else
  41. int64_t height = result.shape[0];
  42. int64_t width = result.shape[1];
  43. auto vis_img = cv::Mat(height, width, CV_8UC3);
  44. int32_t size = static_cast<int32_t>(height * width);
  45. uint8_t *vis_ptr = static_cast<uint8_t *>(vis_img.data);
  46. const uint8_t *label_ptr =
  47. static_cast<const uint8_t *>(result.label_map.data());
  48. const uint8_t *im_ptr = static_cast<const uint8_t *>(im.data);
  49. if (!quantize_weight) {
  50. uint8x16_t zerox16 = vdupq_n_u8(0);
  51. #pragma omp parallel for proc_bind(close) num_threads(_OMP_THREADS)
  52. for (int i = 0; i < size - 15; i += 16) {
  53. uint8x16x3_t bgrx16x3 = vld3q_u8(im_ptr + i * 3); // 48 bytes
  54. uint8x16_t labelx16 = vld1q_u8(label_ptr + i); // 16 bytes
  55. uint8x16_t ibx16 = bgrx16x3.val[0];
  56. uint8x16_t igx16 = bgrx16x3.val[1];
  57. uint8x16_t irx16 = bgrx16x3.val[2];
  58. // e.g 0b00000001 << 7 -> 0b10000000 128;
  59. uint8x16_t mbx16 = vshlq_n_u8(labelx16, 7);
  60. uint8x16_t mgx16 = vshlq_n_u8(labelx16, 4);
  61. uint8x16_t mrx16 = vshlq_n_u8(labelx16, 3);
  62. uint8x16x3_t vbgrx16x3;
  63. // Keep the pixels of input im if mask = 0
  64. uint8x16_t cezx16 = vceqq_u8(labelx16, zerox16);
  65. vbgrx16x3.val[0] = vorrq_u8(vandq_u8(cezx16, ibx16), mbx16);
  66. vbgrx16x3.val[1] = vorrq_u8(vandq_u8(cezx16, igx16), mgx16);
  67. vbgrx16x3.val[2] = vorrq_u8(vandq_u8(cezx16, irx16), mrx16);
  68. vst3q_u8(vis_ptr + i * 3, vbgrx16x3);
  69. }
  70. for (int i = size - 15; i < size; i++) {
  71. uint8_t label = label_ptr[i];
  72. vis_ptr[i * 3 + 0] = (label << 7);
  73. vis_ptr[i * 3 + 1] = (label << 4);
  74. vis_ptr[i * 3 + 2] = (label << 3);
  75. }
  76. // Blend the colors use OpenCV
  77. cv::addWeighted(im, 1.0 - weight, vis_img, weight, 0, vis_img);
  78. return vis_img;
  79. }
  80. // Quantize the weight to boost blending performance.
  81. // After that, we can directly use shift instructions
  82. // to blend the colors from input im and mask. Please
  83. // check QuantizeBlendingWeight8 for more details.
  84. uint8_t old_multi_factor, new_multi_factor;
  85. QuantizeBlendingWeight8(weight, &old_multi_factor, &new_multi_factor);
  86. if (new_multi_factor == 0) {
  87. return im; // Only keep origin image.
  88. }
  89. if (new_multi_factor == 8) {
  90. // Only keep mask, no need to blending with origin image.
  91. #pragma omp parallel for proc_bind(close) num_threads(_OMP_THREADS)
  92. for (int i = 0; i < size - 15; i += 16) {
  93. uint8x16_t labelx16 = vld1q_u8(label_ptr + i); // 16 bytes
  94. // e.g 0b00000001 << 7 -> 0b10000000 128;
  95. uint8x16_t mbx16 = vshlq_n_u8(labelx16, 7);
  96. uint8x16_t mgx16 = vshlq_n_u8(labelx16, 4);
  97. uint8x16_t mrx16 = vshlq_n_u8(labelx16, 3);
  98. uint8x16x3_t vbgr16x3;
  99. vbgr16x3.val[0] = mbx16;
  100. vbgr16x3.val[1] = mgx16;
  101. vbgr16x3.val[2] = mrx16;
  102. vst3q_u8(vis_ptr + i * 3, vbgr16x3);
  103. }
  104. for (int i = size - 15; i < size; i++) {
  105. uint8_t label = label_ptr[i];
  106. vis_ptr[i * 3 + 0] = (label << 7);
  107. vis_ptr[i * 3 + 1] = (label << 4);
  108. vis_ptr[i * 3 + 2] = (label << 3);
  109. }
  110. return vis_img;
  111. }
  112. uint8x16_t zerox16 = vdupq_n_u8(0);
  113. uint8x16_t old_fx16 = vdupq_n_u8(old_multi_factor);
  114. uint8x16_t new_fx16 = vdupq_n_u8(new_multi_factor);
  115. // Blend the two colors together with quantize 'weight'.
  116. #pragma omp parallel for proc_bind(close) num_threads(_OMP_THREADS)
  117. for (int i = 0; i < size - 15; i += 16) {
  118. uint8x16x3_t bgrx16x3 = vld3q_u8(im_ptr + i * 3); // 48 bytes
  119. uint8x16_t labelx16 = vld1q_u8(label_ptr + i); // 16 bytes
  120. uint8x16_t ibx16 = bgrx16x3.val[0];
  121. uint8x16_t igx16 = bgrx16x3.val[1];
  122. uint8x16_t irx16 = bgrx16x3.val[2];
  123. // e.g 0b00000001 << 7 -> 0b10000000 128;
  124. uint8x16_t mbx16 = vshlq_n_u8(labelx16, 7);
  125. uint8x16_t mgx16 = vshlq_n_u8(labelx16, 4);
  126. uint8x16_t mrx16 = vshlq_n_u8(labelx16, 3);
  127. // Moving 7 bits to the right tends to result in zero,
  128. // So, We choose to shift 3 bits to get an approximation
  129. uint8x16_t ibx16_mshr = vmulq_u8(vshrq_n_u8(ibx16, 3), old_fx16);
  130. uint8x16_t igx16_mshr = vmulq_u8(vshrq_n_u8(igx16, 3), old_fx16);
  131. uint8x16_t irx16_mshr = vmulq_u8(vshrq_n_u8(irx16, 3), old_fx16);
  132. uint8x16_t mbx16_mshr = vmulq_u8(vshrq_n_u8(mbx16, 3), new_fx16);
  133. uint8x16_t mgx16_mshr = vmulq_u8(vshrq_n_u8(mgx16, 3), new_fx16);
  134. uint8x16_t mrx16_mshr = vmulq_u8(vshrq_n_u8(mrx16, 3), new_fx16);
  135. uint8x16_t qbx16 = vqaddq_u8(ibx16_mshr, mbx16_mshr);
  136. uint8x16_t qgx16 = vqaddq_u8(igx16_mshr, mgx16_mshr);
  137. uint8x16_t qrx16 = vqaddq_u8(irx16_mshr, mrx16_mshr);
  138. // Keep the pixels of input im if label = 0 (means mask = 0)
  139. uint8x16_t cezx16 = vceqq_u8(labelx16, zerox16);
  140. uint8x16_t abx16 = vandq_u8(cezx16, ibx16);
  141. uint8x16_t agx16 = vandq_u8(cezx16, igx16);
  142. uint8x16_t arx16 = vandq_u8(cezx16, irx16);
  143. uint8x16x3_t vbgr16x3;
  144. // Reset qx values to 0 if label is 0, then, keep mask values
  145. // if label is not 0
  146. uint8x16_t ncezx16 = vmvnq_u8(cezx16);
  147. vbgr16x3.val[0] = vorrq_u8(abx16, vandq_u8(ncezx16, qbx16));
  148. vbgr16x3.val[1] = vorrq_u8(agx16, vandq_u8(ncezx16, qgx16));
  149. vbgr16x3.val[2] = vorrq_u8(arx16, vandq_u8(ncezx16, qrx16));
  150. // Store the blended pixels to vis img
  151. vst3q_u8(vis_ptr + i * 3, vbgr16x3);
  152. }
  153. for (int i = size - 15; i < size; i++) {
  154. uint8_t label = label_ptr[i];
  155. vis_ptr[i * 3 + 0] = (im_ptr[i * 3 + 0] >> 3) * old_multi_factor +
  156. ((label << 7) >> 3) * new_multi_factor;
  157. vis_ptr[i * 3 + 1] = (im_ptr[i * 3 + 1] >> 3) * old_multi_factor +
  158. ((label << 4) >> 3) * new_multi_factor;
  159. vis_ptr[i * 3 + 2] = (im_ptr[i * 3 + 2] >> 3) * old_multi_factor +
  160. ((label << 3) >> 3) * new_multi_factor;
  161. }
  162. return vis_img;
  163. #endif
  164. }
  165. } // namespace vision
  166. } // namespace ultra_infer