transforms.h 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <yaml-cpp/yaml.h>
  16. #include <memory>
  17. #include <string>
  18. #include <unordered_map>
  19. #include <utility>
  20. #include <vector>
  21. #include <opencv2/core/core.hpp>
  22. #include <opencv2/highgui/highgui.hpp>
  23. #include <opencv2/imgproc/imgproc.hpp>
  24. #include <inference_engine.hpp>
  25. using namespace InferenceEngine;
  26. namespace PaddleX {
  27. // Abstraction of preprocessing opration class
  28. class Transform {
  29. public:
  30. virtual void Init(const YAML::Node& item) = 0;
  31. virtual bool Run(cv::Mat* im) = 0;
  32. };
  33. class Normalize : public Transform {
  34. public:
  35. virtual void Init(const YAML::Node& item) {
  36. mean_ = item["mean"].as<std::vector<float>>();
  37. std_ = item["std"].as<std::vector<float>>();
  38. }
  39. virtual bool Run(cv::Mat* im);
  40. private:
  41. std::vector<float> mean_;
  42. std::vector<float> std_;
  43. };
  44. class ResizeByShort : public Transform {
  45. public:
  46. virtual void Init(const YAML::Node& item) {
  47. short_size_ = item["short_size"].as<int>();
  48. if (item["max_size"].IsDefined()) {
  49. max_size_ = item["max_size"].as<int>();
  50. } else {
  51. max_size_ = -1;
  52. }
  53. };
  54. virtual bool Run(cv::Mat* im);
  55. private:
  56. float GenerateScale(const cv::Mat& im);
  57. int short_size_;
  58. int max_size_;
  59. };
  60. class CenterCrop : public Transform {
  61. public:
  62. virtual void Init(const YAML::Node& item) {
  63. if (item["crop_size"].IsScalar()) {
  64. height_ = item["crop_size"].as<int>();
  65. width_ = item["crop_size"].as<int>();
  66. } else if (item["crop_size"].IsSequence()) {
  67. std::vector<int> crop_size = item["crop_size"].as<std::vector<int>>();
  68. width_ = crop_size[0];
  69. height_ = crop_size[1];
  70. }
  71. }
  72. virtual bool Run(cv::Mat* im);
  73. private:
  74. int height_;
  75. int width_;
  76. };
  77. class Transforms {
  78. public:
  79. void Init(const YAML::Node& node, bool to_rgb = true);
  80. std::shared_ptr<Transform> CreateTransform(const std::string& name);
  81. bool Run(cv::Mat* im, Blob::Ptr blob);
  82. private:
  83. std::vector<std::shared_ptr<Transform>> transforms_;
  84. bool to_rgb_ = true;
  85. };
  86. } // namespace PaddleX