transforms.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <iostream>
  15. #include <string>
  16. #include <vector>
  17. #include "include/paddlex/transforms.h"
  18. namespace PaddleX {
  19. std::map<std::string, int> interpolations = {{"LINEAR", cv::INTER_LINEAR},
  20. {"NEAREST", cv::INTER_NEAREST},
  21. {"AREA", cv::INTER_AREA},
  22. {"CUBIC", cv::INTER_CUBIC},
  23. {"LANCZOS4", cv::INTER_LANCZOS4}};
  24. bool Normalize::Run(cv::Mat* im, ImageBlob* data) {
  25. for (int h = 0; h < im->rows; h++) {
  26. for (int w = 0; w < im->cols; w++) {
  27. im->at<cv::Vec3f>(h, w)[0] =
  28. (im->at<cv::Vec3f>(h, w)[0] / 255.0 - mean_[0]) / std_[0];
  29. im->at<cv::Vec3f>(h, w)[1] =
  30. (im->at<cv::Vec3f>(h, w)[1] / 255.0 - mean_[1]) / std_[1];
  31. im->at<cv::Vec3f>(h, w)[2] =
  32. (im->at<cv::Vec3f>(h, w)[2] / 255.0 - mean_[2]) / std_[2];
  33. }
  34. }
  35. return true;
  36. }
  37. bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
  38. int height = static_cast<int>(im->rows);
  39. int width = static_cast<int>(im->cols);
  40. if (height < height_ || width < width_) {
  41. std::cerr << "[CenterCrop] Image size less than crop size" << std::endl;
  42. return false;
  43. }
  44. int offset_x = static_cast<int>((width - width_) / 2);
  45. int offset_y = static_cast<int>((height - height_) / 2);
  46. cv::Rect crop_roi(offset_x, offset_y, width_, height_);
  47. *im = (*im)(crop_roi);
  48. data->new_im_size_[0] = im->rows;
  49. data->new_im_size_[1] = im->cols;
  50. return true;
  51. }
  52. bool Resize::Run(cv::Mat* im, ImageBlob* data) {
  53. if (width_ <= 0 || height_ <= 0) {
  54. std::cerr << "[Resize] width and height should be greater than 0"
  55. << std::endl;
  56. return false;
  57. }
  58. if (interpolations.count(interp_) <= 0) {
  59. std::cerr << "[Resize] Invalid interpolation method: '" << interp_ << "'"
  60. << std::endl;
  61. return false;
  62. }
  63. data->im_size_before_resize_.push_back({im->rows, im->cols});
  64. data->reshape_order_.push_back("resize");
  65. cv::resize(
  66. *im, *im, cv::Size(width_, height_), 0, 0, interpolations[interp_]);
  67. data->new_im_size_[0] = im->rows;
  68. data->new_im_size_[1] = im->cols;
  69. return true;
  70. }
  71. void Transforms::Init(const YAML::Node& transforms_node, bool to_rgb) {
  72. transforms_.clear();
  73. to_rgb_ = to_rgb;
  74. for (const auto& item : transforms_node) {
  75. std::string name = item.begin()->first.as<std::string>();
  76. std::cout << "trans name: " << name << std::endl;
  77. std::shared_ptr<Transform> transform = CreateTransform(name);
  78. transform->Init(item.begin()->second);
  79. transforms_.push_back(transform);
  80. }
  81. }
  82. std::shared_ptr<Transform> Transforms::CreateTransform(
  83. const std::string& transform_name) {
  84. if (transform_name == "Normalize") {
  85. return std::make_shared<Normalize>();
  86. } else if (transform_name == "CenterCrop") {
  87. return std::make_shared<CenterCrop>();
  88. } else if (transform_name == "Resize") {
  89. return std::make_shared<Resize>();
  90. } else {
  91. std::cerr << "There's unexpected transform(name='" << transform_name
  92. << "')." << std::endl;
  93. exit(-1);
  94. }
  95. }
  96. bool Transforms::Run(cv::Mat* im, Blob::ptr data) {
  97. // 按照transforms中预处理算子顺序处理图像
  98. if (to_rgb_) {
  99. cv::cvtColor(*im, *im, cv::COLOR_BGR2RGB);
  100. }
  101. (*im).convertTo(*im, CV_32FC3);
  102. for (int i = 0; i < transforms_.size(); ++i) {
  103. if (!transforms_[i]->Run(im, data)) {
  104. std::cerr << "Apply transforms to image failed!" << std::endl;
  105. return false;
  106. }
  107. }
  108. // 将图像由NHWC转为NCHW格式
  109. // 同时转为连续的内存块存储到Blob
  110. SizeVector blobSize = data_->getTensorDesc().getDims();
  111. const size_t width = blobSize[3];
  112. const size_t height = blobSize[2];
  113. const size_t channels = blobSize[1];
  114. MemoryBlob::Ptr mblob = InferenceEngine::as<MemoryBlob>(blob);
  115. auto mblobHolder = mblob->wmap();
  116. float *blob_data = mblobHolder.as<float *>();
  117. for (size_t c = 0; c < channels; c++) {
  118. for (size_t h = 0; h < height; h++) {
  119. for (size_t w = 0; w < width; w++) {
  120. blob_data[c * width * height + h * width + w] =
  121. im.at<cv::Vec3f>(h, w)[c];
  122. }
  123. }
  124. }
  125. return true;
  126. }
  127. } // namespace PaddleX