det_preprocess.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "model_deploy/ppdet/include/det_preprocess.h"
  15. namespace PaddleDeploy {
  16. bool DetPreprocess::Init(const YAML::Node& yaml_config) {
  17. if (!BuildTransform(yaml_config)) return false;
  18. if (!yaml_config["model_name"].IsDefined()) {
  19. std::cerr << "Yaml file no model_name" << std::endl;
  20. return false;
  21. }
  22. version_ = yaml_config["version"].as<std::string>();
  23. model_arch_ = yaml_config["model_name"].as<std::string>();
  24. return true;
  25. }
  26. bool DetPreprocess::PrepareInputs(const std::vector<ShapeInfo>& shape_infos,
  27. std::vector<cv::Mat>* imgs,
  28. std::vector<DataBlob>* inputs,
  29. int thread_num) {
  30. inputs->clear();
  31. if (!PreprocessImages(shape_infos, imgs, thread_num = thread_num)) {
  32. std::cerr << "Error happend while execute function "
  33. << "DetPreprocess::Run" << std::endl;
  34. return false;
  35. }
  36. if (version_ >= "2.0") {
  37. return PrepareInputsForV2(*imgs, shape_infos, inputs, thread_num);
  38. }
  39. if (model_arch_.find("YOLO") != std::string::npos) {
  40. return PrepareInputsForYOLO(*imgs, shape_infos, inputs, thread_num);
  41. }
  42. if (model_arch_.find("RCNN") != std::string::npos) {
  43. return PrepareInputsForRCNN(*imgs, shape_infos, inputs, thread_num);
  44. }
  45. std::cerr << "Unsupported model type of '" << model_arch_ << "' "
  46. << std::endl;
  47. return false;
  48. }
  49. bool DetPreprocess::PrepareInputsForV2(
  50. const std::vector<cv::Mat>& imgs, const std::vector<ShapeInfo>& shape_infos,
  51. std::vector<DataBlob>* inputs, int thread_num) {
  52. DataBlob scale_factor("scale_factor");
  53. DataBlob image("image");
  54. DataBlob im_shape("im_shape");
  55. // TODO(jiangjiajun): only 3 channel supported
  56. int batch = imgs.size();
  57. int w = shape_infos[0].shapes.back()[0];
  58. int h = shape_infos[0].shapes.back()[1];
  59. scale_factor.Resize({batch, 2}, FLOAT32);
  60. image.Resize({batch, 3, h, w}, FLOAT32);
  61. im_shape.Resize({batch, 2}, FLOAT32);
  62. int sample_shape = 3 * h * w;
  63. #pragma omp parallel for num_threads(thread_num)
  64. for (auto i = 0; i < batch; ++i) {
  65. int shapes_num = shape_infos[i].shapes.size();
  66. float origin_w = static_cast<float>(shape_infos[i].shapes[0][0]);
  67. float origin_h = static_cast<float>(shape_infos[i].shapes[0][1]);
  68. float resize_w = origin_w;
  69. float resize_h = origin_h;
  70. for (auto j = shapes_num - 1; j > 1; --j) {
  71. if (shape_infos[i].transforms[j] == "Padding") {
  72. continue;
  73. }
  74. resize_w = static_cast<float>(shape_infos[i].shapes[j][0]);
  75. resize_h = static_cast<float>(shape_infos[i].shapes[j][1]);
  76. if (shape_infos[i].transforms[j].rfind("Resize", 0) == 0)
  77. break;
  78. }
  79. float scale_x = resize_w / origin_w;
  80. float scale_y = resize_h / origin_h;
  81. float scale_factor_data[] = {scale_y, scale_x};
  82. float im_shape_data[] = {resize_h, resize_w};
  83. memcpy(image.data.data() + i * sample_shape * sizeof(float), imgs[i].data,
  84. sample_shape * sizeof(float));
  85. memcpy(im_shape.data.data() + i * 2 * sizeof(float), im_shape_data,
  86. 2 * sizeof(float));
  87. memcpy(scale_factor.data.data() + i * 2 * sizeof(float), scale_factor_data,
  88. 2 * sizeof(float));
  89. }
  90. inputs->clear();
  91. inputs->push_back(std::move(im_shape));
  92. inputs->push_back(std::move(image));
  93. inputs->push_back(std::move(scale_factor));
  94. return true;
  95. }
  96. bool DetPreprocess::PrepareInputsForYOLO(
  97. const std::vector<cv::Mat>& imgs, const std::vector<ShapeInfo>& shape_infos,
  98. std::vector<DataBlob>* inputs, int thread_num) {
  99. DataBlob im("image");
  100. DataBlob im_size("im_size");
  101. // TODO(jiangjiajun): only 3 channel supported
  102. int batch = imgs.size();
  103. int w = shape_infos[0].shapes.back()[0];
  104. int h = shape_infos[0].shapes.back()[1];
  105. im.Resize({batch, 3, h, w}, FLOAT32);
  106. im_size.Resize({batch, 2}, INT32);
  107. int sample_shape = 3 * h * w;
  108. #pragma omp parallel for num_threads(thread_num)
  109. for (auto i = 0; i < batch; ++i) {
  110. memcpy(im.data.data() + i * sample_shape * sizeof(float), imgs[i].data,
  111. sample_shape * sizeof(float));
  112. int data[2] = {shape_infos[i].shapes[0][1], shape_infos[i].shapes[0][0]};
  113. memcpy(im_size.data.data() + i * 2 * sizeof(int), data, 2 * sizeof(int));
  114. }
  115. inputs->clear();
  116. inputs->push_back(std::move(im));
  117. inputs->push_back(std::move(im_size));
  118. return true;
  119. }
  120. bool DetPreprocess::PrepareInputsForRCNN(
  121. const std::vector<cv::Mat>& imgs, const std::vector<ShapeInfo>& shape_infos,
  122. std::vector<DataBlob>* inputs, int thread_num) {
  123. DataBlob im("image");
  124. DataBlob im_info("im_info");
  125. DataBlob im_shape("im_shape");
  126. // TODO(jiangjiajun): only 3 channel supported
  127. int batch = imgs.size();
  128. int w = shape_infos[0].shapes.back()[0];
  129. int h = shape_infos[0].shapes.back()[1];
  130. im.Resize({batch, 3, h, w}, FLOAT32);
  131. im_info.Resize({batch, 3}, FLOAT32);
  132. im_shape.Resize({batch, 3}, FLOAT32);
  133. int sample_shape = 3 * h * w;
  134. #pragma omp parallel for num_threads(thread_num)
  135. for (auto i = 0; i < batch; ++i) {
  136. int shapes_num = shape_infos[i].shapes.size();
  137. float origin_w = static_cast<float>(shape_infos[i].shapes[0][0]);
  138. float origin_h = static_cast<float>(shape_infos[i].shapes[0][1]);
  139. float resize_w = origin_w;
  140. for (auto j = shapes_num - 1; j > 1; --j) {
  141. if (shape_infos[i].transforms[j] == "Padding") {
  142. continue;
  143. }
  144. resize_w = static_cast<float>(shape_infos[i].shapes[j][0]);
  145. break;
  146. }
  147. float scale = resize_w / origin_w;
  148. float im_info_data[] = {static_cast<float>(h), static_cast<float>(w),
  149. scale};
  150. float im_shape_data[] = {origin_h, origin_w, 1.0};
  151. memcpy(im.data.data() + i * sample_shape * sizeof(float), imgs[i].data,
  152. sample_shape * sizeof(float));
  153. memcpy(im_info.data.data() + i * 3 * sizeof(float), im_info_data,
  154. 3 * sizeof(float));
  155. memcpy(im_shape.data.data() + i * 3 * sizeof(float), im_shape_data,
  156. 3 * sizeof(float));
  157. }
  158. inputs->clear();
  159. inputs->push_back(std::move(im));
  160. inputs->push_back(std::move(im_info));
  161. inputs->push_back(std::move(im_shape));
  162. return true;
  163. }
  164. bool DetPreprocess::Run(std::vector<cv::Mat>* imgs,
  165. std::vector<DataBlob>* inputs,
  166. std::vector<ShapeInfo>* shape_infos, int thread_num) {
  167. if ((*imgs).size() == 0) {
  168. std::cerr << "empty input image on DetPreprocess" << std::endl;
  169. return true;
  170. }
  171. if (!ShapeInfer(*imgs, shape_infos, thread_num)) {
  172. std::cerr << "ShapeInfer failed while call DetPreprocess::Run" << std::endl;
  173. return false;
  174. }
  175. if (!PrepareInputs(*shape_infos, imgs, inputs, thread_num)) {
  176. std::cerr << "PrepareInputs failed while call "
  177. << "DetPreprocess::PrepareInputs" << std::endl;
  178. return false;
  179. }
  180. return true;
  181. }
  182. } // namespace PaddleDeploy