det_postprocess.cpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "model_deploy/ppdet/include/det_postprocess.h"
  15. namespace PaddleDeploy {
  16. bool DetPostprocess::Init(const YAML::Node& yaml_config) {
  17. labels_.clear();
  18. for (auto item : yaml_config["labels"]) {
  19. std::string label = item.as<std::string>();
  20. labels_.push_back(label);
  21. }
  22. version_ = yaml_config["version"].as<std::string>();
  23. return true;
  24. }
  25. bool DetPostprocess::ProcessBbox(const std::vector<DataBlob>& outputs,
  26. const std::vector<ShapeInfo>& shape_infos,
  27. std::vector<Result>* results, int thread_num) {
  28. const float* data = reinterpret_cast<const float*>(outputs[0].data.data());
  29. std::vector<int> num_bboxes_each_sample;
  30. if (outputs[0].lod.empty()) {
  31. for (auto i = 0; i < shape_infos.size(); ++i) {
  32. num_bboxes_each_sample.push_back(
  33. outputs[0].shape[0]/static_cast<int>(shape_infos.size()));
  34. }
  35. } else {
  36. for (auto i = 0; i < outputs[0].lod[0].size() - 1; ++i) {
  37. int num = outputs[0].lod[0][i + 1] - outputs[0].lod[0][i];
  38. num_bboxes_each_sample.push_back(num);
  39. }
  40. }
  41. int idx = 0;
  42. for (auto i = 0; i < num_bboxes_each_sample.size(); ++i) {
  43. (*results)[i].model_type = "det";
  44. (*results)[i].det_result = new DetResult();
  45. for (auto j = 0; j < num_bboxes_each_sample[i]; ++j) {
  46. Box box;
  47. box.category_id = static_cast<int>(round(data[idx * 6]));
  48. if (box.category_id < 0) {
  49. std::cerr << "Compute category id is less than 0"
  50. << "(Maybe no object detected)" << std::endl;
  51. return true;
  52. }
  53. if (box.category_id >= labels_.size()) {
  54. std::cerr << "Compute category id is greater than labels "
  55. << "in your config file" << std::endl;
  56. std::cerr << "Compute Category ID: " << box.category_id
  57. << ", but length of labels is " << labels_.size()
  58. << std::endl;
  59. return false;
  60. }
  61. box.category = labels_[box.category_id];
  62. box.score = data[idx * 6 + 1];
  63. // TODO(jiangjiajun): only for RCNN and YOLO
  64. // lack of process for SSD and Face
  65. float xmin = data[idx * 6 + 2];
  66. float ymin = data[idx * 6 + 3];
  67. float xmax = data[idx * 6 + 4];
  68. float ymax = data[idx * 6 + 5];
  69. box.coordinate = {xmin, ymin, xmax - xmin, ymax - ymin};
  70. (*results)[i].det_result->boxes.push_back(std::move(box));
  71. idx += 1;
  72. }
  73. }
  74. return true;
  75. }
  76. bool DetPostprocess::ProcessMask(DataBlob* mask_blob,
  77. const std::vector<ShapeInfo>& shape_infos,
  78. std::vector<Result>* results,
  79. float threshold) {
  80. std::vector<int> output_mask_shape = mask_blob->shape;
  81. float *mask_data = reinterpret_cast<float*>(mask_blob->data.data());
  82. int mask_pixels = output_mask_shape[2] * output_mask_shape[3];
  83. int classes = output_mask_shape[1];
  84. for (auto i = 0; i < results->size(); ++i) {
  85. (*results)[i].det_result->mask_resolution = output_mask_shape[2];
  86. for (auto j = 0; j < (*results)[i].det_result->boxes.size(); ++j) {
  87. Box *box = &(*results)[i].det_result->boxes[j];
  88. auto begin_mask_data = mask_data + box->category_id * mask_pixels;
  89. cv::Mat bin_mask(output_mask_shape[2],
  90. output_mask_shape[3],
  91. CV_32FC1,
  92. begin_mask_data);
  93. // expand box
  94. cv::Scalar value = cv::Scalar(0.0);
  95. cv::copyMakeBorder(bin_mask, bin_mask,
  96. 1, 1, 1, 1,
  97. cv::BORDER_CONSTANT,
  98. value = value);
  99. int max_w = shape_infos[i].shapes[0][0];
  100. int max_h = shape_infos[i].shapes[0][1];
  101. double scale = (output_mask_shape[2] + 2.0) / output_mask_shape[2];
  102. double w_half = static_cast<double>(box->coordinate[2]) * 0.5;
  103. double h_half = static_cast<double>(box->coordinate[3]) * 0.5;
  104. double x_c = static_cast<double>(box->coordinate[0]) + w_half;
  105. double y_c = static_cast<double>(box->coordinate[1]) + h_half;
  106. w_half *= scale;
  107. h_half *= scale;
  108. int x_min = static_cast<int>(x_c - w_half);
  109. int x_max = static_cast<int>(x_c + w_half);
  110. int y_min = static_cast<int>(y_c - h_half);
  111. int y_max = static_cast<int>(y_c + h_half);
  112. cv::resize(bin_mask, bin_mask,
  113. cv::Size(std::max(x_max - x_min + 1, 1),
  114. std::max(y_max - y_min + 1, 1)));
  115. cv::threshold(bin_mask, bin_mask, threshold, 1, cv::THRESH_BINARY);
  116. bin_mask.convertTo(bin_mask, CV_8UC1);
  117. int x0 = std::min(std::max(x_min, 0), max_w);
  118. int x1 = std::min(std::max(x_max + 1, 0), max_w);
  119. int y0 = std::min(std::max(y_min, 0), max_h);
  120. int y1 = std::min(std::max(y_max + 1, 0), max_h);
  121. cv::Mat mask_mat = bin_mask(cv::Range(y0 - y_min, y1 - y_min),
  122. cv::Range(x0 - x_min, x1 - x_min));
  123. // expand image
  124. cv::copyMakeBorder(mask_mat, mask_mat,
  125. max_h - y1,
  126. y0,
  127. x0,
  128. max_w - x1,
  129. cv::BORDER_CONSTANT,
  130. value = value);
  131. box->mask.Clear();
  132. box->mask.shape = {max_h, max_w};
  133. if (mask_mat.isContinuous()) {
  134. box->mask.data.assign(mask_mat.datastart, mask_mat.dataend);
  135. } else {
  136. for (auto i = 0; i < mask_mat.rows; ++i) {
  137. box->mask.data.insert(box->mask.data.end(),
  138. mask_mat.ptr<uint8_t>(i),
  139. mask_mat.ptr<uint8_t>(i) + mask_mat.cols);
  140. }
  141. }
  142. mask_data += classes * mask_pixels;
  143. }
  144. }
  145. return true;
  146. }
  147. bool DetPostprocess::ProcessMaskV2(DataBlob* mask_blob,
  148. const std::vector<ShapeInfo>& shape_infos,
  149. std::vector<Result>* results) {
  150. std::vector<int> output_mask_shape = mask_blob->shape;
  151. float *mask_data = reinterpret_cast<float*>(mask_blob->data.data());
  152. int mask_pixels = output_mask_shape[1] * output_mask_shape[2];
  153. for (auto i = 0; i < results->size(); ++i) {
  154. for (auto j = 0; j < (*results)[i].det_result->boxes.size(); ++j) {
  155. Box *box = &(*results)[i].det_result->boxes[j];
  156. auto begin_mask = mask_data + j * mask_pixels;
  157. cv::Mat bin_mask(output_mask_shape[1],
  158. output_mask_shape[2],
  159. CV_32SC1,
  160. begin_mask);
  161. bin_mask.convertTo(bin_mask, CV_8UC1);
  162. box->mask.Clear();
  163. box->mask.shape = {static_cast<int>(output_mask_shape[1]),
  164. static_cast<int>(output_mask_shape[2])};
  165. if (bin_mask.isContinuous()) {
  166. box->mask.data.assign(bin_mask.datastart, bin_mask.dataend);
  167. } else {
  168. for (auto i = 0; i < bin_mask.rows; ++i) {
  169. box->mask.data.insert(box->mask.data.end(),
  170. bin_mask.ptr<uint8_t>(i),
  171. bin_mask.ptr<uint8_t>(i) + bin_mask.cols);
  172. }
  173. }
  174. }
  175. }
  176. return true;
  177. }
  178. bool DetPostprocess::Run(const std::vector<DataBlob>& outputs,
  179. const std::vector<ShapeInfo>& shape_infos,
  180. std::vector<Result>* results, int thread_num) {
  181. results->clear();
  182. if (outputs.size() == 0) {
  183. std::cerr << "empty output on DetPostprocess" << std::endl;
  184. return true;
  185. }
  186. results->resize(shape_infos.size());
  187. if (!ProcessBbox(outputs, shape_infos, results, thread_num)) {
  188. std::cerr << "Error happend while process bboxes" << std::endl;
  189. return false;
  190. }
  191. if (version_ < "2.0" && outputs.size() == 2) {
  192. DataBlob mask_blob = outputs[1];
  193. if (!ProcessMask(&mask_blob, shape_infos, results)) {
  194. std::cerr << "Error happend while process masks" << std::endl;
  195. return false;
  196. }
  197. } else if (version_ >= "2.0" && outputs.size() == 3) {
  198. DataBlob mask_blob = outputs[2];
  199. if (!ProcessMaskV2(&mask_blob, shape_infos, results)) {
  200. std::cerr << "Error happend while process masks" << std::endl;
  201. return false;
  202. }
  203. }
  204. return true;
  205. }
  206. } // namespace PaddleDeploy