adaface_pybind.cc 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/pybind/main.h"
  15. namespace ultra_infer {
  16. void BindAdaFace(pybind11::module &m) {
  17. pybind11::class_<vision::faceid::AdaFacePreprocessor>(m,
  18. "AdaFacePreprocessor")
  19. .def(pybind11::init())
  20. .def("run",
  21. [](vision::faceid::AdaFacePreprocessor &self,
  22. std::vector<pybind11::array> &im_list) {
  23. std::vector<vision::FDMat> images;
  24. for (size_t i = 0; i < im_list.size(); ++i) {
  25. images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
  26. }
  27. std::vector<FDTensor> outputs;
  28. if (!self.Run(&images, &outputs)) {
  29. throw std::runtime_error("Failed to preprocess the input data "
  30. "in AdaFacePreprocessor.");
  31. }
  32. for (size_t i = 0; i < outputs.size(); ++i) {
  33. outputs[i].StopSharing();
  34. }
  35. return outputs;
  36. })
  37. .def_property("permute", &vision::faceid::AdaFacePreprocessor::GetPermute,
  38. &vision::faceid::AdaFacePreprocessor::SetPermute)
  39. .def_property("alpha", &vision::faceid::AdaFacePreprocessor::GetAlpha,
  40. &vision::faceid::AdaFacePreprocessor::SetAlpha)
  41. .def_property("beta", &vision::faceid::AdaFacePreprocessor::GetBeta,
  42. &vision::faceid::AdaFacePreprocessor::SetBeta)
  43. .def_property("size", &vision::faceid::AdaFacePreprocessor::GetSize,
  44. &vision::faceid::AdaFacePreprocessor::SetSize);
  45. pybind11::class_<vision::faceid::AdaFacePostprocessor>(m,
  46. "AdaFacePostprocessor")
  47. .def(pybind11::init())
  48. .def("run",
  49. [](vision::faceid::AdaFacePostprocessor &self,
  50. std::vector<FDTensor> &inputs) {
  51. std::vector<vision::FaceRecognitionResult> results;
  52. if (!self.Run(inputs, &results)) {
  53. throw std::runtime_error("Failed to postprocess the runtime "
  54. "result in AdaFacePostprocessor.");
  55. }
  56. return results;
  57. })
  58. .def("run",
  59. [](vision::faceid::AdaFacePostprocessor &self,
  60. std::vector<pybind11::array> &input_array) {
  61. std::vector<vision::FaceRecognitionResult> results;
  62. std::vector<FDTensor> inputs;
  63. PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
  64. if (!self.Run(inputs, &results)) {
  65. throw std::runtime_error("Failed to postprocess the runtime "
  66. "result in AdaFacePostprocessor.");
  67. }
  68. return results;
  69. })
  70. .def_property("l2_normalize",
  71. &vision::faceid::AdaFacePostprocessor::GetL2Normalize,
  72. &vision::faceid::AdaFacePostprocessor::SetL2Normalize);
  73. pybind11::class_<vision::faceid::AdaFace, UltraInferModel>(m, "AdaFace")
  74. .def(pybind11::init<std::string, std::string, RuntimeOption,
  75. ModelFormat>())
  76. .def("predict",
  77. [](vision::faceid::AdaFace &self, pybind11::array &data) {
  78. cv::Mat im = PyArrayToCvMat(data);
  79. vision::FaceRecognitionResult result;
  80. self.Predict(im, &result);
  81. return result;
  82. })
  83. .def("batch_predict",
  84. [](vision::faceid::AdaFace &self,
  85. std::vector<pybind11::array> &data) {
  86. std::vector<cv::Mat> images;
  87. for (size_t i = 0; i < data.size(); ++i) {
  88. images.push_back(PyArrayToCvMat(data[i]));
  89. }
  90. std::vector<vision::FaceRecognitionResult> results;
  91. self.BatchPredict(images, &results);
  92. return results;
  93. })
  94. .def_property_readonly("preprocessor",
  95. &vision::faceid::AdaFace::GetPreprocessor)
  96. .def_property_readonly("postprocessor",
  97. &vision::faceid::AdaFace::GetPostprocessor);
  98. }
  99. } // namespace ultra_infer