det_preprocess.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "model_deploy/ppdet/include/det_preprocess.h"
  15. namespace PaddleDeploy {
  16. bool DetPreprocess::Init(const YAML::Node& yaml_config) {
  17. if (!BuildTransform(yaml_config)) return false;
  18. if (!yaml_config["model_name"].IsDefined()) {
  19. std::cerr << "Yaml file no model_name" << std::endl;
  20. return false;
  21. }
  22. version_ = yaml_config["version"].as<std::string>();
  23. model_arch_ = yaml_config["model_name"].as<std::string>();
  24. return true;
  25. }
  26. bool DetPreprocess::PrepareInputs(const std::vector<ShapeInfo>& shape_infos,
  27. std::vector<cv::Mat>* imgs,
  28. std::vector<DataBlob>* inputs,
  29. int thread_num) {
  30. inputs->clear();
  31. if (!PreprocessImages(shape_infos, imgs, thread_num = thread_num)) {
  32. std::cerr << "Error happend while execute function "
  33. << "DetPreprocess::Run" << std::endl;
  34. return false;
  35. }
  36. if (version_ >= "2.0") {
  37. return PrepareInputsForV2(*imgs, shape_infos, inputs, thread_num);
  38. }
  39. if (model_arch_.find("YOLO") != std::string::npos) {
  40. return PrepareInputsForYOLO(*imgs, shape_infos, inputs, thread_num);
  41. }
  42. if (model_arch_.find("RCNN") != std::string::npos) {
  43. return PrepareInputsForRCNN(*imgs, shape_infos, inputs, thread_num);
  44. }
  45. std::cerr << "Unsupported model type of '" << model_arch_ << "' "
  46. << std::endl;
  47. return false;
  48. }
  49. bool DetPreprocess::PrepareInputsForV2(
  50. const std::vector<cv::Mat>& imgs, const std::vector<ShapeInfo>& shape_infos,
  51. std::vector<DataBlob>* inputs, int thread_num) {
  52. DataBlob scale_factor("scale_factor");
  53. DataBlob image("image");
  54. DataBlob im_shape("im_shape");
  55. // TODO(jiangjiajun): only 3 channel supported
  56. int batch = imgs.size();
  57. int w = shape_infos[0].shapes.back()[0];
  58. int h = shape_infos[0].shapes.back()[1];
  59. scale_factor.Resize({batch, 2}, FLOAT32);
  60. image.Resize({batch, 3, h, w}, FLOAT32);
  61. im_shape.Resize({batch, 2}, FLOAT32);
  62. int sample_shape = 3 * h * w;
  63. #pragma omp parallel for num_threads(thread_num)
  64. for (auto i = 0; i < batch; ++i) {
  65. int shapes_num = shape_infos[i].shapes.size();
  66. float origin_w = static_cast<float>(shape_infos[i].shapes[0][0]);
  67. float origin_h = static_cast<float>(shape_infos[i].shapes[0][1]);
  68. float resize_w = origin_w;
  69. float resize_h = origin_h;
  70. for (auto j = shapes_num - 1; j > 1; --j) {
  71. if (shape_infos[i].transforms[j] == "Padding") {
  72. continue;
  73. }
  74. resize_w = static_cast<float>(shape_infos[i].shapes[j][0]);
  75. resize_h = static_cast<float>(shape_infos[i].shapes[j][1]);
  76. break;
  77. }
  78. float scale_x = resize_w / origin_w;
  79. float scale_y = resize_h / origin_h;
  80. float scale_factor_data[] = {scale_y, scale_x};
  81. float im_shape_data[] = {resize_h, resize_w};
  82. memcpy(image.data.data() + i * sample_shape * sizeof(float), imgs[i].data,
  83. sample_shape * sizeof(float));
  84. memcpy(im_shape.data.data() + i * 2 * sizeof(float), im_shape_data,
  85. 2 * sizeof(float));
  86. memcpy(scale_factor.data.data() + i * 2 * sizeof(float), scale_factor_data,
  87. 2 * sizeof(float));
  88. }
  89. inputs->clear();
  90. inputs->push_back(std::move(im_shape));
  91. inputs->push_back(std::move(image));
  92. inputs->push_back(std::move(scale_factor));
  93. return true;
  94. }
  95. bool DetPreprocess::PrepareInputsForYOLO(
  96. const std::vector<cv::Mat>& imgs, const std::vector<ShapeInfo>& shape_infos,
  97. std::vector<DataBlob>* inputs, int thread_num) {
  98. DataBlob im("image");
  99. DataBlob im_size("im_size");
  100. // TODO(jiangjiajun): only 3 channel supported
  101. int batch = imgs.size();
  102. int w = shape_infos[0].shapes.back()[0];
  103. int h = shape_infos[0].shapes.back()[1];
  104. im.Resize({batch, 3, h, w}, FLOAT32);
  105. im_size.Resize({batch, 2}, INT32);
  106. int sample_shape = 3 * h * w;
  107. #pragma omp parallel for num_threads(thread_num)
  108. for (auto i = 0; i < batch; ++i) {
  109. memcpy(im.data.data() + i * sample_shape * sizeof(float), imgs[i].data,
  110. sample_shape * sizeof(float));
  111. int data[2] = {shape_infos[i].shapes[0][1], shape_infos[i].shapes[0][0]};
  112. memcpy(im_size.data.data() + i * 2 * sizeof(int), data, 2 * sizeof(int));
  113. }
  114. inputs->clear();
  115. inputs->push_back(std::move(im));
  116. inputs->push_back(std::move(im_size));
  117. return true;
  118. }
  119. bool DetPreprocess::PrepareInputsForRCNN(
  120. const std::vector<cv::Mat>& imgs, const std::vector<ShapeInfo>& shape_infos,
  121. std::vector<DataBlob>* inputs, int thread_num) {
  122. DataBlob im("image");
  123. DataBlob im_info("im_info");
  124. DataBlob im_shape("im_shape");
  125. // TODO(jiangjiajun): only 3 channel supported
  126. int batch = imgs.size();
  127. int w = shape_infos[0].shapes.back()[0];
  128. int h = shape_infos[0].shapes.back()[1];
  129. im.Resize({batch, 3, h, w}, FLOAT32);
  130. im_info.Resize({batch, 3}, FLOAT32);
  131. im_shape.Resize({batch, 3}, FLOAT32);
  132. int sample_shape = 3 * h * w;
  133. #pragma omp parallel for num_threads(thread_num)
  134. for (auto i = 0; i < batch; ++i) {
  135. int shapes_num = shape_infos[i].shapes.size();
  136. float origin_w = static_cast<float>(shape_infos[i].shapes[0][0]);
  137. float origin_h = static_cast<float>(shape_infos[i].shapes[0][1]);
  138. float resize_w = origin_w;
  139. for (auto j = shapes_num - 1; j > 1; --j) {
  140. if (shape_infos[i].transforms[j] == "Padding") {
  141. continue;
  142. }
  143. resize_w = static_cast<float>(shape_infos[i].shapes[j][0]);
  144. break;
  145. }
  146. float scale = resize_w / origin_w;
  147. float im_info_data[] = {static_cast<float>(h), static_cast<float>(w),
  148. scale};
  149. float im_shape_data[] = {origin_h, origin_w, 1.0};
  150. memcpy(im.data.data() + i * sample_shape * sizeof(float), imgs[i].data,
  151. sample_shape * sizeof(float));
  152. memcpy(im_info.data.data() + i * 3 * sizeof(float), im_info_data,
  153. 3 * sizeof(float));
  154. memcpy(im_shape.data.data() + i * 3 * sizeof(float), im_shape_data,
  155. 3 * sizeof(float));
  156. }
  157. inputs->clear();
  158. inputs->push_back(std::move(im));
  159. inputs->push_back(std::move(im_info));
  160. inputs->push_back(std::move(im_shape));
  161. return true;
  162. }
  163. bool DetPreprocess::Run(std::vector<cv::Mat>* imgs,
  164. std::vector<DataBlob>* inputs,
  165. std::vector<ShapeInfo>* shape_infos, int thread_num) {
  166. if ((*imgs).size() == 0) {
  167. std::cerr << "empty input image on DetPreprocess" << std::endl;
  168. return true;
  169. }
  170. if (!ShapeInfer(*imgs, shape_infos, thread_num)) {
  171. std::cerr << "ShapeInfer failed while call DetPreprocess::Run" << std::endl;
  172. return false;
  173. }
  174. if (!PrepareInputs(*shape_infos, imgs, inputs, thread_num)) {
  175. std::cerr << "PrepareInputs failed while call "
  176. << "DetPreprocess::PrepareInputs" << std::endl;
  177. return false;
  178. }
  179. return true;
  180. }
  181. } // namespace PaddleDeploy