preprocessor.cc 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/classification/ppcls/preprocessor.h"
  15. #include "yaml-cpp/yaml.h"
  16. namespace ultra_infer {
  17. namespace vision {
  18. namespace classification {
  19. PaddleClasPreprocessor::PaddleClasPreprocessor(const std::string &config_file) {
  20. this->config_file_ = config_file;
  21. FDASSERT(BuildPreprocessPipelineFromConfig(),
  22. "Failed to create PaddleClasPreprocessor.");
  23. initialized_ = true;
  24. }
  25. bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig() {
  26. processors_.clear();
  27. YAML::Node cfg;
  28. try {
  29. cfg = YAML::LoadFile(config_file_);
  30. } catch (YAML::BadFile &e) {
  31. FDERROR << "Failed to load yaml file " << config_file_
  32. << ", maybe you should check this file." << std::endl;
  33. return false;
  34. }
  35. auto preprocess_cfg = cfg["PreProcess"]["transform_ops"];
  36. processors_.push_back(std::make_shared<BGR2RGB>());
  37. for (const auto &op : preprocess_cfg) {
  38. FDASSERT(op.IsMap(),
  39. "Require the transform information in yaml be Map type.");
  40. auto op_name = op.begin()->first.as<std::string>();
  41. if (op_name == "ResizeImage") {
  42. if (op.begin()->second["resize_short"]) {
  43. int target_size = op.begin()->second["resize_short"].as<int>();
  44. bool use_scale = false;
  45. int interp = 1;
  46. processors_.push_back(
  47. std::make_shared<ResizeByShort>(target_size, 1, use_scale));
  48. } else if (op.begin()->second["size"]) {
  49. int width = 0;
  50. int height = 0;
  51. if (op.begin()->second["size"].IsScalar()) {
  52. auto size = op.begin()->second["size"].as<int>();
  53. width = size;
  54. height = size;
  55. } else {
  56. auto size = op.begin()->second["size"].as<std::vector<int>>();
  57. width = size[0];
  58. height = size[1];
  59. }
  60. processors_.push_back(
  61. std::make_shared<Resize>(width, height, -1.0, -1.0, 1, false));
  62. } else {
  63. FDERROR << "Invalid params for ResizeImage for both 'size' and "
  64. "'resize_short' are None"
  65. << std::endl;
  66. }
  67. } else if (op_name == "CropImage") {
  68. int width = op.begin()->second["size"].as<int>();
  69. int height = op.begin()->second["size"].as<int>();
  70. processors_.push_back(std::make_shared<CenterCrop>(width, height));
  71. } else if (op_name == "NormalizeImage") {
  72. if (!disable_normalize_) {
  73. auto mean = op.begin()->second["mean"].as<std::vector<float>>();
  74. auto std = op.begin()->second["std"].as<std::vector<float>>();
  75. const auto &scale_origin = op.begin()->second["scale"];
  76. float scale;
  77. if (scale_origin.as<std::string>() == "1/255") {
  78. scale = 1.0f / 255.0f;
  79. } else {
  80. scale = scale_origin.as<float>();
  81. }
  82. processors_.push_back(std::make_shared<Normalize>(
  83. mean, std, true, std::vector<float>(mean.size(), 0.0f),
  84. std::vector<float>(mean.size(), 1.0f / scale)));
  85. }
  86. } else if (op_name == "ToCHWImage") {
  87. if (!disable_permute_) {
  88. processors_.push_back(std::make_shared<HWC2CHW>());
  89. }
  90. } else {
  91. FDERROR << "Unexcepted preprocess operator: " << op_name << "."
  92. << std::endl;
  93. return false;
  94. }
  95. }
  96. // Fusion will improve performance
  97. FuseTransforms(&processors_);
  98. return true;
  99. }
  100. void PaddleClasPreprocessor::DisableNormalize() {
  101. this->disable_normalize_ = true;
  102. // the DisableNormalize function will be invalid if the configuration file is
  103. // loaded during preprocessing
  104. if (!BuildPreprocessPipelineFromConfig()) {
  105. FDERROR << "Failed to build preprocess pipeline from configuration file."
  106. << std::endl;
  107. }
  108. }
  109. void PaddleClasPreprocessor::DisablePermute() {
  110. this->disable_permute_ = true;
  111. // the DisablePermute function will be invalid if the configuration file is
  112. // loaded during preprocessing
  113. if (!BuildPreprocessPipelineFromConfig()) {
  114. FDERROR << "Failed to build preprocess pipeline from configuration file."
  115. << std::endl;
  116. }
  117. }
  118. bool PaddleClasPreprocessor::Apply(FDMatBatch *image_batch,
  119. std::vector<FDTensor> *outputs) {
  120. if (!initialized_) {
  121. FDERROR << "The preprocessor is not initialized." << std::endl;
  122. return false;
  123. }
  124. for (size_t j = 0; j < processors_.size(); ++j) {
  125. image_batch->proc_lib = proc_lib_;
  126. if (initial_resize_on_cpu_ && j == 0 &&
  127. processors_[j]->Name().find("Resize") == 0) {
  128. image_batch->proc_lib = ProcLib::OPENCV;
  129. }
  130. if (!(*(processors_[j].get()))(image_batch)) {
  131. FDERROR << "Failed to process image in " << processors_[j]->Name() << "."
  132. << std::endl;
  133. return false;
  134. }
  135. }
  136. outputs->resize(1);
  137. FDTensor *tensor = image_batch->Tensor();
  138. (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
  139. tensor->Data(), tensor->device,
  140. tensor->device_id);
  141. return true;
  142. }
  143. } // namespace classification
  144. } // namespace vision
  145. } // namespace ultra_infer