seg_postprocess.cpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "model_deploy/ppseg/include/seg_postprocess.h"
  15. #include <time.h>
  16. namespace PaddleDeploy {
  17. bool SegPostprocess::Init(const YAML::Node& yaml_config) {
  18. if (yaml_config["version"].IsDefined() &&
  19. yaml_config["toolkit"].as<std::string>() == "PaddleX") {
  20. version_ = yaml_config["version"].as<std::string>();
  21. } else {
  22. version_ = "0.0.0";
  23. }
  24. return true;
  25. }
  26. void SegPostprocess::RestoreSegMap(const ShapeInfo& shape_info,
  27. cv::Mat* label_mat, cv::Mat* score_mat,
  28. SegResult* result) {
  29. int ori_h = shape_info.shapes[0][1];
  30. int ori_w = shape_info.shapes[0][0];
  31. int score_c = score_mat->channels();
  32. result->label_map.Resize({ori_h, ori_w});
  33. if (score_c == 1) {
  34. result->score_map.Resize({ori_h, ori_w});
  35. } else {
  36. result->score_map.Resize({ori_h, ori_w, score_c});
  37. }
  38. for (int j = shape_info.transforms.size() - 1; j > 0; --j) {
  39. std::vector<int> last_shape = shape_info.shapes[j - 1];
  40. std::vector<int> cur_shape = shape_info.shapes[j];
  41. if (shape_info.transforms[j] == "Resize" ||
  42. shape_info.transforms[j] == "ResizeByShort" ||
  43. shape_info.transforms[j] == "ResizeByLong") {
  44. if (last_shape[0] != label_mat->cols ||
  45. last_shape[1] != label_mat->rows) {
  46. cv::resize(*label_mat, *label_mat,
  47. cv::Size(last_shape[0], last_shape[1]), 0, 0,
  48. cv::INTER_NEAREST);
  49. cv::resize(*score_mat, *score_mat,
  50. cv::Size(last_shape[0], last_shape[1]), 0, 0,
  51. cv::INTER_LINEAR);
  52. }
  53. } else if (shape_info.transforms[j] == "Padding") {
  54. if (last_shape[0] < label_mat->cols || last_shape[1] < label_mat->rows) {
  55. *label_mat = (*label_mat)(cv::Rect(0, 0, last_shape[0], last_shape[1]));
  56. *score_mat = (*score_mat)(cv::Rect(0, 0, last_shape[0], last_shape[1]));
  57. }
  58. }
  59. }
  60. result->label_map.data.assign(label_mat->begin<uint8_t>(),
  61. label_mat->end<uint8_t>());
  62. result->score_map.data.assign(score_mat->begin<float>(),
  63. score_mat->end<float>());
  64. }
  65. // ppseg version >= 2.1 shape = [b, w, h]
  66. bool SegPostprocess::RunV2(const DataBlob& output,
  67. const std::vector<ShapeInfo>& shape_infos,
  68. std::vector<Result>* results, int thread_num) {
  69. int batch_size = shape_infos.size();
  70. int label_map_size = output.shape[1] * output.shape[2];
  71. const uint8_t* label_data;
  72. std::vector<uint8_t> label_vector;
  73. if (output.dtype == INT64) { // int64
  74. const int64_t* output_data =
  75. reinterpret_cast<const int64_t*>(output.data.data());
  76. std::transform(output_data, output_data + label_map_size * batch_size,
  77. std::back_inserter(label_vector),
  78. [](int64_t x) { return (uint8_t)x; });
  79. label_data = reinterpret_cast<const uint8_t*>(label_vector.data());
  80. } else if (output.dtype == INT32) { // int32
  81. const int32_t* output_data =
  82. reinterpret_cast<const int32_t*>(output.data.data());
  83. std::transform(output_data, output_data + label_map_size * batch_size,
  84. std::back_inserter(label_vector),
  85. [](int32_t x) { return (uint8_t)x; });
  86. label_data = reinterpret_cast<const uint8_t*>(label_vector.data());
  87. } else if (output.dtype == INT8) { // uint8
  88. label_data = reinterpret_cast<const uint8_t*>(output.data.data());
  89. } else {
  90. std::cerr << "Output dtype is not support on seg posrtprocess "
  91. << output.dtype << std::endl;
  92. return false;
  93. }
  94. for (int i = 0; i < batch_size; ++i) {
  95. (*results)[i].model_type = "seg";
  96. (*results)[i].seg_result = new SegResult();
  97. const uint8_t* current_start_ptr = label_data + i * label_map_size;
  98. cv::Mat score_mat(output.shape[1], output.shape[2], CV_32FC1,
  99. cv::Scalar(1.0));
  100. cv::Mat label_mat(output.shape[1], output.shape[2], CV_8UC1,
  101. const_cast<uint8_t*>(current_start_ptr));
  102. RestoreSegMap(shape_infos[i], &label_mat, &score_mat,
  103. (*results)[i].seg_result);
  104. }
  105. return true;
  106. }
  107. // paddlex version >= 2.0.0 shape = [b, h, w, c]
  108. bool SegPostprocess::RunXV2(const std::vector<DataBlob>& outputs,
  109. const std::vector<ShapeInfo>& shape_infos,
  110. std::vector<Result>* results, int thread_num) {
  111. int batch_size = shape_infos.size();
  112. int label_map_size = outputs[0].shape[1] * outputs[1].shape[2];
  113. std::vector<int> score_map_shape = outputs[1].shape;
  114. int score_map_size =
  115. std::accumulate(score_map_shape.begin() + 1, score_map_shape.end(), 1,
  116. std::multiplies<int>());
  117. const uint8_t* label_map_data;
  118. std::vector<uint8_t> label_map_vector;
  119. if (outputs[0].dtype == INT32) {
  120. const int32_t* output_data =
  121. reinterpret_cast<const int32_t*>(outputs[0].data.data());
  122. std::transform(output_data, output_data + label_map_size * batch_size,
  123. std::back_inserter(label_map_vector),
  124. [](int32_t x) { return (uint8_t)x; });
  125. label_map_data = reinterpret_cast<const uint8_t*>(label_map_vector.data());
  126. }
  127. const float* score_map_data =
  128. reinterpret_cast<const float*>(outputs[index].data.data());
  129. for (int i = 0; i < batch_size; ++i) {
  130. (*results)[i].model_type = "seg";
  131. (*results)[i].seg_result = new SegResult();
  132. const uint8_t* current_label_start_ptr =
  133. label_map_data + i * label_map_size;
  134. const float* current_score_start_ptr = score_map_data + i * score_map_size;
  135. cv::Mat label_mat(outputs[0].shape[1], outputs[0].shape[2], CV_8UC1,
  136. const_cast<uint8_t*>(current_label_start_ptr));
  137. cv::Mat score_mat(score_map_shape[1], score_map_shape[2], CV_32FC(n),
  138. const_cast<float*>(current_score_start_ptr));
  139. RestoreSegMap(shape_infos[i], &label_mat, &score_mat,
  140. (*results)[i].seg_result);
  141. }
  142. return true;
  143. }
  144. bool SegPostprocess::Run(const std::vector<DataBlob>& outputs,
  145. const std::vector<ShapeInfo>& shape_infos,
  146. std::vector<Result>* results, int thread_num) {
  147. if (outputs.size() == 0) {
  148. std::cerr << "empty output on SegPostprocess" << std::endl;
  149. return true;
  150. }
  151. results->clear();
  152. int batch_size = shape_infos.size();
  153. results->resize(batch_size);
  154. // tricks for PaddleX, of which segmentation model has two outputs
  155. int index = 0;
  156. if (outputs.size() == 2) {
  157. index = 1;
  158. }
  159. std::vector<int> score_map_shape = outputs[index].shape;
  160. // paddlex version >= 2.0.0 shape[b, h, w, c]
  161. if (version_ >= "2.0.0") {
  162. return RunXV2(outputs, shape_infos, results, thread_num);
  163. }
  164. // ppseg version >= 2.1 shape = [b, h, w]
  165. if (score_map_shape.size() == 3) {
  166. return RunV2(outputs[index], shape_infos, results, thread_num);
  167. }
  168. int score_map_size =
  169. std::accumulate(score_map_shape.begin() + 1, score_map_shape.end(), 1,
  170. std::multiplies<int>());
  171. const float* score_map_data =
  172. reinterpret_cast<const float*>(outputs[index].data.data());
  173. int num_map_pixels = score_map_shape[2] * score_map_shape[3];
  174. for (int i = 0; i < batch_size; ++i) {
  175. (*results)[i].model_type = "seg";
  176. (*results)[i].seg_result = new SegResult();
  177. const float* current_start_ptr = score_map_data + i * score_map_size;
  178. cv::Mat ori_score_mat(score_map_shape[1],
  179. score_map_shape[2] * score_map_shape[3], CV_32FC1,
  180. const_cast<float*>(current_start_ptr));
  181. ori_score_mat = ori_score_mat.t();
  182. cv::Mat score_mat(score_map_shape[2], score_map_shape[3], CV_32FC1);
  183. cv::Mat label_mat(score_map_shape[2], score_map_shape[3], CV_8UC1);
  184. for (int j = 0; j < ori_score_mat.rows; ++j) {
  185. double max_value;
  186. cv::Point max_id;
  187. minMaxLoc(ori_score_mat.row(j), 0, &max_value, 0, &max_id);
  188. score_mat.at<float>(j) = max_value;
  189. label_mat.at<uchar>(j) = max_id.x;
  190. }
  191. RestoreSegMap(shape_infos[i], &label_mat, &score_mat,
  192. (*results)[i].seg_result);
  193. }
  194. return true;
  195. }
  196. } // namespace PaddleDeploy