base_preprocess.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "model_deploy/common/include/base_preprocess.h"
  15. #include <omp.h>
  16. namespace PaddleDeploy {
  17. bool BasePreprocess::BuildTransform(const YAML::Node& yaml_config) {
  18. transforms_.clear();
  19. YAML::Node transforms_node = yaml_config["transforms"];
  20. for (YAML::const_iterator it = transforms_node.begin();
  21. it != transforms_node.end(); ++it) {
  22. std::string name = it->first.as<std::string>();
  23. std::shared_ptr<Transform> transform = CreateTransform(name);
  24. if (!transform) {
  25. std::cerr << "Failed to create " << name << " on Preprocess" << std::endl;
  26. return false;
  27. }
  28. transform->Init(it->second);
  29. transforms_.push_back(transform);
  30. }
  31. return true;
  32. }
  33. bool BasePreprocess::ShapeInfer(const std::vector<cv::Mat>& imgs,
  34. std::vector<ShapeInfo>* shape_infos,
  35. int thread_num) {
  36. int batch_size = imgs.size();
  37. thread_num = std::min(thread_num, batch_size);
  38. shape_infos->resize(batch_size);
  39. std::vector<int> success(batch_size, 1);
  40. #pragma omp parallel for num_threads(thread_num)
  41. for (auto i = 0; i < batch_size; ++i) {
  42. int h = imgs[i].rows;
  43. int w = imgs[i].cols;
  44. (*shape_infos)[i].Insert("Origin", w, h);
  45. for (auto j = 0; j < transforms_.size(); ++j) {
  46. std::vector<int> out_shape;
  47. if (!transforms_[j]->ShapeInfer((*shape_infos)[i].shapes[j],
  48. &out_shape)) {
  49. std::cerr << "Run transforms ShapeInfer failed!" << std::endl;
  50. success[i] = 0;
  51. continue;
  52. }
  53. (*shape_infos)[i].Insert(transforms_[j]->Name(), out_shape[0],
  54. out_shape[1]);
  55. }
  56. }
  57. if (std::accumulate(success.begin(), success.end(), 0) < batch_size) {
  58. return false;
  59. }
  60. // get max shape
  61. int max_w = 0;
  62. int max_h = 0;
  63. for (auto i = 0; i < shape_infos->size(); ++i) {
  64. if ((*shape_infos)[i].shapes.back()[0] > max_w) {
  65. max_w = (*shape_infos)[i].shapes[transforms_.size()][0];
  66. }
  67. if ((*shape_infos)[i].shapes.back()[1] > max_h) {
  68. max_h = (*shape_infos)[i].shapes[transforms_.size()][1];
  69. }
  70. }
  71. for (auto i = 0; i < shape_infos->size(); ++i) {
  72. (*shape_infos)[i].Insert("Padding", max_w, max_h);
  73. }
  74. return true;
  75. }
  76. bool BasePreprocess::PreprocessImages(const std::vector<ShapeInfo>& shape_infos,
  77. std::vector<cv::Mat>* imgs,
  78. int thread_num) {
  79. int batch_size = imgs->size();
  80. thread_num = std::min(thread_num, batch_size);
  81. int max_w = shape_infos[0].shapes.back()[0];
  82. int max_h = shape_infos[0].shapes.back()[1];
  83. std::vector<int> success(batch_size, 1);
  84. #pragma omp parallel for num_threads(thread_num)
  85. for (auto i = 0; i < batch_size; ++i) {
  86. bool to_chw = false;
  87. for (auto j = 0; j < transforms_.size(); ++j) {
  88. // Permute will put to the last step to apply
  89. if (transforms_[j]->Name() == "Permute") {
  90. to_chw = true;
  91. continue;
  92. }
  93. if (!transforms_[j]->Run(&(*imgs)[i])) {
  94. std::cerr << "Run transforms to image failed!" << std::endl;
  95. success[i] = 0;
  96. continue;
  97. }
  98. }
  99. if (!batch_padding_.Run(&(*imgs)[i], max_w, max_h)) {
  100. std::cerr << "Run BatchPadding to image failed!" << std::endl;
  101. success[i] = 0;
  102. }
  103. // apply permute hwc->chw
  104. if (to_chw) {
  105. if (!permute_.Run(&(*imgs)[i])) {
  106. std::cerr << "Run Permute to image failed!" << std::endl;
  107. success[i] == 0;
  108. }
  109. }
  110. }
  111. if (std::accumulate(success.begin(), success.end(), 0) < batch_size) {
  112. return false;
  113. }
  114. return true;
  115. }
  116. std::shared_ptr<Transform> BasePreprocess::CreateTransform(
  117. const std::string& transform_name) {
  118. if (transform_name == "Normalize") {
  119. return std::make_shared<Normalize>();
  120. } else if (transform_name == "ResizeByShort") {
  121. return std::make_shared<ResizeByShort>();
  122. } else if (transform_name == "ResizeByLong") {
  123. return std::make_shared<ResizeByLong>();
  124. } else if (transform_name == "CenterCrop") {
  125. return std::make_shared<CenterCrop>();
  126. } else if (transform_name == "Permute") {
  127. return std::make_shared<Permute>();
  128. } else if (transform_name == "Resize") {
  129. return std::make_shared<Resize>();
  130. } else if (transform_name == "Padding") {
  131. return std::make_shared<Padding>();
  132. } else if (transform_name == "Clip") {
  133. return std::make_shared<Clip>();
  134. } else if (transform_name == "RGB2BGR") {
  135. return std::make_shared<RGB2BGR>();
  136. } else if (transform_name == "BGR2RGB") {
  137. return std::make_shared<BGR2RGB>();
  138. } else if (transform_name == "Convert") {
  139. return std::make_shared<Convert>();
  140. } else if (transform_name == "OcrResize") {
  141. return std::make_shared<OcrResize>();
  142. } else if (transform_name == "OcrTrtResize") {
  143. return std::make_shared<OcrTrtResize>();
  144. } else {
  145. std::cerr << "There's unexpected transform(name='" << transform_name
  146. << "')." << std::endl;
  147. return nullptr;
  148. }
  149. }
  150. } // namespace PaddleDeploy