transform.cc 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/common/processors/transform.h"
  15. namespace ultra_infer {
  16. namespace vision {
  17. void FuseNormalizeCast(std::vector<std::shared_ptr<Processor>> *processors) {
  18. // Fuse Normalize and Cast<Float>
  19. int cast_index = -1;
  20. for (size_t i = 0; i < processors->size(); ++i) {
  21. if ((*processors)[i]->Name() == "Cast") {
  22. if (i == 0) {
  23. continue;
  24. }
  25. if ((*processors)[i - 1]->Name() != "Normalize" &&
  26. (*processors)[i - 1]->Name() != "NormalizeAndPermute") {
  27. continue;
  28. }
  29. cast_index = i;
  30. }
  31. }
  32. if (cast_index < 0) {
  33. return;
  34. }
  35. if (dynamic_cast<Cast *>((*processors)[cast_index].get())->GetDtype() !=
  36. "float") {
  37. return;
  38. }
  39. processors->erase(processors->begin() + cast_index);
  40. FDINFO << (*processors)[cast_index - 1]->Name() << " and Cast are fused to "
  41. << (*processors)[cast_index - 1]->Name()
  42. << " in preprocessing pipeline." << std::endl;
  43. }
  44. void FuseNormalizeHWC2CHW(std::vector<std::shared_ptr<Processor>> *processors) {
  45. // Fuse Normalize and HWC2CHW to NormalizeAndPermute
  46. int hwc2chw_index = -1;
  47. for (size_t i = 0; i < processors->size(); ++i) {
  48. if ((*processors)[i]->Name() == "HWC2CHW") {
  49. if (i == 0) {
  50. continue;
  51. }
  52. if ((*processors)[i - 1]->Name() != "Normalize") {
  53. continue;
  54. }
  55. hwc2chw_index = i;
  56. }
  57. }
  58. if (hwc2chw_index < 0) {
  59. return;
  60. }
  61. // Get alpha and beta of Normalize
  62. std::vector<float> alpha =
  63. dynamic_cast<Normalize *>((*processors)[hwc2chw_index - 1].get())
  64. ->GetAlpha();
  65. std::vector<float> beta =
  66. dynamic_cast<Normalize *>((*processors)[hwc2chw_index - 1].get())
  67. ->GetBeta();
  68. // Delete Normalize and HWC2CHW
  69. processors->erase(processors->begin() + hwc2chw_index);
  70. processors->erase(processors->begin() + hwc2chw_index - 1);
  71. // Add NormalizeAndPermute
  72. std::vector<float> mean({0.0, 0.0, 0.0});
  73. std::vector<float> std({1.0, 1.0, 1.0});
  74. processors->push_back(std::make_shared<NormalizeAndPermute>(mean, std));
  75. // Set alpha and beta
  76. auto processor = dynamic_cast<NormalizeAndPermute *>(
  77. (*processors)[hwc2chw_index - 1].get());
  78. processor->SetAlpha(alpha);
  79. processor->SetBeta(beta);
  80. FDINFO << "Normalize and HWC2CHW are fused to NormalizeAndPermute "
  81. " in preprocessing pipeline."
  82. << std::endl;
  83. }
  84. void FuseNormalizeColorConvert(
  85. std::vector<std::shared_ptr<Processor>> *processors) {
  86. // Fuse Normalize and BGR2RGB/RGB2BGR
  87. int normalize_index = -1;
  88. int color_convert_index = -1;
  89. // If these middle processors are after BGR2RGB/RGB2BGR and before Normalize,
  90. // we can still fuse Normalize and BGR2RGB/RGB2BGR
  91. static std::unordered_set<std::string> middle_processors(
  92. {"Resize", "ResizeByShort", "ResizeByLong", "Crop", "CenterCrop",
  93. "LimitByStride", "LimitShort", "Pad", "PadToSize", "StridePad",
  94. "WarpAffine"});
  95. for (size_t i = 0; i < processors->size(); ++i) {
  96. if ((*processors)[i]->Name() == "BGR2RGB" ||
  97. (*processors)[i]->Name() == "RGB2BGR") {
  98. color_convert_index = i;
  99. for (size_t j = color_convert_index + 1; j < processors->size(); ++j) {
  100. if ((*processors)[j]->Name() == "Normalize" ||
  101. (*processors)[j]->Name() == "NormalizeAndPermute") {
  102. normalize_index = j;
  103. break;
  104. }
  105. }
  106. if (normalize_index < 0) {
  107. return;
  108. }
  109. for (size_t j = color_convert_index + 1; j < normalize_index; ++j) {
  110. if (middle_processors.count((*processors)[j]->Name())) {
  111. continue;
  112. }
  113. return;
  114. }
  115. }
  116. }
  117. if (color_convert_index < 0) {
  118. return;
  119. }
  120. // Delete Color Space Convert
  121. std::string color_processor_name = (*processors)[color_convert_index]->Name();
  122. processors->erase(processors->begin() + color_convert_index);
  123. // Toggle the swap_rb option of the Normalize processor
  124. std::string normalize_processor_name =
  125. (*processors)[normalize_index - 1]->Name();
  126. bool swap_rb;
  127. if (normalize_processor_name == "Normalize") {
  128. auto processor =
  129. dynamic_cast<Normalize *>((*processors)[normalize_index - 1].get());
  130. swap_rb = processor->GetSwapRB();
  131. processor->SetSwapRB(!swap_rb);
  132. } else if (normalize_processor_name == "NormalizeAndPermute") {
  133. auto processor = dynamic_cast<NormalizeAndPermute *>(
  134. (*processors)[normalize_index - 1].get());
  135. swap_rb = processor->GetSwapRB();
  136. processor->SetSwapRB(!swap_rb);
  137. } else {
  138. FDASSERT(false, "Something wrong in FuseNormalizeColorConvert().");
  139. }
  140. FDINFO << color_processor_name << " and " << normalize_processor_name
  141. << " are fused to " << normalize_processor_name
  142. << " with swap_rb=" << !swap_rb << std::endl;
  143. }
  144. void FuseTransforms(std::vector<std::shared_ptr<Processor>> *processors) {
  145. FuseNormalizeCast(processors);
  146. FuseNormalizeHWC2CHW(processors);
  147. FuseNormalizeColorConvert(processors);
  148. }
  149. } // namespace vision
  150. } // namespace ultra_infer