postprocessor.cc 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/generation/contrib/postprocessor.h"
  15. namespace ultra_infer {
  16. namespace vision {
  17. namespace generation {
  18. bool AnimeGANPostprocessor::Run(std::vector<FDTensor> &infer_results,
  19. std::vector<cv::Mat> *results) {
  20. // 1. Reverse normalization
  21. // 2. RGB2BGR
  22. FDTensor &output_tensor = infer_results.at(0);
  23. std::vector<int64_t> shape = output_tensor.Shape(); // n, h, w, c
  24. int size = shape[1] * shape[2] * shape[3];
  25. results->resize(shape[0]);
  26. float *infer_result_data = reinterpret_cast<float *>(output_tensor.Data());
  27. for (size_t i = 0; i < results->size(); ++i) {
  28. Mat result_mat = Mat::Create(shape[1], shape[2], 3, FDDataType::FP32,
  29. infer_result_data + i * size);
  30. std::vector<float> mean{127.5f, 127.5f, 127.5f};
  31. std::vector<float> std{127.5f, 127.5f, 127.5f};
  32. Convert::Run(&result_mat, mean, std);
  33. // tmp data type is float[0-1.0],convert to uint type
  34. auto temp = result_mat.GetOpenCVMat();
  35. cv::Mat res = cv::Mat::zeros(temp->size(), CV_8UC3);
  36. temp->convertTo(res, CV_8UC3, 1);
  37. Mat fd_image = WrapMat(res);
  38. BGR2RGB::Run(&fd_image);
  39. res = *(fd_image.GetOpenCVMat());
  40. res.copyTo(results->at(i));
  41. }
  42. return true;
  43. }
  44. } // namespace generation
  45. } // namespace vision
  46. } // namespace ultra_infer