animegan_pybind.cc 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/pybind/main.h"
  15. namespace ultra_infer {
  16. void BindAnimeGAN(pybind11::module &m) {
  17. pybind11::class_<vision::generation::AnimeGAN, UltraInferModel>(m, "AnimeGAN")
  18. .def(pybind11::init<std::string, std::string, RuntimeOption,
  19. ModelFormat>())
  20. .def("predict",
  21. [](vision::generation::AnimeGAN &self, pybind11::array &data) {
  22. auto mat = PyArrayToCvMat(data);
  23. cv::Mat res;
  24. self.Predict(mat, &res);
  25. auto ret = pybind11::array_t<unsigned char>(
  26. {res.rows, res.cols, res.channels()}, res.data);
  27. return ret;
  28. })
  29. .def("batch_predict",
  30. [](vision::generation::AnimeGAN &self,
  31. std::vector<pybind11::array> &data) {
  32. std::vector<cv::Mat> images;
  33. for (size_t i = 0; i < data.size(); ++i) {
  34. images.push_back(PyArrayToCvMat(data[i]));
  35. }
  36. std::vector<cv::Mat> results;
  37. self.BatchPredict(images, &results);
  38. std::vector<pybind11::array_t<unsigned char>> ret;
  39. for (size_t i = 0; i < results.size(); ++i) {
  40. ret.push_back(pybind11::array_t<unsigned char>(
  41. {results[i].rows, results[i].cols, results[i].channels()},
  42. results[i].data));
  43. }
  44. return ret;
  45. })
  46. .def_property_readonly("preprocessor",
  47. &vision::generation::AnimeGAN::GetPreprocessor)
  48. .def_property_readonly("postprocessor",
  49. &vision::generation::AnimeGAN::GetPostprocessor);
  50. pybind11::class_<vision::generation::AnimeGANPreprocessor>(
  51. m, "AnimeGANPreprocessor")
  52. .def(pybind11::init<>())
  53. .def("run", [](vision::generation::AnimeGANPreprocessor &self,
  54. std::vector<pybind11::array> &im_list) {
  55. std::vector<vision::FDMat> images;
  56. for (size_t i = 0; i < im_list.size(); ++i) {
  57. images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
  58. }
  59. std::vector<FDTensor> outputs;
  60. if (!self.Run(images, &outputs)) {
  61. throw std::runtime_error(
  62. "Failed to preprocess the input data in PaddleClasPreprocessor.");
  63. }
  64. for (size_t i = 0; i < outputs.size(); ++i) {
  65. outputs[i].StopSharing();
  66. }
  67. return outputs;
  68. });
  69. pybind11::class_<vision::generation::AnimeGANPostprocessor>(
  70. m, "AnimeGANPostprocessor")
  71. .def(pybind11::init<>())
  72. .def("run", [](vision::generation::AnimeGANPostprocessor &self,
  73. std::vector<FDTensor> &inputs) {
  74. std::vector<cv::Mat> results;
  75. if (!self.Run(inputs, &results)) {
  76. throw std::runtime_error("Failed to postprocess the runtime result "
  77. "in YOLOv5Postprocessor.");
  78. }
  79. return results;
  80. });
  81. }
  82. } // namespace ultra_infer